Add sample_rate setting to TTS services
This commit is contained in:
@@ -171,6 +171,7 @@ class TTSService(AIService):
|
||||
push_stop_frames: bool = False,
|
||||
# if push_stop_frames is True, wait for this idle period before pushing TTSStoppedFrame
|
||||
stop_frame_timeout_s: float = 1.0,
|
||||
sample_rate: int = 16000,
|
||||
**kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self._aggregate_sentences: bool = aggregate_sentences
|
||||
@@ -180,6 +181,11 @@ class TTSService(AIService):
|
||||
self._stop_frame_task: Optional[asyncio.Task] = None
|
||||
self._stop_frame_queue: asyncio.Queue = asyncio.Queue()
|
||||
self._current_sentence: str = ""
|
||||
self._sample_rate: int = sample_rate
|
||||
|
||||
@property
|
||||
def sample_rate(self) -> int:
|
||||
return self._sample_rate
|
||||
|
||||
@abstractmethod
|
||||
async def set_model(self, model: str):
|
||||
|
||||
@@ -72,13 +72,21 @@ class AzureLLMService(BaseOpenAILLMService):
|
||||
|
||||
|
||||
class AzureTTSService(TTSService):
|
||||
def __init__(self, *, api_key: str, region: str, voice="en-US-SaraNeural", **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
api_key: str,
|
||||
region: str,
|
||||
voice="en-US-SaraNeural",
|
||||
sample_rate: int = 16000,
|
||||
**kwargs):
|
||||
super().__init__(sample_rate=sample_rate, **kwargs)
|
||||
|
||||
speech_config = SpeechConfig(subscription=api_key, region=region)
|
||||
self._speech_synthesizer = SpeechSynthesizer(speech_config=speech_config, audio_config=None)
|
||||
|
||||
self._voice = voice
|
||||
self._sample_rate = sample_rate
|
||||
|
||||
def can_generate_metrics(self) -> bool:
|
||||
return True
|
||||
@@ -109,7 +117,7 @@ class AzureTTSService(TTSService):
|
||||
await self.stop_ttfb_metrics()
|
||||
await self.push_frame(TTSStartedFrame())
|
||||
# Azure always sends a 44-byte header. Strip it off.
|
||||
yield AudioRawFrame(audio=result.audio_data[44:], sample_rate=16000, num_channels=1)
|
||||
yield AudioRawFrame(audio=result.audio_data[44:], sample_rate=self._sample_rate, num_channels=1)
|
||||
await self.push_frame(TTSStoppedFrame())
|
||||
elif result.reason == ResultReason.Canceled:
|
||||
cancellation_details = result.cancellation_details
|
||||
|
||||
@@ -84,7 +84,7 @@ class CartesiaTTSService(AsyncWordTTSService):
|
||||
# if we're interrupted. Cartesia gives us word-by-word timestamps. We
|
||||
# can use those to generate text frames ourselves aligned with the
|
||||
# playout timing of the audio!
|
||||
super().__init__(aggregate_sentences=True, push_text_frames=False, **kwargs)
|
||||
super().__init__(aggregate_sentences=True, push_text_frames=False, sample_rate=sample_rate, **kwargs)
|
||||
|
||||
self._api_key = api_key
|
||||
self._cartesia_version = cartesia_version
|
||||
|
||||
@@ -101,6 +101,7 @@ class ElevenLabsTTSService(AsyncWordTTSService):
|
||||
push_text_frames=False,
|
||||
push_stop_frames=True,
|
||||
stop_frame_timeout_s=2.0,
|
||||
sample_rate=sample_rate_from_output_format(params.output_format),
|
||||
**kwargs
|
||||
)
|
||||
|
||||
@@ -109,7 +110,6 @@ class ElevenLabsTTSService(AsyncWordTTSService):
|
||||
self._model = model
|
||||
self._url = url
|
||||
self._params = params
|
||||
self._sample_rate = sample_rate_from_output_format(params.output_format)
|
||||
|
||||
# Websocket connection to ElevenLabs.
|
||||
self._websocket = None
|
||||
@@ -209,7 +209,7 @@ class ElevenLabsTTSService(AsyncWordTTSService):
|
||||
self.start_word_timestamps()
|
||||
|
||||
audio = base64.b64decode(msg["audio"])
|
||||
frame = AudioRawFrame(audio, self._sample_rate, 1)
|
||||
frame = AudioRawFrame(audio, self.sample_rate, 1)
|
||||
await self.push_frame(frame)
|
||||
|
||||
if msg.get("alignment"):
|
||||
|
||||
@@ -46,7 +46,7 @@ class LmntTTSService(AsyncTTSService):
|
||||
**kwargs):
|
||||
# Let TTSService produce TTSStoppedFrames after a short delay of
|
||||
# no activity.
|
||||
super().__init__(push_stop_frames=True, **kwargs)
|
||||
super().__init__(push_stop_frames=True, sample_rate=sample_rate, **kwargs)
|
||||
|
||||
self._api_key = api_key
|
||||
self._voice_id = voice_id
|
||||
|
||||
@@ -11,7 +11,7 @@ import json
|
||||
import httpx
|
||||
from dataclasses import dataclass
|
||||
|
||||
from typing import AsyncGenerator, List, Literal
|
||||
from typing import AsyncGenerator, Dict, List, Literal
|
||||
|
||||
from loguru import logger
|
||||
from PIL import Image
|
||||
@@ -55,6 +55,17 @@ except ModuleNotFoundError as e:
|
||||
"In order to use OpenAI, you need to `pip install pipecat-ai[openai]`. Also, set `OPENAI_API_KEY` environment variable.")
|
||||
raise Exception(f"Missing module: {e}")
|
||||
|
||||
ValidVoice = Literal["alloy", "echo", "fable", "onyx", "nova", "shimmer"]
|
||||
|
||||
VALID_VOICES: Dict[str, ValidVoice] = {
|
||||
"alloy": "alloy",
|
||||
"echo": "echo",
|
||||
"fable": "fable",
|
||||
"onyx": "onyx",
|
||||
"nova": "nova",
|
||||
"shimmer": "shimmer",
|
||||
}
|
||||
|
||||
|
||||
class OpenAIUnhandledFunctionException(Exception):
|
||||
pass
|
||||
@@ -182,8 +193,8 @@ class BaseOpenAILLMService(LLMService):
|
||||
if self.has_function(function_name):
|
||||
await self._handle_function_call(context, tool_call_id, function_name, arguments)
|
||||
else:
|
||||
raise OpenAIUnhandledFunctionException(
|
||||
f"The LLM tried to call a function named '{function_name}', but there isn't a callback registered for that function.")
|
||||
raise OpenAIUnhandledFunctionException(f"The LLM tried to call a function named '{
|
||||
function_name}', but there isn't a callback registered for that function.")
|
||||
|
||||
async def _handle_function_call(
|
||||
self,
|
||||
@@ -307,13 +318,15 @@ class OpenAITTSService(TTSService):
|
||||
self,
|
||||
*,
|
||||
api_key: str | None = None,
|
||||
voice: Literal["alloy", "echo", "fable", "onyx", "nova", "shimmer"] = "alloy",
|
||||
voice: str = "alloy",
|
||||
model: Literal["tts-1", "tts-1-hd"] = "tts-1",
|
||||
sample_rate: int = 24000,
|
||||
**kwargs):
|
||||
super().__init__(**kwargs)
|
||||
super().__init__(sample_rate=sample_rate, **kwargs)
|
||||
|
||||
self._voice = voice
|
||||
self._voice: ValidVoice = VALID_VOICES.get(voice, "alloy")
|
||||
self._model = model
|
||||
self._sample_rate = sample_rate
|
||||
|
||||
self._client = AsyncOpenAI(api_key=api_key)
|
||||
|
||||
@@ -322,7 +335,11 @@ class OpenAITTSService(TTSService):
|
||||
|
||||
async def set_voice(self, voice: str):
|
||||
logger.debug(f"Switching TTS voice to: [{voice}]")
|
||||
self._voice = voice
|
||||
self._voice = VALID_VOICES.get(voice, self._voice)
|
||||
|
||||
async def set_model(self, model: str):
|
||||
logger.debug(f"Switching TTS model to: [{model}]")
|
||||
self._model = model
|
||||
|
||||
async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]:
|
||||
logger.debug(f"Generating TTS: [{text}]")
|
||||
@@ -348,7 +365,7 @@ class OpenAITTSService(TTSService):
|
||||
async for chunk in r.iter_bytes(8192):
|
||||
if len(chunk) > 0:
|
||||
await self.stop_ttfb_metrics()
|
||||
frame = AudioRawFrame(chunk, 24_000, 1)
|
||||
frame = AudioRawFrame(chunk, self.sample_rate, 1)
|
||||
yield frame
|
||||
await self.push_frame(TTSStoppedFrame())
|
||||
except BadRequestError as e:
|
||||
|
||||
@@ -27,8 +27,15 @@ except ModuleNotFoundError as e:
|
||||
|
||||
class PlayHTTTSService(TTSService):
|
||||
|
||||
def __init__(self, *, api_key: str, user_id: str, voice_url: str, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
api_key: str,
|
||||
user_id: str,
|
||||
voice_url: str,
|
||||
sample_rate: int = 16000,
|
||||
**kwargs):
|
||||
super().__init__(sample_rate=sample_rate, **kwargs)
|
||||
|
||||
self._user_id = user_id
|
||||
self._speech_key = api_key
|
||||
@@ -39,13 +46,17 @@ class PlayHTTTSService(TTSService):
|
||||
)
|
||||
self._options = TTSOptions(
|
||||
voice=voice_url,
|
||||
sample_rate=16000,
|
||||
sample_rate=sample_rate,
|
||||
quality="higher",
|
||||
format=Format.FORMAT_WAV)
|
||||
|
||||
def can_generate_metrics(self) -> bool:
|
||||
return True
|
||||
|
||||
async def set_voice(self, voice: str):
|
||||
logger.debug(f"Switching TTS voice to: [{voice}]")
|
||||
self._options.voice = voice
|
||||
|
||||
async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]:
|
||||
logger.debug(f"Generating TTS: [{text}]")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user