Add sample_rate setting to TTS services

This commit is contained in:
Mark Backman
2024-09-18 14:19:04 -04:00
parent 13a4a05388
commit eadd68d40b
7 changed files with 60 additions and 18 deletions

View File

@@ -171,6 +171,7 @@ class TTSService(AIService):
push_stop_frames: bool = False,
# if push_stop_frames is True, wait for this idle period before pushing TTSStoppedFrame
stop_frame_timeout_s: float = 1.0,
sample_rate: int = 16000,
**kwargs):
super().__init__(**kwargs)
self._aggregate_sentences: bool = aggregate_sentences
@@ -180,6 +181,11 @@ class TTSService(AIService):
self._stop_frame_task: Optional[asyncio.Task] = None
self._stop_frame_queue: asyncio.Queue = asyncio.Queue()
self._current_sentence: str = ""
self._sample_rate: int = sample_rate
@property
def sample_rate(self) -> int:
return self._sample_rate
@abstractmethod
async def set_model(self, model: str):

View File

@@ -72,13 +72,21 @@ class AzureLLMService(BaseOpenAILLMService):
class AzureTTSService(TTSService):
def __init__(self, *, api_key: str, region: str, voice="en-US-SaraNeural", **kwargs):
super().__init__(**kwargs)
def __init__(
self,
*,
api_key: str,
region: str,
voice="en-US-SaraNeural",
sample_rate: int = 16000,
**kwargs):
super().__init__(sample_rate=sample_rate, **kwargs)
speech_config = SpeechConfig(subscription=api_key, region=region)
self._speech_synthesizer = SpeechSynthesizer(speech_config=speech_config, audio_config=None)
self._voice = voice
self._sample_rate = sample_rate
def can_generate_metrics(self) -> bool:
return True
@@ -109,7 +117,7 @@ class AzureTTSService(TTSService):
await self.stop_ttfb_metrics()
await self.push_frame(TTSStartedFrame())
# Azure always sends a 44-byte header. Strip it off.
yield AudioRawFrame(audio=result.audio_data[44:], sample_rate=16000, num_channels=1)
yield AudioRawFrame(audio=result.audio_data[44:], sample_rate=self._sample_rate, num_channels=1)
await self.push_frame(TTSStoppedFrame())
elif result.reason == ResultReason.Canceled:
cancellation_details = result.cancellation_details

View File

@@ -84,7 +84,7 @@ class CartesiaTTSService(AsyncWordTTSService):
# if we're interrupted. Cartesia gives us word-by-word timestamps. We
# can use those to generate text frames ourselves aligned with the
# playout timing of the audio!
super().__init__(aggregate_sentences=True, push_text_frames=False, **kwargs)
super().__init__(aggregate_sentences=True, push_text_frames=False, sample_rate=sample_rate, **kwargs)
self._api_key = api_key
self._cartesia_version = cartesia_version

View File

@@ -101,6 +101,7 @@ class ElevenLabsTTSService(AsyncWordTTSService):
push_text_frames=False,
push_stop_frames=True,
stop_frame_timeout_s=2.0,
sample_rate=sample_rate_from_output_format(params.output_format),
**kwargs
)
@@ -109,7 +110,6 @@ class ElevenLabsTTSService(AsyncWordTTSService):
self._model = model
self._url = url
self._params = params
self._sample_rate = sample_rate_from_output_format(params.output_format)
# Websocket connection to ElevenLabs.
self._websocket = None
@@ -209,7 +209,7 @@ class ElevenLabsTTSService(AsyncWordTTSService):
self.start_word_timestamps()
audio = base64.b64decode(msg["audio"])
frame = AudioRawFrame(audio, self._sample_rate, 1)
frame = AudioRawFrame(audio, self.sample_rate, 1)
await self.push_frame(frame)
if msg.get("alignment"):

View File

@@ -46,7 +46,7 @@ class LmntTTSService(AsyncTTSService):
**kwargs):
# Let TTSService produce TTSStoppedFrames after a short delay of
# no activity.
super().__init__(push_stop_frames=True, **kwargs)
super().__init__(push_stop_frames=True, sample_rate=sample_rate, **kwargs)
self._api_key = api_key
self._voice_id = voice_id

