From eadd68d40b6233e4bffeb0bfa7171f2cdc0758c4 Mon Sep 17 00:00:00 2001 From: Mark Backman Date: Wed, 18 Sep 2024 14:19:04 -0400 Subject: [PATCH] Add sample_rate setting to TTS services --- src/pipecat/services/ai_services.py | 6 ++++++ src/pipecat/services/azure.py | 14 +++++++++--- src/pipecat/services/cartesia.py | 2 +- src/pipecat/services/elevenlabs.py | 4 ++-- src/pipecat/services/lmnt.py | 2 +- src/pipecat/services/openai.py | 33 ++++++++++++++++++++++------- src/pipecat/services/playht.py | 17 ++++++++++++--- 7 files changed, 60 insertions(+), 18 deletions(-) diff --git a/src/pipecat/services/ai_services.py b/src/pipecat/services/ai_services.py index dcba578c5..7291e7db9 100644 --- a/src/pipecat/services/ai_services.py +++ b/src/pipecat/services/ai_services.py @@ -171,6 +171,7 @@ class TTSService(AIService): push_stop_frames: bool = False, # if push_stop_frames is True, wait for this idle period before pushing TTSStoppedFrame stop_frame_timeout_s: float = 1.0, + sample_rate: int = 16000, **kwargs): super().__init__(**kwargs) self._aggregate_sentences: bool = aggregate_sentences @@ -180,6 +181,11 @@ class TTSService(AIService): self._stop_frame_task: Optional[asyncio.Task] = None self._stop_frame_queue: asyncio.Queue = asyncio.Queue() self._current_sentence: str = "" + self._sample_rate: int = sample_rate + + @property + def sample_rate(self) -> int: + return self._sample_rate @abstractmethod async def set_model(self, model: str): diff --git a/src/pipecat/services/azure.py b/src/pipecat/services/azure.py index 76e884992..c2f984b75 100644 --- a/src/pipecat/services/azure.py +++ b/src/pipecat/services/azure.py @@ -72,13 +72,21 @@ class AzureLLMService(BaseOpenAILLMService): class AzureTTSService(TTSService): - def __init__(self, *, api_key: str, region: str, voice="en-US-SaraNeural", **kwargs): - super().__init__(**kwargs) + def __init__( + self, + *, + api_key: str, + region: str, + voice="en-US-SaraNeural", + sample_rate: int = 16000, + **kwargs): + super().__init__(sample_rate=sample_rate, **kwargs) speech_config = SpeechConfig(subscription=api_key, region=region) self._speech_synthesizer = SpeechSynthesizer(speech_config=speech_config, audio_config=None) self._voice = voice + self._sample_rate = sample_rate def can_generate_metrics(self) -> bool: return True @@ -109,7 +117,7 @@ class AzureTTSService(TTSService): await self.stop_ttfb_metrics() await self.push_frame(TTSStartedFrame()) # Azure always sends a 44-byte header. Strip it off. - yield AudioRawFrame(audio=result.audio_data[44:], sample_rate=16000, num_channels=1) + yield AudioRawFrame(audio=result.audio_data[44:], sample_rate=self._sample_rate, num_channels=1) await self.push_frame(TTSStoppedFrame()) elif result.reason == ResultReason.Canceled: cancellation_details = result.cancellation_details diff --git a/src/pipecat/services/cartesia.py b/src/pipecat/services/cartesia.py index 7b4463812..ea790fab7 100644 --- a/src/pipecat/services/cartesia.py +++ b/src/pipecat/services/cartesia.py @@ -84,7 +84,7 @@ class CartesiaTTSService(AsyncWordTTSService): # if we're interrupted. Cartesia gives us word-by-word timestamps. We # can use those to generate text frames ourselves aligned with the # playout timing of the audio! - super().__init__(aggregate_sentences=True, push_text_frames=False, **kwargs) + super().__init__(aggregate_sentences=True, push_text_frames=False, sample_rate=sample_rate, **kwargs) self._api_key = api_key self._cartesia_version = cartesia_version diff --git a/src/pipecat/services/elevenlabs.py b/src/pipecat/services/elevenlabs.py index a7a80033e..081a6bf5d 100644 --- a/src/pipecat/services/elevenlabs.py +++ b/src/pipecat/services/elevenlabs.py @@ -101,6 +101,7 @@ class ElevenLabsTTSService(AsyncWordTTSService): push_text_frames=False, push_stop_frames=True, stop_frame_timeout_s=2.0, + sample_rate=sample_rate_from_output_format(params.output_format), **kwargs ) @@ -109,7 +110,6 @@ class ElevenLabsTTSService(AsyncWordTTSService): self._model = model self._url = url self._params = params - self._sample_rate = sample_rate_from_output_format(params.output_format) # Websocket connection to ElevenLabs. self._websocket = None @@ -209,7 +209,7 @@ class ElevenLabsTTSService(AsyncWordTTSService): self.start_word_timestamps() audio = base64.b64decode(msg["audio"]) - frame = AudioRawFrame(audio, self._sample_rate, 1) + frame = AudioRawFrame(audio, self.sample_rate, 1) await self.push_frame(frame) if msg.get("alignment"): diff --git a/src/pipecat/services/lmnt.py b/src/pipecat/services/lmnt.py index f5ad8aa1a..638e394a1 100644 --- a/src/pipecat/services/lmnt.py +++ b/src/pipecat/services/lmnt.py @@ -46,7 +46,7 @@ class LmntTTSService(AsyncTTSService): **kwargs): # Let TTSService produce TTSStoppedFrames after a short delay of # no activity. - super().__init__(push_stop_frames=True, **kwargs) + super().__init__(push_stop_frames=True, sample_rate=sample_rate, **kwargs) self._api_key = api_key self._voice_id = voice_id diff --git a/src/pipecat/services/openai.py b/src/pipecat/services/openai.py index 2d0a24589..a03b350ba 100644 --- a/src/pipecat/services/openai.py +++ b/src/pipecat/services/openai.py @@ -11,7 +11,7 @@ import json import httpx from dataclasses import dataclass -from typing import AsyncGenerator, List, Literal +from typing import AsyncGenerator, Dict, List, Literal from loguru import logger from PIL import Image @@ -55,6 +55,17 @@ except ModuleNotFoundError as e: "In order to use OpenAI, you need to `pip install pipecat-ai[openai]`. Also, set `OPENAI_API_KEY` environment variable.") raise Exception(f"Missing module: {e}") +ValidVoice = Literal["alloy", "echo", "fable", "onyx", "nova", "shimmer"] + +VALID_VOICES: Dict[str, ValidVoice] = { + "alloy": "alloy", + "echo": "echo", + "fable": "fable", + "onyx": "onyx", + "nova": "nova", + "shimmer": "shimmer", +} + class OpenAIUnhandledFunctionException(Exception): pass @@ -182,8 +193,8 @@ class BaseOpenAILLMService(LLMService): if self.has_function(function_name): await self._handle_function_call(context, tool_call_id, function_name, arguments) else: - raise OpenAIUnhandledFunctionException( - f"The LLM tried to call a function named '{function_name}', but there isn't a callback registered for that function.") + raise OpenAIUnhandledFunctionException(f"The LLM tried to call a function named '{ + function_name}', but there isn't a callback registered for that function.") async def _handle_function_call( self, @@ -307,13 +318,15 @@ class OpenAITTSService(TTSService): self, *, api_key: str | None = None, - voice: Literal["alloy", "echo", "fable", "onyx", "nova", "shimmer"] = "alloy", + voice: str = "alloy", model: Literal["tts-1", "tts-1-hd"] = "tts-1", + sample_rate: int = 24000, **kwargs): - super().__init__(**kwargs) + super().__init__(sample_rate=sample_rate, **kwargs) - self._voice = voice + self._voice: ValidVoice = VALID_VOICES.get(voice, "alloy") self._model = model + self._sample_rate = sample_rate self._client = AsyncOpenAI(api_key=api_key) @@ -322,7 +335,11 @@ class OpenAITTSService(TTSService): async def set_voice(self, voice: str): logger.debug(f"Switching TTS voice to: [{voice}]") - self._voice = voice + self._voice = VALID_VOICES.get(voice, self._voice) + + async def set_model(self, model: str): + logger.debug(f"Switching TTS model to: [{model}]") + self._model = model async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]: logger.debug(f"Generating TTS: [{text}]") @@ -348,7 +365,7 @@ class OpenAITTSService(TTSService): async for chunk in r.iter_bytes(8192): if len(chunk) > 0: await self.stop_ttfb_metrics() - frame = AudioRawFrame(chunk, 24_000, 1) + frame = AudioRawFrame(chunk, self.sample_rate, 1) yield frame await self.push_frame(TTSStoppedFrame()) except BadRequestError as e: diff --git a/src/pipecat/services/playht.py b/src/pipecat/services/playht.py index 2f4ae9851..c3200fee9 100644 --- a/src/pipecat/services/playht.py +++ b/src/pipecat/services/playht.py @@ -27,8 +27,15 @@ except ModuleNotFoundError as e: class PlayHTTTSService(TTSService): - def __init__(self, *, api_key: str, user_id: str, voice_url: str, **kwargs): - super().__init__(**kwargs) + def __init__( + self, + *, + api_key: str, + user_id: str, + voice_url: str, + sample_rate: int = 16000, + **kwargs): + super().__init__(sample_rate=sample_rate, **kwargs) self._user_id = user_id self._speech_key = api_key @@ -39,13 +46,17 @@ class PlayHTTTSService(TTSService): ) self._options = TTSOptions( voice=voice_url, - sample_rate=16000, + sample_rate=sample_rate, quality="higher", format=Format.FORMAT_WAV) def can_generate_metrics(self) -> bool: return True + async def set_voice(self, voice: str): + logger.debug(f"Switching TTS voice to: [{voice}]") + self._options.voice = voice + async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]: logger.debug(f"Generating TTS: [{text}]")