updated per PR reviews
This commit is contained in:
@@ -75,7 +75,7 @@ Catch new features, interviews, and how-tos on our [Pipecat TV](https://www.yout
|
||||
| ------------------- | ----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| Speech-to-Text | [AssemblyAI](https://docs.pipecat.ai/server/services/stt/assemblyai), [AWS](https://docs.pipecat.ai/server/services/stt/aws), [Azure](https://docs.pipecat.ai/server/services/stt/azure), [Cartesia](https://docs.pipecat.ai/server/services/stt/cartesia), [Deepgram](https://docs.pipecat.ai/server/services/stt/deepgram), [ElevenLabs](https://docs.pipecat.ai/server/services/stt/elevenlabs), [Fal Wizper](https://docs.pipecat.ai/server/services/stt/fal), [Gladia](https://docs.pipecat.ai/server/services/stt/gladia), [Google](https://docs.pipecat.ai/server/services/stt/google), [Gradium](https://docs.pipecat.ai/server/services/stt/gradium), [Groq (Whisper)](https://docs.pipecat.ai/server/services/stt/groq), [NVIDIA Riva](https://docs.pipecat.ai/server/services/stt/riva), [OpenAI (Whisper)](https://docs.pipecat.ai/server/services/stt/openai), [SambaNova (Whisper)](https://docs.pipecat.ai/server/services/stt/sambanova), [Sarvam](https://docs.pipecat.ai/server/services/stt/sarvam), [Soniox](https://docs.pipecat.ai/server/services/stt/soniox), [Speechmatics](https://docs.pipecat.ai/server/services/stt/speechmatics), [Whisper](https://docs.pipecat.ai/server/services/stt/whisper) |
|
||||
| LLMs | [Anthropic](https://docs.pipecat.ai/server/services/llm/anthropic), [AWS](https://docs.pipecat.ai/server/services/llm/aws), [Azure](https://docs.pipecat.ai/server/services/llm/azure), [Cerebras](https://docs.pipecat.ai/server/services/llm/cerebras), [DeepSeek](https://docs.pipecat.ai/server/services/llm/deepseek), [Fireworks AI](https://docs.pipecat.ai/server/services/llm/fireworks), [Gemini](https://docs.pipecat.ai/server/services/llm/gemini), [Grok](https://docs.pipecat.ai/server/services/llm/grok), [Groq](https://docs.pipecat.ai/server/services/llm/groq), [Mistral](https://docs.pipecat.ai/server/services/llm/mistral), [NVIDIA NIM](https://docs.pipecat.ai/server/services/llm/nim), [Ollama](https://docs.pipecat.ai/server/services/llm/ollama), [OpenAI](https://docs.pipecat.ai/server/services/llm/openai), [OpenRouter](https://docs.pipecat.ai/server/services/llm/openrouter), [Perplexity](https://docs.pipecat.ai/server/services/llm/perplexity), [Qwen](https://docs.pipecat.ai/server/services/llm/qwen), [SambaNova](https://docs.pipecat.ai/server/services/llm/sambanova) [Together AI](https://docs.pipecat.ai/server/services/llm/together) |
|
||||
| Text-to-Speech | [Async](https://docs.pipecat.ai/server/services/tts/asyncai), [AWS](https://docs.pipecat.ai/server/services/tts/aws), [Azure](https://docs.pipecat.ai/server/services/tts/azure), [Cartesia](https://docs.pipecat.ai/server/services/tts/cartesia), [Deepgram](https://docs.pipecat.ai/server/services/tts/deepgram), [ElevenLabs](https://docs.pipecat.ai/server/services/tts/elevenlabs), [Fish](https://docs.pipecat.ai/server/services/tts/fish), [Google](https://docs.pipecat.ai/server/services/tts/google), [Gradium](https://docs.pipecat.ai/server/services/tts/gradium), [Groq](https://docs.pipecat.ai/server/services/tts/groq), [Hume](https://docs.pipecat.ai/server/services/tts/hume), [Inworld](https://docs.pipecat.ai/server/services/tts/inworld), [LMNT](https://docs.pipecat.ai/server/services/tts/lmnt), [MiniMax](https://docs.pipecat.ai/server/services/tts/minimax), [Neuphonic](https://docs.pipecat.ai/server/services/tts/neuphonic), [NVIDIA Riva](https://docs.pipecat.ai/server/services/tts/riva), [OpenAI](https://docs.pipecat.ai/server/services/tts/openai), [Piper](https://docs.pipecat.ai/server/services/tts/piper), [PlayHT](https://docs.pipecat.ai/server/services/tts/playht), [Rime](https://docs.pipecat.ai/server/services/tts/rime), [Sarvam](https://docs.pipecat.ai/server/services/tts/sarvam), [Speechmatics](https://docs.pipecat.ai/server/services/tts/speechmatics), [XTTS](https://docs.pipecat.ai/server/services/tts/xtts) |
|
||||
| Text-to-Speech | [Async](https://docs.pipecat.ai/server/services/tts/asyncai), [AWS](https://docs.pipecat.ai/server/services/tts/aws), [Azure](https://docs.pipecat.ai/server/services/tts/azure), [Camb AI](https://docs.pipecat.ai/server/services/tts/camb), [Cartesia](https://docs.pipecat.ai/server/services/tts/cartesia), [Deepgram](https://docs.pipecat.ai/server/services/tts/deepgram), [ElevenLabs](https://docs.pipecat.ai/server/services/tts/elevenlabs), [Fish](https://docs.pipecat.ai/server/services/tts/fish), [Google](https://docs.pipecat.ai/server/services/tts/google), [Gradium](https://docs.pipecat.ai/server/services/tts/gradium), [Groq](https://docs.pipecat.ai/server/services/tts/groq), [Hume](https://docs.pipecat.ai/server/services/tts/hume), [Inworld](https://docs.pipecat.ai/server/services/tts/inworld), [LMNT](https://docs.pipecat.ai/server/services/tts/lmnt), [MiniMax](https://docs.pipecat.ai/server/services/tts/minimax), [Neuphonic](https://docs.pipecat.ai/server/services/tts/neuphonic), [NVIDIA Riva](https://docs.pipecat.ai/server/services/tts/riva), [OpenAI](https://docs.pipecat.ai/server/services/tts/openai), [Piper](https://docs.pipecat.ai/server/services/tts/piper), [PlayHT](https://docs.pipecat.ai/server/services/tts/playht), [Rime](https://docs.pipecat.ai/server/services/tts/rime), [Sarvam](https://docs.pipecat.ai/server/services/tts/sarvam), [Speechmatics](https://docs.pipecat.ai/server/services/tts/speechmatics), [XTTS](https://docs.pipecat.ai/server/services/tts/xtts) |
|
||||
| Speech-to-Speech | [AWS Nova Sonic](https://docs.pipecat.ai/server/services/s2s/aws), [Gemini Multimodal Live](https://docs.pipecat.ai/server/services/s2s/gemini), [Grok Voice Agent](https://docs.pipecat.ai/server/services/s2s/grok), [OpenAI Realtime](https://docs.pipecat.ai/server/services/s2s/openai), [Ultravox](https://docs.pipecat.ai/server/services/s2s/ultravox), |
|
||||
| Transport | [Daily (WebRTC)](https://docs.pipecat.ai/server/services/transport/daily), [FastAPI Websocket](https://docs.pipecat.ai/server/services/transport/fastapi-websocket), [SmallWebRTCTransport](https://docs.pipecat.ai/server/services/transport/small-webrtc), [WebSocket Server](https://docs.pipecat.ai/server/services/transport/websocket-server), Local |
|
||||
| Serializers | [Exotel](https://docs.pipecat.ai/server/utilities/serializers/exotel), [Plivo](https://docs.pipecat.ai/server/utilities/serializers/plivo), [Twilio](https://docs.pipecat.ai/server/utilities/serializers/twilio), [Telnyx](https://docs.pipecat.ai/server/utilities/serializers/telnyx), [Vonage](https://docs.pipecat.ai/server/utilities/serializers/vonage) |
|
||||
|
||||
@@ -1 +1 @@
|
||||
- Added Camb.ai TTS integration with MARS models (mars-flash, mars-pro, mars-instruct) for high-quality text-to-speech synthesis.
|
||||
- Added `CambTTSService`, using Camb.ai's TTS integration with MARS models (mars-flash, mars-pro, mars-instruct) for high-quality text-to-speech synthesis.
|
||||
|
||||
@@ -31,6 +31,9 @@ AZURE_DALLE_API_KEY=...
|
||||
AZURE_DALLE_ENDPOINT=https://...
|
||||
AZURE_DALLE_MODEL=...
|
||||
|
||||
# Camb.ai
|
||||
CAMB_API_KEY=...
|
||||
|
||||
# Cartesia
|
||||
CARTESIA_API_KEY=...
|
||||
CARTESIA_VOICE_ID=...
|
||||
|
||||
@@ -1,203 +0,0 @@
|
||||
#
|
||||
# Copyright (c) 2024–2025, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
"""Camb.ai TTS example with local audio (microphone/speakers).
|
||||
|
||||
This is a standalone local example for quick testing without WebRTC/Daily.
|
||||
For production use with Daily/Twilio/WebRTC, see 07zb-interruptible-camb.py
|
||||
|
||||
Requirements:
|
||||
- CAMB_API_KEY environment variable
|
||||
- OPENAI_API_KEY environment variable (for LLM)
|
||||
- DEEPGRAM_API_KEY environment variable (for STT)
|
||||
|
||||
Usage:
|
||||
python 07zb-interruptible-camb-local.py
|
||||
python 07zb-interruptible-camb-local.py --voice-id 147320
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
|
||||
from dotenv import load_dotenv
|
||||
from loguru import logger
|
||||
|
||||
from pipecat.audio.vad.silero import SileroVADAnalyzer
|
||||
from pipecat.audio.vad.vad_analyzer import VADParams
|
||||
from pipecat.frames.frames import (
|
||||
BotStartedSpeakingFrame,
|
||||
Frame,
|
||||
LLMFullResponseStartFrame,
|
||||
LLMRunFrame,
|
||||
TTSStartedFrame,
|
||||
UserStoppedSpeakingFrame,
|
||||
)
|
||||
from pipecat.metrics.metrics import TTFBMetricsData
|
||||
from pipecat.observers.loggers.metrics_log_observer import MetricsLogObserver
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
from pipecat.pipeline.task import PipelineParams, PipelineTask
|
||||
from pipecat.processors.aggregators.llm_context import LLMContext
|
||||
from pipecat.processors.aggregators.llm_response_universal import (
|
||||
LLMContextAggregatorPair,
|
||||
)
|
||||
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
|
||||
from pipecat.services.camb.tts import CambTTSService
|
||||
from pipecat.services.deepgram.stt import DeepgramSTTService
|
||||
from pipecat.services.openai.llm import OpenAILLMService
|
||||
from pipecat.transports.local.audio import LocalAudioTransport, LocalAudioTransportParams
|
||||
|
||||
|
||||
class LatencyTracker(FrameProcessor):
|
||||
"""Tracks end-to-end latency from user speech to AI audio response."""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self._user_stopped_time: float = 0
|
||||
self._llm_start_time: float = 0
|
||||
self._tts_start_time: float = 0
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
if isinstance(frame, UserStoppedSpeakingFrame):
|
||||
self._user_stopped_time = time.time()
|
||||
logger.info("⏱️ User stopped speaking - timer started")
|
||||
|
||||
elif isinstance(frame, LLMFullResponseStartFrame):
|
||||
self._llm_start_time = time.time()
|
||||
if self._user_stopped_time > 0:
|
||||
stt_latency = (self._llm_start_time - self._user_stopped_time) * 1000
|
||||
logger.info(f"⏱️ STT latency: {stt_latency:.0f}ms")
|
||||
|
||||
elif isinstance(frame, TTSStartedFrame):
|
||||
self._tts_start_time = time.time()
|
||||
if self._llm_start_time > 0:
|
||||
llm_latency = (self._tts_start_time - self._llm_start_time) * 1000
|
||||
logger.info(f"⏱️ LLM TTFB: {llm_latency:.0f}ms")
|
||||
|
||||
elif isinstance(frame, BotStartedSpeakingFrame):
|
||||
if self._user_stopped_time > 0:
|
||||
total_latency = (time.time() - self._user_stopped_time) * 1000
|
||||
tts_latency = (time.time() - self._tts_start_time) * 1000 if self._tts_start_time > 0 else 0
|
||||
logger.info(f"⏱️ TTS TTFB: {tts_latency:.0f}ms")
|
||||
logger.info(f"⏱️ ✨ TOTAL END-TO-END LATENCY: {total_latency:.0f}ms")
|
||||
# Reset for next turn
|
||||
self._user_stopped_time = 0
|
||||
self._llm_start_time = 0
|
||||
self._tts_start_time = 0
|
||||
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
load_dotenv(override=True)
|
||||
|
||||
logger.remove(0)
|
||||
logger.add(sys.stderr, level="DEBUG")
|
||||
|
||||
# Default voice
|
||||
DEFAULT_VOICE_ID = 147320
|
||||
|
||||
|
||||
async def main(voice_id: int):
|
||||
sample_rate = 22050 # mars-flash uses 22.05kHz
|
||||
|
||||
# Local audio transport - uses your microphone and speakers
|
||||
# Increase audio_out_10ms_chunks for larger buffer (default is 4 = 40ms)
|
||||
transport = LocalAudioTransport(
|
||||
LocalAudioTransportParams(
|
||||
audio_in_enabled=True,
|
||||
audio_out_enabled=True,
|
||||
audio_out_10ms_chunks=10, # 100ms buffer for smoother playback
|
||||
vad_analyzer=SileroVADAnalyzer(params=VADParams(stop_secs=0.2)),
|
||||
)
|
||||
)
|
||||
|
||||
# Deepgram STT for speech recognition
|
||||
stt = DeepgramSTTService(api_key=os.getenv("DEEPGRAM_API_KEY"))
|
||||
|
||||
# Camb.ai TTS
|
||||
tts = CambTTSService(
|
||||
api_key=os.getenv("CAMB_API_KEY"),
|
||||
voice_id=voice_id,
|
||||
model="mars-flash",
|
||||
)
|
||||
|
||||
# OpenAI LLM
|
||||
llm = OpenAILLMService(api_key=os.getenv("OPENAI_API_KEY"))
|
||||
|
||||
# System prompt
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": """You are a helpful voice assistant powered by Camb.ai
|
||||
text-to-speech technology. Keep your responses concise and conversational since
|
||||
they will be spoken aloud. Avoid special characters, emojis, or bullet points.""",
|
||||
},
|
||||
]
|
||||
|
||||
# Context management
|
||||
context = LLMContext(messages)
|
||||
context_aggregator = LLMContextAggregatorPair(context)
|
||||
|
||||
# Latency tracker for end-to-end timing
|
||||
latency_tracker = LatencyTracker()
|
||||
|
||||
# Build the pipeline
|
||||
pipeline = Pipeline(
|
||||
[
|
||||
transport.input(), # Microphone input
|
||||
stt, # Speech-to-text
|
||||
latency_tracker, # Track latency at various stages
|
||||
context_aggregator.user(), # User context
|
||||
llm, # Language model
|
||||
tts, # TTS
|
||||
transport.output(), # Speaker output
|
||||
context_aggregator.assistant(), # Assistant context
|
||||
]
|
||||
)
|
||||
|
||||
# Create pipeline task with TTFB tracking
|
||||
task = PipelineTask(
|
||||
pipeline,
|
||||
params=PipelineParams(
|
||||
audio_out_sample_rate=sample_rate,
|
||||
enable_metrics=True,
|
||||
enable_usage_metrics=True,
|
||||
),
|
||||
observers=[MetricsLogObserver(include_metrics={TTFBMetricsData})],
|
||||
)
|
||||
|
||||
# Start the conversation when the pipeline is ready
|
||||
@task.event_handler("on_pipeline_started")
|
||||
async def on_pipeline_started(task, frame):
|
||||
messages.append(
|
||||
{
|
||||
"role": "system",
|
||||
"content": "Please introduce yourself briefly and ask how you can help.",
|
||||
}
|
||||
)
|
||||
await task.queue_frames([LLMRunFrame()])
|
||||
|
||||
# Run the pipeline
|
||||
runner = PipelineRunner()
|
||||
logger.info("Starting Camb.ai TTS bot with local audio...")
|
||||
logger.info("Speak into your microphone to interact with the bot.")
|
||||
await runner.run(task)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Camb.ai TTS with local audio")
|
||||
parser.add_argument(
|
||||
"--voice-id",
|
||||
type=int,
|
||||
default=DEFAULT_VOICE_ID,
|
||||
help=f"Camb.ai voice ID (default: {DEFAULT_VOICE_ID})",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
asyncio.run(main(args.voice_id))
|
||||
@@ -9,6 +9,7 @@ import os
|
||||
from dotenv import load_dotenv
|
||||
from loguru import logger
|
||||
|
||||
from pipecat.audio.turn.smart_turn.local_smart_turn_v3 import LocalSmartTurnAnalyzerV3
|
||||
from pipecat.audio.vad.silero import SileroVADAnalyzer
|
||||
from pipecat.audio.vad.vad_analyzer import VADParams
|
||||
from pipecat.frames.frames import LLMRunFrame
|
||||
@@ -16,7 +17,10 @@ from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
from pipecat.pipeline.task import PipelineParams, PipelineTask
|
||||
from pipecat.processors.aggregators.llm_context import LLMContext
|
||||
from pipecat.processors.aggregators.llm_response_universal import LLMContextAggregatorPair
|
||||
from pipecat.processors.aggregators.llm_response_universal import (
|
||||
LLMContextAggregatorPair,
|
||||
LLMUserAggregatorParams,
|
||||
)
|
||||
from pipecat.runner.types import RunnerArguments
|
||||
from pipecat.runner.utils import create_transport
|
||||
from pipecat.services.camb.tts import CambTTSService
|
||||
@@ -25,6 +29,8 @@ from pipecat.services.openai.llm import OpenAILLMService
|
||||
from pipecat.transports.base_transport import BaseTransport, TransportParams
|
||||
from pipecat.transports.daily.transport import DailyParams
|
||||
from pipecat.transports.websocket.fastapi import FastAPIWebsocketParams
|
||||
from pipecat.turns.bot import TurnAnalyzerBotTurnStartStrategy
|
||||
from pipecat.turns.turn_start_strategies import TurnStartStrategies
|
||||
|
||||
load_dotenv(override=True)
|
||||
|
||||
@@ -52,7 +58,7 @@ transport_params = {
|
||||
|
||||
|
||||
async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
logger.info("Starting Camb.ai TTS bot")
|
||||
logger.info("Starting Camb AI TTS bot")
|
||||
|
||||
stt = DeepgramSTTService(api_key=os.getenv("DEEPGRAM_API_KEY"))
|
||||
|
||||
@@ -66,14 +72,21 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a helpful voice assistant powered by Camb.ai text-to-speech. "
|
||||
"content": "You are a helpful voice assistant powered by Camb AI text-to-speech. "
|
||||
"Keep your responses concise and conversational since they will be spoken aloud. "
|
||||
"Avoid special characters, emojis, or bullet points.",
|
||||
},
|
||||
]
|
||||
|
||||
context = LLMContext(messages)
|
||||
context_aggregator = LLMContextAggregatorPair(context)
|
||||
context_aggregator = LLMContextAggregatorPair(
|
||||
context,
|
||||
user_params=LLMUserAggregatorParams(
|
||||
turn_start_strategies=TurnStartStrategies(
|
||||
bot=[TurnAnalyzerBotTurnStartStrategy(turn_analyzer=LocalSmartTurnAnalyzerV3())]
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
pipeline = Pipeline(
|
||||
[
|
||||
@@ -92,7 +105,9 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
params=PipelineParams(
|
||||
enable_metrics=True,
|
||||
enable_usage_metrics=True,
|
||||
audio_out_sample_rate=22050,
|
||||
),
|
||||
idle_timeout_secs=runner_args.pipeline_idle_timeout_secs,
|
||||
)
|
||||
|
||||
@transport.event_handler("on_client_connected")
|
||||
@@ -53,7 +53,7 @@ aws = [ "aioboto3~=15.5.0", "pipecat-ai[websockets-base]" ]
|
||||
aws-nova-sonic = [ "aws_sdk_bedrock_runtime~=0.2.0; python_version>='3.12'" ]
|
||||
azure = [ "azure-cognitiveservices-speech~=1.44.0"]
|
||||
cartesia = [ "cartesia~=2.0.3", "pipecat-ai[websockets-base]" ]
|
||||
camb = [ "pipecat-ai[websockets-base]" ]
|
||||
camb = [ "camb-sdk>=1.5.4" ]
|
||||
cerebras = []
|
||||
daily = [ "daily-python~=0.23.0" ]
|
||||
deepgram = [ "deepgram-sdk~=4.7.0", "pipecat-ai[websockets-base]" ]
|
||||
|
||||
@@ -3,6 +3,3 @@
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
|
||||
from .tts import *
|
||||
|
||||
@@ -16,10 +16,10 @@ Features:
|
||||
- Model-specific sample rates: mars-pro (48kHz), mars-flash (22.05kHz)
|
||||
"""
|
||||
|
||||
from typing import Any, AsyncGenerator, Dict, List, Mapping, Optional
|
||||
from typing import Any, AsyncGenerator, Dict, Mapping, Optional
|
||||
|
||||
from camb.client import AsyncCambAI
|
||||
from camb import StreamTtsOutputConfiguration
|
||||
from camb.client import AsyncCambAI
|
||||
from loguru import logger
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
@@ -35,25 +35,13 @@ from pipecat.services.tts_service import TTSService
|
||||
from pipecat.transcriptions.language import Language, resolve_language
|
||||
from pipecat.utils.tracing.service_decorators import traced_tts
|
||||
|
||||
|
||||
# Default configuration
|
||||
DEFAULT_VOICE_ID = 147320
|
||||
DEFAULT_LANGUAGE = "en-us"
|
||||
DEFAULT_MODEL = "mars-flash" # Faster inference
|
||||
DEFAULT_TIMEOUT = 60.0 # Seconds (minimum recommended by Camb.ai)
|
||||
MIN_TEXT_LENGTH = 3
|
||||
MAX_TEXT_LENGTH = 3000
|
||||
|
||||
# Model-specific sample rates
|
||||
MODEL_SAMPLE_RATES: Dict[str, int] = {
|
||||
"mars-flash": 22050, # 22.05kHz
|
||||
"mars-pro": 48000, # 48kHz
|
||||
"mars-flash": 22050, # 22.05kHz
|
||||
"mars-pro": 48000, # 48kHz
|
||||
"mars-instruct": 22050, # 22.05kHz
|
||||
}
|
||||
|
||||
# Gender mapping for voice listing
|
||||
GENDER_MAP = {0: "Not Specified", 1: "Male", 2: "Female", 9: "Not Applicable"}
|
||||
|
||||
|
||||
def language_to_camb_language(language: Language) -> Optional[str]:
|
||||
"""Convert a Pipecat Language enum to Camb.ai language code.
|
||||
@@ -132,6 +120,19 @@ def language_to_camb_language(language: Language) -> Optional[str]:
|
||||
return resolve_language(language, LANGUAGE_MAP, use_base_code=True)
|
||||
|
||||
|
||||
def _get_aligned_audio(buffer: bytes) -> tuple[bytes, bytes]:
|
||||
"""Split buffer into aligned audio (2-byte samples) and remainder.
|
||||
|
||||
Args:
|
||||
buffer: Raw audio bytes to align.
|
||||
|
||||
Returns:
|
||||
Tuple of (aligned audio bytes, remaining bytes).
|
||||
"""
|
||||
aligned_size = (len(buffer) // 2) * 2
|
||||
return buffer[:aligned_size], buffer[aligned_size:]
|
||||
|
||||
|
||||
class CambTTSService(TTSService):
|
||||
"""Camb.ai MARS text-to-speech service using the official SDK.
|
||||
|
||||
@@ -176,9 +177,9 @@ class CambTTSService(TTSService):
|
||||
self,
|
||||
*,
|
||||
api_key: str,
|
||||
voice_id: int = DEFAULT_VOICE_ID,
|
||||
model: str = DEFAULT_MODEL,
|
||||
timeout: float = DEFAULT_TIMEOUT,
|
||||
voice_id: int = 147320,
|
||||
model: str = "mars-flash",
|
||||
timeout: float = 60.0,
|
||||
sample_rate: Optional[int] = None,
|
||||
params: Optional[InputParams] = None,
|
||||
**kwargs,
|
||||
@@ -187,10 +188,11 @@ class CambTTSService(TTSService):
|
||||
|
||||
Args:
|
||||
api_key: Camb.ai API key for authentication.
|
||||
voice_id: Voice ID to use. Defaults to DEFAULT_VOICE_ID.
|
||||
voice_id: Voice ID to use. Defaults to 147320.
|
||||
model: TTS model to use. Options: "mars-flash" (fast), "mars-pro" (high quality).
|
||||
Defaults to DEFAULT_MODEL (mars-flash).
|
||||
timeout: Request timeout in seconds. Defaults to DEFAULT_TIMEOUT (60s).
|
||||
Defaults to "mars-flash".
|
||||
timeout: Request timeout in seconds. Defaults to 60.0 (minimum recommended
|
||||
by Camb.ai).
|
||||
sample_rate: Audio sample rate in Hz. If None, uses model-specific default.
|
||||
params: Additional voice parameters. If None, uses defaults.
|
||||
**kwargs: Additional arguments passed to parent TTSService.
|
||||
@@ -201,19 +203,24 @@ class CambTTSService(TTSService):
|
||||
|
||||
self._client = AsyncCambAI(api_key=api_key, timeout=timeout)
|
||||
|
||||
# Warn if sample rate doesn't match model's supported rate
|
||||
if sample_rate and sample_rate != MODEL_SAMPLE_RATES.get(model):
|
||||
logger.warning(
|
||||
f"Camb.ai's {model} model only supports {MODEL_SAMPLE_RATES.get(model)}Hz "
|
||||
f"sample rate. Current rate of {sample_rate}Hz may cause issues."
|
||||
)
|
||||
|
||||
# Build settings
|
||||
self._settings = {
|
||||
"language": (
|
||||
self.language_to_service_language(params.language)
|
||||
if params.language
|
||||
else DEFAULT_LANGUAGE
|
||||
self.language_to_service_language(params.language) if params.language else "en-us"
|
||||
),
|
||||
"user_instructions": params.user_instructions,
|
||||
}
|
||||
|
||||
self.set_model_name(model)
|
||||
self.set_voice(str(voice_id))
|
||||
self._voice_id_int = voice_id
|
||||
self._voice_id = voice_id
|
||||
|
||||
def can_generate_metrics(self) -> bool:
|
||||
"""Check if this service can generate processing metrics.
|
||||
@@ -262,7 +269,7 @@ class CambTTSService(TTSService):
|
||||
self._settings[key] = value
|
||||
logger.debug(f"Updated Camb.ai TTS setting {key} to: {value}")
|
||||
elif key == "voice_id":
|
||||
self._voice_id_int = int(value)
|
||||
self._voice_id = int(value)
|
||||
self.set_voice(str(value))
|
||||
|
||||
@traced_tts
|
||||
@@ -270,7 +277,7 @@ class CambTTSService(TTSService):
|
||||
"""Generate speech from text using Camb.ai's TTS API.
|
||||
|
||||
Args:
|
||||
text: The text to synthesize into speech (3-3000 characters).
|
||||
text: The text to synthesize into speech (max 3000 characters).
|
||||
|
||||
Yields:
|
||||
Frame: Audio frames containing the synthesized speech.
|
||||
@@ -278,16 +285,9 @@ class CambTTSService(TTSService):
|
||||
logger.debug(f"{self}: Generating TTS [{text}]")
|
||||
|
||||
# Validate text length
|
||||
if len(text) < MIN_TEXT_LENGTH:
|
||||
logger.warning(f"Text too short for Camb.ai TTS (min {MIN_TEXT_LENGTH} chars): {text}")
|
||||
yield TTSStoppedFrame()
|
||||
return
|
||||
|
||||
if len(text) > MAX_TEXT_LENGTH:
|
||||
logger.warning(
|
||||
f"Text too long for Camb.ai TTS (max {MAX_TEXT_LENGTH} chars), truncating"
|
||||
)
|
||||
text = text[:MAX_TEXT_LENGTH]
|
||||
if len(text) > 3000:
|
||||
logger.warning("Text too long for Camb.ai TTS (max 3000 chars), truncating")
|
||||
text = text[:3000]
|
||||
|
||||
try:
|
||||
await self.start_ttfb_metrics()
|
||||
@@ -295,9 +295,9 @@ class CambTTSService(TTSService):
|
||||
# Build SDK parameters
|
||||
tts_kwargs: Dict[str, Any] = {
|
||||
"text": text,
|
||||
"voice_id": self._voice_id_int,
|
||||
"voice_id": self._voice_id,
|
||||
"language": self._settings["language"],
|
||||
"speech_model": self._model_name,
|
||||
"speech_model": self.model_name,
|
||||
"output_configuration": StreamTtsOutputConfiguration(format="pcm_s16le"),
|
||||
}
|
||||
|
||||
@@ -318,71 +318,25 @@ class CambTTSService(TTSService):
|
||||
audio_buffer += chunk
|
||||
|
||||
# Only yield complete 16-bit samples (2 bytes per sample)
|
||||
aligned_size = (len(audio_buffer) // 2) * 2
|
||||
if aligned_size > 0:
|
||||
aligned_audio, audio_buffer = _get_aligned_audio(audio_buffer)
|
||||
if aligned_audio:
|
||||
yield TTSAudioRawFrame(
|
||||
audio=audio_buffer[:aligned_size],
|
||||
audio=aligned_audio,
|
||||
sample_rate=self.sample_rate,
|
||||
num_channels=1,
|
||||
)
|
||||
audio_buffer = audio_buffer[aligned_size:]
|
||||
|
||||
# Yield any remaining complete samples
|
||||
if len(audio_buffer) >= 2:
|
||||
aligned_size = (len(audio_buffer) // 2) * 2
|
||||
yield TTSAudioRawFrame(
|
||||
audio=audio_buffer[:aligned_size],
|
||||
sample_rate=self.sample_rate,
|
||||
num_channels=1,
|
||||
)
|
||||
aligned_audio, _ = _get_aligned_audio(audio_buffer)
|
||||
if aligned_audio:
|
||||
yield TTSAudioRawFrame(
|
||||
audio=aligned_audio,
|
||||
sample_rate=self.sample_rate,
|
||||
num_channels=1,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"Camb.ai TTS error: {e}"
|
||||
logger.error(f"{self}: {error_msg}")
|
||||
yield ErrorFrame(error=error_msg)
|
||||
yield ErrorFrame(error=f"Camb.ai TTS error: {e}")
|
||||
finally:
|
||||
logger.debug(f"{self}: Finished TTS [{text}]")
|
||||
await self.stop_ttfb_metrics()
|
||||
yield TTSStoppedFrame()
|
||||
|
||||
@staticmethod
|
||||
async def list_voices(api_key: str) -> List[Dict[str, Any]]:
|
||||
"""Fetch available voices from Camb.ai API.
|
||||
|
||||
Args:
|
||||
api_key: Camb.ai API key for authentication.
|
||||
|
||||
Returns:
|
||||
List of voice dictionaries with id, name, gender, and language fields.
|
||||
|
||||
Raises:
|
||||
Exception: If the API request fails.
|
||||
|
||||
Example::
|
||||
|
||||
voices = await CambTTSService.list_voices(api_key="your-api-key")
|
||||
for voice in voices:
|
||||
print(f"{voice['id']}: {voice['name']}")
|
||||
"""
|
||||
client = AsyncCambAI(api_key=api_key)
|
||||
voice_list = await client.voice_cloning.list_voices()
|
||||
|
||||
voices = []
|
||||
for voice in voice_list:
|
||||
voice_id = voice.get("id")
|
||||
# Skip voices without an ID
|
||||
if voice_id is None:
|
||||
continue
|
||||
|
||||
gender_int = voice.get("gender")
|
||||
gender = GENDER_MAP.get(gender_int) if gender_int is not None else None
|
||||
|
||||
voices.append({
|
||||
"id": voice_id,
|
||||
"name": voice.get("voice_name", ""),
|
||||
"gender": gender,
|
||||
"age": voice.get("age"),
|
||||
"language": voice.get("language"),
|
||||
})
|
||||
|
||||
return voices
|
||||
|
||||
@@ -1,384 +0,0 @@
|
||||
#
|
||||
# Copyright (c) 2024-2025 Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
"""Tests for CambTTSService.
|
||||
|
||||
These tests mock the Camb.ai SDK client to test the service behavior.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
AggregatedTextFrame,
|
||||
ErrorFrame,
|
||||
TTSAudioRawFrame,
|
||||
TTSSpeakFrame,
|
||||
TTSStartedFrame,
|
||||
TTSStoppedFrame,
|
||||
TTSTextFrame,
|
||||
)
|
||||
from pipecat.services.camb.tts import (
|
||||
CambTTSService,
|
||||
DEFAULT_VOICE_ID,
|
||||
MODEL_SAMPLE_RATES,
|
||||
language_to_camb_language,
|
||||
)
|
||||
from pipecat.tests.utils import run_test
|
||||
from pipecat.transcriptions.language import Language
|
||||
|
||||
|
||||
async def mock_tts_stream(*args, **kwargs):
|
||||
"""Mock TTS stream that yields audio chunks."""
|
||||
yield b"\x00\x01" * 4800 # Small chunk of PCM audio
|
||||
|
||||
|
||||
async def mock_tts_stream_error(*args, **kwargs):
|
||||
"""Mock TTS stream that raises an error."""
|
||||
raise Exception("API error: Invalid API key")
|
||||
yield # Make this a generator
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_camb_tts_success():
|
||||
"""Test successful TTS generation with chunked PCM audio.
|
||||
|
||||
Verifies the frame sequence: TTSStartedFrame -> TTSAudioRawFrame* -> TTSStoppedFrame
|
||||
"""
|
||||
with patch("pipecat.services.camb.tts.AsyncCambAI") as MockAsyncCambAI:
|
||||
mock_client = MagicMock()
|
||||
mock_client.text_to_speech.tts = mock_tts_stream
|
||||
MockAsyncCambAI.return_value = mock_client
|
||||
|
||||
tts_service = CambTTSService(api_key="test-api-key")
|
||||
|
||||
# Manually set sample rate (normally done by StartFrame)
|
||||
# mars-flash uses 22.05kHz
|
||||
tts_service._sample_rate = MODEL_SAMPLE_RATES["mars-flash"]
|
||||
|
||||
# Test run_tts directly to avoid frame count variability
|
||||
text = "Hello world, this is a test."
|
||||
frames = []
|
||||
async for frame in tts_service.run_tts(text):
|
||||
frames.append(frame)
|
||||
|
||||
# Verify we got the expected frame types
|
||||
frame_types = [type(f).__name__ for f in frames]
|
||||
assert "TTSStartedFrame" in frame_types, "Should have TTSStartedFrame"
|
||||
assert "TTSAudioRawFrame" in frame_types, "Should have TTSAudioRawFrame"
|
||||
assert "TTSStoppedFrame" in frame_types, "Should have TTSStoppedFrame"
|
||||
|
||||
audio_frames = [f for f in frames if isinstance(f, TTSAudioRawFrame)]
|
||||
assert len(audio_frames) > 0, "Should have at least one audio frame"
|
||||
|
||||
# Verify sample rate matches model output (mars-flash = 22.05kHz)
|
||||
for a_frame in audio_frames:
|
||||
assert a_frame.sample_rate == MODEL_SAMPLE_RATES["mars-flash"]
|
||||
assert a_frame.num_channels == 1, "Should be mono audio"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_camb_tts_error():
|
||||
"""Test handling of TTS API errors."""
|
||||
with patch("pipecat.services.camb.tts.AsyncCambAI") as MockAsyncCambAI:
|
||||
mock_client = MagicMock()
|
||||
mock_client.text_to_speech.tts = mock_tts_stream_error
|
||||
MockAsyncCambAI.return_value = mock_client
|
||||
|
||||
tts_service = CambTTSService(api_key="invalid-key")
|
||||
|
||||
frames_to_send = [
|
||||
TTSSpeakFrame(text="This should fail."),
|
||||
]
|
||||
|
||||
# TTSStartedFrame is emitted before we attempt to iterate the stream
|
||||
expected_down_frames = [AggregatedTextFrame, TTSStartedFrame, TTSStoppedFrame, TTSTextFrame]
|
||||
expected_up_frames = [ErrorFrame]
|
||||
|
||||
frames_received = await run_test(
|
||||
tts_service,
|
||||
frames_to_send=frames_to_send,
|
||||
expected_down_frames=expected_down_frames,
|
||||
expected_up_frames=expected_up_frames,
|
||||
)
|
||||
up_frames = frames_received[1]
|
||||
|
||||
assert isinstance(up_frames[0], ErrorFrame), "Must receive an ErrorFrame"
|
||||
assert "error" in up_frames[0].error.lower(), "ErrorFrame should contain error message"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_voices():
|
||||
"""Test voice listing endpoint with dict responses."""
|
||||
|
||||
async def mock_list_voices(*args, **kwargs):
|
||||
# Return mock voice dicts (as returned by the API)
|
||||
return [
|
||||
{
|
||||
"id": 2681,
|
||||
"voice_name": "Attic",
|
||||
"gender": 1,
|
||||
"age": 25,
|
||||
"language": None,
|
||||
},
|
||||
{
|
||||
"id": 2682,
|
||||
"voice_name": "Cellar",
|
||||
"gender": 2,
|
||||
"age": 30,
|
||||
"language": "en-us",
|
||||
},
|
||||
]
|
||||
|
||||
with patch("pipecat.services.camb.tts.AsyncCambAI") as MockAsyncCambAI:
|
||||
mock_client = MagicMock()
|
||||
mock_client.voice_cloning.list_voices = mock_list_voices
|
||||
MockAsyncCambAI.return_value = mock_client
|
||||
|
||||
voices = await CambTTSService.list_voices(api_key="test-api-key")
|
||||
|
||||
# Should return all voices
|
||||
assert len(voices) == 2, "Should return all voices"
|
||||
|
||||
# Verify voice data structure
|
||||
attic_voice = next(v for v in voices if v["id"] == 2681)
|
||||
assert attic_voice["name"] == "Attic"
|
||||
assert attic_voice["gender"] == "Male"
|
||||
assert attic_voice["age"] == 25
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_voices_skips_none_id():
|
||||
"""Test that voices without an ID are skipped."""
|
||||
|
||||
async def mock_list_voices(*args, **kwargs):
|
||||
return [
|
||||
{"id": 123, "voice_name": "Valid", "gender": 1, "age": 25, "language": None},
|
||||
{"id": None, "voice_name": "NoID", "gender": 2, "age": 30, "language": None},
|
||||
{"voice_name": "MissingID", "gender": 1, "age": 35, "language": None},
|
||||
]
|
||||
|
||||
with patch("pipecat.services.camb.tts.AsyncCambAI") as MockAsyncCambAI:
|
||||
mock_client = MagicMock()
|
||||
mock_client.voice_cloning.list_voices = mock_list_voices
|
||||
MockAsyncCambAI.return_value = mock_client
|
||||
|
||||
voices = await CambTTSService.list_voices(api_key="test-api-key")
|
||||
|
||||
# Should only return the voice with a valid ID
|
||||
assert len(voices) == 1
|
||||
assert voices[0]["id"] == 123
|
||||
assert voices[0]["name"] == "Valid"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_voices_handles_none_gender():
|
||||
"""Test that None gender is handled correctly."""
|
||||
|
||||
async def mock_list_voices(*args, **kwargs):
|
||||
return [
|
||||
{"id": 123, "voice_name": "NoGender", "gender": None, "age": 25, "language": None},
|
||||
]
|
||||
|
||||
with patch("pipecat.services.camb.tts.AsyncCambAI") as MockAsyncCambAI:
|
||||
mock_client = MagicMock()
|
||||
mock_client.voice_cloning.list_voices = mock_list_voices
|
||||
MockAsyncCambAI.return_value = mock_client
|
||||
|
||||
voices = await CambTTSService.list_voices(api_key="test-api-key")
|
||||
|
||||
assert len(voices) == 1
|
||||
assert voices[0]["gender"] is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_text_length_validation_too_short():
|
||||
"""Test that text shorter than 3 characters is handled gracefully."""
|
||||
with patch("pipecat.services.camb.tts.AsyncCambAI") as MockAsyncCambAI:
|
||||
mock_client = MagicMock()
|
||||
# TTS should not be called for short text
|
||||
mock_client.text_to_speech.tts = AsyncMock(side_effect=AssertionError("TTS should not be called"))
|
||||
MockAsyncCambAI.return_value = mock_client
|
||||
|
||||
tts_service = CambTTSService(api_key="test-api-key")
|
||||
|
||||
frames_to_send = [
|
||||
TTSSpeakFrame(text="Hi"), # Only 2 characters
|
||||
]
|
||||
|
||||
# For short text, we expect TTSStoppedFrame but no audio
|
||||
expected_down_frames = [AggregatedTextFrame, TTSStoppedFrame, TTSTextFrame]
|
||||
|
||||
frames_received = await run_test(
|
||||
tts_service,
|
||||
frames_to_send=frames_to_send,
|
||||
expected_down_frames=expected_down_frames,
|
||||
)
|
||||
down_frames = frames_received[0]
|
||||
|
||||
# Verify no audio frames were generated
|
||||
audio_frames = [f for f in down_frames if isinstance(f, TTSAudioRawFrame)]
|
||||
assert len(audio_frames) == 0, "Should not generate audio for text < 3 chars"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_input_params():
|
||||
"""Test InputParams model validation and defaults."""
|
||||
|
||||
# Test defaults
|
||||
params = CambTTSService.InputParams()
|
||||
assert params.language == Language.EN
|
||||
assert params.user_instructions is None
|
||||
|
||||
# Test custom values
|
||||
params = CambTTSService.InputParams(
|
||||
language=Language.ES,
|
||||
user_instructions="Speak slowly and clearly",
|
||||
)
|
||||
assert params.language == Language.ES
|
||||
assert params.user_instructions == "Speak slowly and clearly"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_language_mapping():
|
||||
"""Test language enum to Camb.ai language code conversion."""
|
||||
|
||||
# Test common languages
|
||||
assert language_to_camb_language(Language.EN) == "en-us"
|
||||
assert language_to_camb_language(Language.EN_US) == "en-us"
|
||||
assert language_to_camb_language(Language.EN_GB) == "en-gb"
|
||||
assert language_to_camb_language(Language.ES) == "es-es"
|
||||
assert language_to_camb_language(Language.FR) == "fr-fr"
|
||||
assert language_to_camb_language(Language.DE) == "de-de"
|
||||
assert language_to_camb_language(Language.JA) == "ja-jp"
|
||||
assert language_to_camb_language(Language.ZH) == "zh-cn"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mars_instruct_model():
|
||||
"""Test that user_instructions are included for mars-instruct model."""
|
||||
received_kwargs = {}
|
||||
|
||||
async def mock_tts_with_capture(*args, **kwargs):
|
||||
nonlocal received_kwargs
|
||||
received_kwargs = kwargs
|
||||
yield b"\x00" * 1000
|
||||
|
||||
with patch("pipecat.services.camb.tts.AsyncCambAI") as MockAsyncCambAI:
|
||||
mock_client = MagicMock()
|
||||
mock_client.text_to_speech.tts = mock_tts_with_capture
|
||||
MockAsyncCambAI.return_value = mock_client
|
||||
|
||||
tts_service = CambTTSService(
|
||||
api_key="test-api-key",
|
||||
model="mars-instruct",
|
||||
params=CambTTSService.InputParams(user_instructions="Speak with excitement"),
|
||||
)
|
||||
|
||||
frames_to_send = [
|
||||
TTSSpeakFrame(text="This is exciting news!"),
|
||||
]
|
||||
|
||||
await run_test(
|
||||
tts_service,
|
||||
frames_to_send=frames_to_send,
|
||||
expected_down_frames=[
|
||||
AggregatedTextFrame,
|
||||
TTSStartedFrame,
|
||||
TTSAudioRawFrame,
|
||||
TTSStoppedFrame,
|
||||
TTSTextFrame,
|
||||
],
|
||||
)
|
||||
|
||||
# Verify user_instructions was included in the request
|
||||
assert received_kwargs.get("speech_model") == "mars-instruct"
|
||||
assert received_kwargs.get("user_instructions") == "Speak with excitement"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_client_initialization():
|
||||
"""Test that client is created with correct parameters."""
|
||||
with patch("pipecat.services.camb.tts.AsyncCambAI") as MockAsyncCambAI:
|
||||
mock_client = MagicMock()
|
||||
MockAsyncCambAI.return_value = mock_client
|
||||
|
||||
CambTTSService(api_key="test-key", timeout=120.0)
|
||||
|
||||
# Should have created a client with correct params
|
||||
MockAsyncCambAI.assert_called_once_with(api_key="test-key", timeout=120.0)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_default_voice_id_used():
|
||||
"""Test that DEFAULT_VOICE_ID is used when not specified."""
|
||||
with patch("pipecat.services.camb.tts.AsyncCambAI") as MockAsyncCambAI:
|
||||
mock_client = MagicMock()
|
||||
MockAsyncCambAI.return_value = mock_client
|
||||
|
||||
tts = CambTTSService(api_key="test-key")
|
||||
|
||||
assert tts._voice_id_int == DEFAULT_VOICE_ID
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ttfb_metrics_tracked():
|
||||
"""Test that TTFB metrics are properly tracked during TTS generation."""
|
||||
import time
|
||||
|
||||
ttfb_start_called = False
|
||||
ttfb_stop_called = False
|
||||
start_time = None
|
||||
stop_time = None
|
||||
|
||||
async def mock_tts_with_delay(*args, **kwargs):
|
||||
# Simulate some network delay
|
||||
await asyncio.sleep(0.05)
|
||||
yield b"\x00\x01" * 4800
|
||||
|
||||
with patch("pipecat.services.camb.tts.AsyncCambAI") as MockAsyncCambAI:
|
||||
mock_client = MagicMock()
|
||||
mock_client.text_to_speech.tts = mock_tts_with_delay
|
||||
MockAsyncCambAI.return_value = mock_client
|
||||
|
||||
tts_service = CambTTSService(api_key="test-api-key")
|
||||
tts_service._sample_rate = MODEL_SAMPLE_RATES["mars-flash"]
|
||||
|
||||
# Patch the metrics methods to track calls
|
||||
original_start_ttfb = tts_service.start_ttfb_metrics
|
||||
original_stop_ttfb = tts_service.stop_ttfb_metrics
|
||||
|
||||
async def patched_start_ttfb():
|
||||
nonlocal ttfb_start_called, start_time
|
||||
ttfb_start_called = True
|
||||
start_time = time.time()
|
||||
await original_start_ttfb()
|
||||
|
||||
async def patched_stop_ttfb():
|
||||
nonlocal ttfb_stop_called, stop_time
|
||||
if not ttfb_stop_called: # Only record first stop
|
||||
ttfb_stop_called = True
|
||||
stop_time = time.time()
|
||||
await original_stop_ttfb()
|
||||
|
||||
tts_service.start_ttfb_metrics = patched_start_ttfb
|
||||
tts_service.stop_ttfb_metrics = patched_stop_ttfb
|
||||
|
||||
# Run TTS
|
||||
frames = []
|
||||
async for frame in tts_service.run_tts("Hello, this is a TTFB test."):
|
||||
frames.append(frame)
|
||||
|
||||
# Verify TTFB tracking was called
|
||||
assert ttfb_start_called, "start_ttfb_metrics should be called"
|
||||
assert ttfb_stop_called, "stop_ttfb_metrics should be called"
|
||||
assert start_time is not None and stop_time is not None
|
||||
|
||||
# TTFB should be >= simulated delay
|
||||
ttfb = stop_time - start_time
|
||||
assert ttfb >= 0.05, f"TTFB ({ttfb:.3f}s) should be >= 0.05s (simulated delay)"
|
||||
Reference in New Issue
Block a user