diff --git a/src/pipecat/services/aws/stt.py b/src/pipecat/services/aws/stt.py index ae502e8be..21220e646 100644 --- a/src/pipecat/services/aws/stt.py +++ b/src/pipecat/services/aws/stt.py @@ -102,7 +102,7 @@ class AWSTranscribeSTTService(WebsocketSTTService): super().__init__(ttfs_p99_latency=ttfs_p99_latency, **kwargs) self._settings = AWSTranscribeSTTSettings( - language=language, + language=self.language_to_service_language(language) or "en-US", sample_rate=sample_rate, media_encoding="linear16", number_of_channels=1, @@ -251,9 +251,9 @@ class AWSTranscribeSTTService(WebsocketSTTService): logger.debug("Connecting to AWS Transcribe WebSocket") - language_code = self.language_to_service_language(Language(self._settings.language)) + language_code = self._settings.language if not language_code: - raise ValueError(f"Unsupported language: {self._settings.language}") + raise ValueError(f"Unsupported language: {language_code}") # Generate random websocket key websocket_key = "".join( diff --git a/src/pipecat/services/azure/stt.py b/src/pipecat/services/azure/stt.py index 8a5b09e26..7f9d3f1ba 100644 --- a/src/pipecat/services/azure/stt.py +++ b/src/pipecat/services/azure/stt.py @@ -123,6 +123,17 @@ class AzureSTTService(STTService): """ return True + def language_to_service_language(self, language: Language) -> Optional[str]: + """Convert a Language enum to Azure service-specific language code. + + Args: + language: The language to convert. + + Returns: + The Azure-specific language identifier, or None if not supported. + """ + return language_to_azure_language(language) + async def _update_settings(self, update: STTSettings) -> dict[str, Any]: """Apply a settings update. @@ -130,13 +141,6 @@ class AzureSTTService(STTService): """ changed = await super()._update_settings(update) - if "language" in changed: - # Convert Language enum to Azure language code for consistency. - lang = self._settings.language - if isinstance(lang, Language): - lang = language_to_azure_language(lang) - self._settings.language = lang - # TODO: someday we could reconnect here to apply updated settings. # Code might look something like the below: # if "language" in changed: diff --git a/src/pipecat/services/elevenlabs/stt.py b/src/pipecat/services/elevenlabs/stt.py index fd938c12e..0ef137006 100644 --- a/src/pipecat/services/elevenlabs/stt.py +++ b/src/pipecat/services/elevenlabs/stt.py @@ -34,7 +34,7 @@ from pipecat.frames.frames import ( VADUserStoppedSpeakingFrame, ) from pipecat.processors.frame_processor import FrameDirection -from pipecat.services.settings import NOT_GIVEN, STTSettings, is_given +from pipecat.services.settings import NOT_GIVEN, STTSettings from pipecat.services.stt_latency import ELEVENLABS_REALTIME_TTFS_P99, ELEVENLABS_TTFS_P99 from pipecat.services.stt_service import SegmentedSTTService, WebsocketSTTService from pipecat.transcriptions.language import Language, resolve_language @@ -306,12 +306,6 @@ class ElevenLabsSTTService(SegmentedSTTService): Returns: Dict mapping changed field names to their previous values. """ - # Convert language to ElevenLabs format before applying - if is_given(update.language) and isinstance(update.language, Language): - converted = self.language_to_service_language(update.language) - if converted is not None: - update.language = converted - changed = await super()._update_settings(update) if "model" in changed: @@ -555,12 +549,6 @@ class ElevenLabsRealtimeSTTService(WebsocketSTTService): Returns: Dict mapping changed field names to their previous values. """ - # Convert language to ElevenLabs format before applying - if is_given(update.language) and isinstance(update.language, Language): - converted = language_to_elevenlabs_language(update.language) - if converted is not None: - update.language = converted - changed = await super()._update_settings(update) if not changed: diff --git a/src/pipecat/services/fal/stt.py b/src/pipecat/services/fal/stt.py index 28b611865..a29d8d70d 100644 --- a/src/pipecat/services/fal/stt.py +++ b/src/pipecat/services/fal/stt.py @@ -252,15 +252,8 @@ class FalSTTService(SegmentedSTTService): return language_to_fal_language(language) async def _update_settings(self, update: STTSettings) -> dict[str, Any]: - """Apply a settings update, converting language if changed.""" + """Apply a settings update.""" changed = await super()._update_settings(update) - - if "language" in changed: - # Convert the Language enum to a Fal language code. - lang = self._settings.language - if isinstance(lang, Language): - self._settings.language = self.language_to_service_language(lang) - return changed @traced_stt diff --git a/src/pipecat/services/nvidia/stt.py b/src/pipecat/services/nvidia/stt.py index b0d11fc2b..8e1babec7 100644 --- a/src/pipecat/services/nvidia/stt.py +++ b/src/pipecat/services/nvidia/stt.py @@ -488,7 +488,8 @@ class NvidiaSegmentedSTTService(SegmentedSTTService): self._config = None self._asr_service = None self._settings = NvidiaSegmentedSTTSettings( - language=params.language or Language.EN_US, + language=self.language_to_service_language(params.language or Language.EN_US) + or "en-US", profanity_filter=params.profanity_filter, automatic_punctuation=params.automatic_punctuation, verbatim_transcripts=params.verbatim_transcripts, @@ -523,8 +524,8 @@ class NvidiaSegmentedSTTService(SegmentedSTTService): self._asr_service = riva.client.ASRService(auth) def _get_language_code(self) -> str: - """Resolve the current language enum to an NVIDIA Riva language code string.""" - return self.language_to_service_language(self._settings.language) or "en-US" + """Get the current NVIDIA Riva language code string.""" + return self._settings.language or "en-US" def _create_recognition_config(self): """Create the NVIDIA Riva ASR recognition configuration.""" diff --git a/src/pipecat/services/settings.py b/src/pipecat/services/settings.py index 54a25124b..fdc1c15e6 100644 --- a/src/pipecat/services/settings.py +++ b/src/pipecat/services/settings.py @@ -34,6 +34,8 @@ from typing import TYPE_CHECKING, Any, ClassVar, Dict, Mapping, Optional, Type, from loguru import logger +from pipecat.transcriptions.language import Language + if TYPE_CHECKING: from pipecat.turns.user_turn_completion_mixin import UserTurnCompletionConfig @@ -294,11 +296,13 @@ class TTSSettings(ServiceSettings): Parameters: model: TTS model identifier. voice: Voice identifier or name. - language: Language for speech synthesis. + language: Language for speech synthesis. Accepts a ``Language`` enum + (converted to a service-specific string) or a raw string (stored + as-is). """ voice: str | _NotGiven = field(default_factory=lambda: NOT_GIVEN) - language: Any = field(default_factory=lambda: NOT_GIVEN) + language: Language | str | _NotGiven = field(default_factory=lambda: NOT_GIVEN) _aliases: ClassVar[Dict[str, str]] = {"voice_id": "voice"} @@ -309,7 +313,9 @@ class STTSettings(ServiceSettings): Parameters: model: STT model identifier. - language: Language for speech recognition. + language: Language for speech recognition. Accepts a ``Language`` enum + (converted to a service-specific string) or a raw string (stored + as-is). """ - language: Any = field(default_factory=lambda: NOT_GIVEN) + language: Language | str | _NotGiven = field(default_factory=lambda: NOT_GIVEN) diff --git a/src/pipecat/services/stt_service.py b/src/pipecat/services/stt_service.py index d6dd31824..ae04ed33f 100644 --- a/src/pipecat/services/stt_service.py +++ b/src/pipecat/services/stt_service.py @@ -35,7 +35,7 @@ from pipecat.frames.frames import ( from pipecat.metrics.metrics import TTFBMetricsData from pipecat.processors.frame_processor import FrameDirection from pipecat.services.ai_service import AIService -from pipecat.services.settings import STTSettings +from pipecat.services.settings import STTSettings, is_given from pipecat.services.stt_latency import DEFAULT_TTFS_P99 from pipecat.services.websocket_service import WebsocketService from pipecat.transcriptions.language import Language @@ -206,6 +206,17 @@ class STTService(AIService): settings_cls = type(self._settings) await self._update_settings(settings_cls(language=language)) + def language_to_service_language(self, language: Language) -> Optional[str]: + """Convert a language to the service-specific language format. + + Args: + language: The language to convert. + + Returns: + The service-specific language identifier, or None if not supported. + """ + return Language(language) + @abstractmethod async def run_stt(self, audio: bytes) -> AsyncGenerator[Frame, None]: """Run speech-to-text on the provided audio data. @@ -239,8 +250,9 @@ class STTService(AIService): async def _update_settings(self, update: STTSettings) -> dict[str, Any]: """Apply an STT settings update. - Handles ``model`` (via parent). Does **not** call ``set_language`` - — concrete services should override this method and handle language + Handles ``model`` (via parent). Translates ``Language`` enum values + before applying so the stored value is a service-specific string. + Concrete services should override this method and handle language changes (including any reconnect logic) based on the returned changed-field dict. @@ -250,6 +262,12 @@ class STTService(AIService): Returns: Dict mapping changed field names to their previous values. """ + # Translate language *before* applying so the stored value is canonical + if is_given(update.language) and isinstance(update.language, Language): + converted = self.language_to_service_language(update.language) + if converted is not None: + update.language = converted + changed = await super()._update_settings(update) return changed diff --git a/src/pipecat/services/tts_service.py b/src/pipecat/services/tts_service.py index bb5cff69f..4b4b47a50 100644 --- a/src/pipecat/services/tts_service.py +++ b/src/pipecat/services/tts_service.py @@ -448,7 +448,7 @@ class TTSService(AIService): Dict mapping changed field names to their previous values. """ # Translate language *before* applying so the stored value is canonical - if is_given(update.language) and update.language is not None: + if is_given(update.language) and isinstance(update.language, Language): converted = self.language_to_service_language(update.language) if converted is not None: update.language = converted diff --git a/src/pipecat/services/whisper/base_stt.py b/src/pipecat/services/whisper/base_stt.py index a67ad1cbc..d50c24eb2 100644 --- a/src/pipecat/services/whisper/base_stt.py +++ b/src/pipecat/services/whisper/base_stt.py @@ -183,7 +183,7 @@ class BaseWhisperSTTService(SegmentedSTTService): changed = await super()._update_settings(update) if "language" in changed: - self._language = self.language_to_service_language(Language(self._settings.language)) + self._language = self._settings.language if "prompt" in changed: self._prompt = self._settings.prompt if "temperature" in changed: diff --git a/src/pipecat/services/whisper/stt.py b/src/pipecat/services/whisper/stt.py index a96c26992..d4efcb166 100644 --- a/src/pipecat/services/whisper/stt.py +++ b/src/pipecat/services/whisper/stt.py @@ -319,9 +319,8 @@ class WhisperSTTService(SegmentedSTTService): # Divide by 32768 because we have signed 16-bit data. audio_float = np.frombuffer(audio, dtype=np.int16).astype(np.float32) / 32768.0 - whisper_lang = self.language_to_service_language(self._settings.language) segments, _ = await asyncio.to_thread( - self._model.transcribe, audio_float, language=whisper_lang + self._model.transcribe, audio_float, language=self._settings.language ) text: str = "" for segment in segments: @@ -419,13 +418,12 @@ class WhisperSTTServiceMLX(WhisperSTTService): # Divide by 32768 because we have signed 16-bit data. audio_float = np.frombuffer(audio, dtype=np.int16).astype(np.float32) / 32768.0 - whisper_lang = self.language_to_service_language(self._settings.language) chunk = await asyncio.to_thread( mlx_whisper.transcribe, audio_float, path_or_hf_repo=self.model_name, temperature=self._temperature, - language=whisper_lang, + language=self._settings.language, ) text: str = "" for segment in chunk.get("segments", []):