Type language fields and centralize conversion in STT services.

Change `TTSSettings.language` and `STTSettings.language` from `Any` to `Language | str | _NotGiven`. Add `language_to_service_language` base method and centralized `isinstance`-guarded conversion in `STTService._update_settings` (mirroring TTS). Update the TTS guard from `is not None` to `isinstance(…, Language)` so raw strings pass through unchanged.

Remove now-redundant per-service language conversion from `_update_settings` overrides (ElevenLabs, Azure, Fal, Whisper). Add `language_to_service_language` to Azure STT so the centralized conversion picks it up. Fix AWS and NVIDIA STT `__init__` to convert language at construction time, then simplify their runtime accessors to read `_settings.language` directly.
This commit is contained in:
Paul Kompfner
2026-02-17 14:49:26 -05:00
parent d2372c127a
commit 7dc16b1d92
10 changed files with 55 additions and 47 deletions

View File

@@ -102,7 +102,7 @@ class AWSTranscribeSTTService(WebsocketSTTService):
super().__init__(ttfs_p99_latency=ttfs_p99_latency, **kwargs)
self._settings = AWSTranscribeSTTSettings(
language=language,
language=self.language_to_service_language(language) or "en-US",
sample_rate=sample_rate,
media_encoding="linear16",
number_of_channels=1,
@@ -251,9 +251,9 @@ class AWSTranscribeSTTService(WebsocketSTTService):
logger.debug("Connecting to AWS Transcribe WebSocket")
language_code = self.language_to_service_language(Language(self._settings.language))
language_code = self._settings.language
if not language_code:
raise ValueError(f"Unsupported language: {self._settings.language}")
raise ValueError(f"Unsupported language: {language_code}")
# Generate random websocket key
websocket_key = "".join(

View File

@@ -123,6 +123,17 @@ class AzureSTTService(STTService):
"""
return True
def language_to_service_language(self, language: Language) -> Optional[str]:
"""Convert a Language enum to Azure service-specific language code.
Args:
language: The language to convert.
Returns:
The Azure-specific language identifier, or None if not supported.
"""
return language_to_azure_language(language)
async def _update_settings(self, update: STTSettings) -> dict[str, Any]:
"""Apply a settings update.
@@ -130,13 +141,6 @@ class AzureSTTService(STTService):
"""
changed = await super()._update_settings(update)
if "language" in changed:
# Convert Language enum to Azure language code for consistency.
lang = self._settings.language
if isinstance(lang, Language):
lang = language_to_azure_language(lang)
self._settings.language = lang
# TODO: someday we could reconnect here to apply updated settings.
# Code might look something like the below:
# if "language" in changed:

View File

@@ -34,7 +34,7 @@ from pipecat.frames.frames import (
VADUserStoppedSpeakingFrame,
)
from pipecat.processors.frame_processor import FrameDirection
from pipecat.services.settings import NOT_GIVEN, STTSettings, is_given
from pipecat.services.settings import NOT_GIVEN, STTSettings
from pipecat.services.stt_latency import ELEVENLABS_REALTIME_TTFS_P99, ELEVENLABS_TTFS_P99
from pipecat.services.stt_service import SegmentedSTTService, WebsocketSTTService
from pipecat.transcriptions.language import Language, resolve_language
@@ -306,12 +306,6 @@ class ElevenLabsSTTService(SegmentedSTTService):
Returns:
Dict mapping changed field names to their previous values.
"""
# Convert language to ElevenLabs format before applying
if is_given(update.language) and isinstance(update.language, Language):
converted = self.language_to_service_language(update.language)
if converted is not None:
update.language = converted
changed = await super()._update_settings(update)
if "model" in changed:
@@ -555,12 +549,6 @@ class ElevenLabsRealtimeSTTService(WebsocketSTTService):
Returns:
Dict mapping changed field names to their previous values.
"""
# Convert language to ElevenLabs format before applying
if is_given(update.language) and isinstance(update.language, Language):
converted = language_to_elevenlabs_language(update.language)
if converted is not None:
update.language = converted
changed = await super()._update_settings(update)
if not changed:

View File

@@ -252,15 +252,8 @@ class FalSTTService(SegmentedSTTService):
return language_to_fal_language(language)
async def _update_settings(self, update: STTSettings) -> dict[str, Any]:
"""Apply a settings update, converting language if changed."""
"""Apply a settings update."""
changed = await super()._update_settings(update)
if "language" in changed:
# Convert the Language enum to a Fal language code.
lang = self._settings.language
if isinstance(lang, Language):
self._settings.language = self.language_to_service_language(lang)
return changed
@traced_stt

View File

@@ -488,7 +488,8 @@ class NvidiaSegmentedSTTService(SegmentedSTTService):
self._config = None
self._asr_service = None
self._settings = NvidiaSegmentedSTTSettings(
language=params.language or Language.EN_US,
language=self.language_to_service_language(params.language or Language.EN_US)
or "en-US",
profanity_filter=params.profanity_filter,
automatic_punctuation=params.automatic_punctuation,
verbatim_transcripts=params.verbatim_transcripts,
@@ -523,8 +524,8 @@ class NvidiaSegmentedSTTService(SegmentedSTTService):
self._asr_service = riva.client.ASRService(auth)
def _get_language_code(self) -> str:
"""Resolve the current language enum to an NVIDIA Riva language code string."""
return self.language_to_service_language(self._settings.language) or "en-US"
"""Get the current NVIDIA Riva language code string."""
return self._settings.language or "en-US"
def _create_recognition_config(self):
"""Create the NVIDIA Riva ASR recognition configuration."""

View File

@@ -34,6 +34,8 @@ from typing import TYPE_CHECKING, Any, ClassVar, Dict, Mapping, Optional, Type,
from loguru import logger
from pipecat.transcriptions.language import Language
if TYPE_CHECKING:
from pipecat.turns.user_turn_completion_mixin import UserTurnCompletionConfig
@@ -294,11 +296,13 @@ class TTSSettings(ServiceSettings):
Parameters:
model: TTS model identifier.
voice: Voice identifier or name.
language: Language for speech synthesis.
language: Language for speech synthesis. Accepts a ``Language`` enum
(converted to a service-specific string) or a raw string (stored
as-is).
"""
voice: str | _NotGiven = field(default_factory=lambda: NOT_GIVEN)
language: Any = field(default_factory=lambda: NOT_GIVEN)
language: Language | str | _NotGiven = field(default_factory=lambda: NOT_GIVEN)
_aliases: ClassVar[Dict[str, str]] = {"voice_id": "voice"}
@@ -309,7 +313,9 @@ class STTSettings(ServiceSettings):
Parameters:
model: STT model identifier.
language: Language for speech recognition.
language: Language for speech recognition. Accepts a ``Language`` enum
(converted to a service-specific string) or a raw string (stored
as-is).
"""
language: Any = field(default_factory=lambda: NOT_GIVEN)
language: Language | str | _NotGiven = field(default_factory=lambda: NOT_GIVEN)

View File

@@ -35,7 +35,7 @@ from pipecat.frames.frames import (
from pipecat.metrics.metrics import TTFBMetricsData
from pipecat.processors.frame_processor import FrameDirection
from pipecat.services.ai_service import AIService
from pipecat.services.settings import STTSettings
from pipecat.services.settings import STTSettings, is_given
from pipecat.services.stt_latency import DEFAULT_TTFS_P99
from pipecat.services.websocket_service import WebsocketService
from pipecat.transcriptions.language import Language
@@ -206,6 +206,17 @@ class STTService(AIService):
settings_cls = type(self._settings)
await self._update_settings(settings_cls(language=language))
def language_to_service_language(self, language: Language) -> Optional[str]:
"""Convert a language to the service-specific language format.
Args:
language: The language to convert.
Returns:
The service-specific language identifier, or None if not supported.
"""
return Language(language)
@abstractmethod
async def run_stt(self, audio: bytes) -> AsyncGenerator[Frame, None]:
"""Run speech-to-text on the provided audio data.
@@ -239,8 +250,9 @@ class STTService(AIService):
async def _update_settings(self, update: STTSettings) -> dict[str, Any]:
"""Apply an STT settings update.
Handles ``model`` (via parent). Does **not** call ``set_language``
— concrete services should override this method and handle language
Handles ``model`` (via parent). Translates ``Language`` enum values
before applying so the stored value is a service-specific string.
Concrete services should override this method and handle language
changes (including any reconnect logic) based on the returned
changed-field dict.
@@ -250,6 +262,12 @@ class STTService(AIService):
Returns:
Dict mapping changed field names to their previous values.
"""
# Translate language *before* applying so the stored value is canonical
if is_given(update.language) and isinstance(update.language, Language):
converted = self.language_to_service_language(update.language)
if converted is not None:
update.language = converted
changed = await super()._update_settings(update)
return changed

View File

@@ -448,7 +448,7 @@ class TTSService(AIService):
Dict mapping changed field names to their previous values.
"""
# Translate language *before* applying so the stored value is canonical
if is_given(update.language) and update.language is not None:
if is_given(update.language) and isinstance(update.language, Language):
converted = self.language_to_service_language(update.language)
if converted is not None:
update.language = converted

View File

@@ -183,7 +183,7 @@ class BaseWhisperSTTService(SegmentedSTTService):
changed = await super()._update_settings(update)
if "language" in changed:
self._language = self.language_to_service_language(Language(self._settings.language))
self._language = self._settings.language
if "prompt" in changed:
self._prompt = self._settings.prompt
if "temperature" in changed:

View File

@@ -319,9 +319,8 @@ class WhisperSTTService(SegmentedSTTService):
# Divide by 32768 because we have signed 16-bit data.
audio_float = np.frombuffer(audio, dtype=np.int16).astype(np.float32) / 32768.0
whisper_lang = self.language_to_service_language(self._settings.language)
segments, _ = await asyncio.to_thread(
self._model.transcribe, audio_float, language=whisper_lang
self._model.transcribe, audio_float, language=self._settings.language
)
text: str = ""
for segment in segments:
@@ -419,13 +418,12 @@ class WhisperSTTServiceMLX(WhisperSTTService):
# Divide by 32768 because we have signed 16-bit data.
audio_float = np.frombuffer(audio, dtype=np.int16).astype(np.float32) / 32768.0
whisper_lang = self.language_to_service_language(self._settings.language)
chunk = await asyncio.to_thread(
mlx_whisper.transcribe,
audio_float,
path_or_hf_repo=self.model_name,
temperature=self._temperature,
language=whisper_lang,
language=self._settings.language,
)
text: str = ""
for segment in chunk.get("segments", []):