feat(google): add location parameter to TTS services
This commit is contained in:
@@ -40,6 +40,7 @@ from pipecat.services.tts_service import TTSService
|
||||
from pipecat.transcriptions.language import Language, resolve_language
|
||||
|
||||
try:
|
||||
from google.api_core.client_options import ClientOptions
|
||||
from google.auth import default
|
||||
from google.auth.exceptions import GoogleAuthError
|
||||
from google.cloud import texttospeech_v1
|
||||
@@ -515,6 +516,7 @@ class GoogleHttpTTSService(TTSService):
|
||||
*,
|
||||
credentials: Optional[str] = None,
|
||||
credentials_path: Optional[str] = None,
|
||||
location: Optional[str] = None,
|
||||
voice_id: str = "en-US-Chirp3-HD-Charon",
|
||||
sample_rate: Optional[int] = None,
|
||||
params: Optional[InputParams] = None,
|
||||
@@ -525,6 +527,7 @@ class GoogleHttpTTSService(TTSService):
|
||||
Args:
|
||||
credentials: JSON string containing Google Cloud service account credentials.
|
||||
credentials_path: Path to Google Cloud service account JSON file.
|
||||
location: Google Cloud location for regional endpoint (e.g., "us-central1").
|
||||
voice_id: Google TTS voice identifier (e.g., "en-US-Standard-A").
|
||||
sample_rate: Audio sample rate in Hz. If None, uses default.
|
||||
params: Voice customization parameters including pitch, rate, volume, etc.
|
||||
@@ -534,6 +537,7 @@ class GoogleHttpTTSService(TTSService):
|
||||
|
||||
params = params or GoogleHttpTTSService.InputParams()
|
||||
|
||||
self._location = location
|
||||
self._settings = {
|
||||
"pitch": params.pitch,
|
||||
"rate": params.rate,
|
||||
@@ -586,7 +590,15 @@ class GoogleHttpTTSService(TTSService):
|
||||
if not creds:
|
||||
raise ValueError("No valid credentials provided.")
|
||||
|
||||
return texttospeech_v1.TextToSpeechAsyncClient(credentials=creds)
|
||||
client_options = None
|
||||
if self._location:
|
||||
client_options = ClientOptions(
|
||||
api_endpoint=f"{self._location}-texttospeech.googleapis.com"
|
||||
)
|
||||
|
||||
return texttospeech_v1.TextToSpeechAsyncClient(
|
||||
credentials=creds, client_options=client_options
|
||||
)
|
||||
|
||||
def can_generate_metrics(self) -> bool:
|
||||
"""Check if this service can generate processing metrics.
|
||||
@@ -783,7 +795,15 @@ class GoogleBaseTTSService(TTSService):
|
||||
if not creds:
|
||||
raise ValueError("No valid credentials provided.")
|
||||
|
||||
return texttospeech_v1.TextToSpeechAsyncClient(credentials=creds)
|
||||
client_options = None
|
||||
if self._location:
|
||||
client_options = ClientOptions(
|
||||
api_endpoint=f"{self._location}-texttospeech.googleapis.com"
|
||||
)
|
||||
|
||||
return texttospeech_v1.TextToSpeechAsyncClient(
|
||||
credentials=creds, client_options=client_options
|
||||
)
|
||||
|
||||
def can_generate_metrics(self) -> bool:
|
||||
"""Check if this service can generate processing metrics.
|
||||
@@ -903,6 +923,7 @@ class GoogleTTSService(GoogleBaseTTSService):
|
||||
*,
|
||||
credentials: Optional[str] = None,
|
||||
credentials_path: Optional[str] = None,
|
||||
location: Optional[str] = None,
|
||||
voice_id: str = "en-US-Chirp3-HD-Charon",
|
||||
voice_cloning_key: Optional[str] = None,
|
||||
sample_rate: Optional[int] = None,
|
||||
@@ -914,6 +935,7 @@ class GoogleTTSService(GoogleBaseTTSService):
|
||||
Args:
|
||||
credentials: JSON string containing Google Cloud service account credentials.
|
||||
credentials_path: Path to Google Cloud service account JSON file.
|
||||
location: Google Cloud location for regional endpoint (e.g., "us-central1").
|
||||
voice_id: Google TTS voice identifier (e.g., "en-US-Chirp3-HD-Charon").
|
||||
voice_cloning_key: The voice cloning key for Chirp 3 custom voices.
|
||||
sample_rate: Audio sample rate in Hz. If None, uses default.
|
||||
@@ -924,6 +946,7 @@ class GoogleTTSService(GoogleBaseTTSService):
|
||||
|
||||
params = params or GoogleTTSService.InputParams()
|
||||
|
||||
self._location = location
|
||||
self._settings = {
|
||||
"language": self.language_to_service_language(params.language)
|
||||
if params.language
|
||||
@@ -1083,6 +1106,7 @@ class GeminiTTSService(GoogleBaseTTSService):
|
||||
model: str = "gemini-2.5-flash-tts",
|
||||
credentials: Optional[str] = None,
|
||||
credentials_path: Optional[str] = None,
|
||||
location: Optional[str] = None,
|
||||
voice_id: str = "Kore",
|
||||
sample_rate: Optional[int] = None,
|
||||
params: Optional[InputParams] = None,
|
||||
@@ -1101,6 +1125,7 @@ class GeminiTTSService(GoogleBaseTTSService):
|
||||
"gemini-2.5-flash-tts" or "gemini-2.5-pro-tts".
|
||||
credentials: JSON string containing Google Cloud service account credentials.
|
||||
credentials_path: Path to Google Cloud service account JSON file.
|
||||
location: Google Cloud location for regional endpoint (e.g., "us-central1").
|
||||
voice_id: Voice name from the available Gemini voices.
|
||||
sample_rate: Audio sample rate in Hz. If None, uses Google's default 24kHz.
|
||||
params: TTS configuration parameters.
|
||||
@@ -1127,6 +1152,7 @@ class GeminiTTSService(GoogleBaseTTSService):
|
||||
if voice_id not in self.AVAILABLE_VOICES:
|
||||
logger.warning(f"Voice '{voice_id}' not in known voices list. Using anyway.")
|
||||
|
||||
self._location = location
|
||||
self._model = model
|
||||
self._voice_id = voice_id
|
||||
self._settings = {
|
||||
|
||||
Reference in New Issue
Block a user