View File

@@ -11,7 +11,7 @@ import json
import httpx
from dataclasses import dataclass
from typing import AsyncGenerator, List, Literal
from typing import AsyncGenerator, Dict, List, Literal
from loguru import logger
from PIL import Image
@@ -55,6 +55,17 @@ except ModuleNotFoundError as e:
"In order to use OpenAI, you need to `pip install pipecat-ai[openai]`. Also, set `OPENAI_API_KEY` environment variable.")
raise Exception(f"Missing module: {e}")
ValidVoice = Literal["alloy", "echo", "fable", "onyx", "nova", "shimmer"]
VALID_VOICES: Dict[str, ValidVoice] = {
"alloy": "alloy",
"echo": "echo",
"fable": "fable",
"onyx": "onyx",
"nova": "nova",
"shimmer": "shimmer",
}
class OpenAIUnhandledFunctionException(Exception):
pass
@@ -182,8 +193,8 @@ class BaseOpenAILLMService(LLMService):
if self.has_function(function_name):
await self._handle_function_call(context, tool_call_id, function_name, arguments)
else:
raise OpenAIUnhandledFunctionException(
f"The LLM tried to call a function named '{function_name}', but there isn't a callback registered for that function.")
raise OpenAIUnhandledFunctionException(f"The LLM tried to call a function named '{
function_name}', but there isn't a callback registered for that function.")
async def _handle_function_call(
self,
@@ -307,13 +318,15 @@ class OpenAITTSService(TTSService):
self,
*,
api_key: str | None = None,
voice: Literal["alloy", "echo", "fable", "onyx", "nova", "shimmer"] = "alloy",
voice: str = "alloy",
model: Literal["tts-1", "tts-1-hd"] = "tts-1",
sample_rate: int = 24000,
**kwargs):
super().__init__(**kwargs)
super().__init__(sample_rate=sample_rate, **kwargs)
self._voice = voice
self._voice: ValidVoice = VALID_VOICES.get(voice, "alloy")
self._model = model
self._sample_rate = sample_rate
self._client = AsyncOpenAI(api_key=api_key)
@@ -322,7 +335,11 @@ class OpenAITTSService(TTSService):
async def set_voice(self, voice: str):
logger.debug(f"Switching TTS voice to: [{voice}]")
self._voice = voice
self._voice = VALID_VOICES.get(voice, self._voice)
async def set_model(self, model: str):
logger.debug(f"Switching TTS model to: [{model}]")
self._model = model
async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]:
logger.debug(f"Generating TTS: [{text}]")
@@ -348,7 +365,7 @@ class OpenAITTSService(TTSService):
async for chunk in r.iter_bytes(8192):
if len(chunk) > 0:
await self.stop_ttfb_metrics()
frame = AudioRawFrame(chunk, 24_000, 1)
frame = AudioRawFrame(chunk, self.sample_rate, 1)
yield frame
await self.push_frame(TTSStoppedFrame())
except BadRequestError as e:

View File

@@ -27,8 +27,15 @@ except ModuleNotFoundError as e:
class PlayHTTTSService(TTSService):
def __init__(self, *, api_key: str, user_id: str, voice_url: str, **kwargs):
super().__init__(**kwargs)
def __init__(
self,
*,
api_key: str,
user_id: str,
voice_url: str,
sample_rate: int = 16000,
**kwargs):
super().__init__(sample_rate=sample_rate, **kwargs)
self._user_id = user_id
self._speech_key = api_key
@@ -39,13 +46,17 @@ class PlayHTTTSService(TTSService):
)
self._options = TTSOptions(
voice=voice_url,
sample_rate=16000,
sample_rate=sample_rate,
quality="higher",
format=Format.FORMAT_WAV)
def can_generate_metrics(self) -> bool:
return True
async def set_voice(self, voice: str):
logger.debug(f"Switching TTS voice to: [{voice}]")
self._options.voice = voice
async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]:
logger.debug(f"Generating TTS: [{text}]")