Type language fields and centralize conversion in STT services.
Change `TTSSettings.language` and `STTSettings.language` from `Any` to `Language | str | _NotGiven`. Add `language_to_service_language` base method and centralized `isinstance`-guarded conversion in `STTService._update_settings` (mirroring TTS). Update the TTS guard from `is not None` to `isinstance(…, Language)` so raw strings pass through unchanged. Remove now-redundant per-service language conversion from `_update_settings` overrides (ElevenLabs, Azure, Fal, Whisper). Add `language_to_service_language` to Azure STT so the centralized conversion picks it up. Fix AWS and NVIDIA STT `__init__` to convert language at construction time, then simplify their runtime accessors to read `_settings.language` directly.
This commit is contained in:
@@ -102,7 +102,7 @@ class AWSTranscribeSTTService(WebsocketSTTService):
|
||||
super().__init__(ttfs_p99_latency=ttfs_p99_latency, **kwargs)
|
||||
|
||||
self._settings = AWSTranscribeSTTSettings(
|
||||
language=language,
|
||||
language=self.language_to_service_language(language) or "en-US",
|
||||
sample_rate=sample_rate,
|
||||
media_encoding="linear16",
|
||||
number_of_channels=1,
|
||||
@@ -251,9 +251,9 @@ class AWSTranscribeSTTService(WebsocketSTTService):
|
||||
|
||||
logger.debug("Connecting to AWS Transcribe WebSocket")
|
||||
|
||||
language_code = self.language_to_service_language(Language(self._settings.language))
|
||||
language_code = self._settings.language
|
||||
if not language_code:
|
||||
raise ValueError(f"Unsupported language: {self._settings.language}")
|
||||
raise ValueError(f"Unsupported language: {language_code}")
|
||||
|
||||
# Generate random websocket key
|
||||
websocket_key = "".join(
|
||||
|
||||
@@ -123,6 +123,17 @@ class AzureSTTService(STTService):
|
||||
"""
|
||||
return True
|
||||
|
||||
def language_to_service_language(self, language: Language) -> Optional[str]:
|
||||
"""Convert a Language enum to Azure service-specific language code.
|
||||
|
||||
Args:
|
||||
language: The language to convert.
|
||||
|
||||
Returns:
|
||||
The Azure-specific language identifier, or None if not supported.
|
||||
"""
|
||||
return language_to_azure_language(language)
|
||||
|
||||
async def _update_settings(self, update: STTSettings) -> dict[str, Any]:
|
||||
"""Apply a settings update.
|
||||
|
||||
@@ -130,13 +141,6 @@ class AzureSTTService(STTService):
|
||||
"""
|
||||
changed = await super()._update_settings(update)
|
||||
|
||||
if "language" in changed:
|
||||
# Convert Language enum to Azure language code for consistency.
|
||||
lang = self._settings.language
|
||||
if isinstance(lang, Language):
|
||||
lang = language_to_azure_language(lang)
|
||||
self._settings.language = lang
|
||||
|
||||
# TODO: someday we could reconnect here to apply updated settings.
|
||||
# Code might look something like the below:
|
||||
# if "language" in changed:
|
||||
|
||||
@@ -34,7 +34,7 @@ from pipecat.frames.frames import (
|
||||
VADUserStoppedSpeakingFrame,
|
||||
)
|
||||
from pipecat.processors.frame_processor import FrameDirection
|
||||
from pipecat.services.settings import NOT_GIVEN, STTSettings, is_given
|
||||
from pipecat.services.settings import NOT_GIVEN, STTSettings
|
||||
from pipecat.services.stt_latency import ELEVENLABS_REALTIME_TTFS_P99, ELEVENLABS_TTFS_P99
|
||||
from pipecat.services.stt_service import SegmentedSTTService, WebsocketSTTService
|
||||
from pipecat.transcriptions.language import Language, resolve_language
|
||||
@@ -306,12 +306,6 @@ class ElevenLabsSTTService(SegmentedSTTService):
|
||||
Returns:
|
||||
Dict mapping changed field names to their previous values.
|
||||
"""
|
||||
# Convert language to ElevenLabs format before applying
|
||||
if is_given(update.language) and isinstance(update.language, Language):
|
||||
converted = self.language_to_service_language(update.language)
|
||||
if converted is not None:
|
||||
update.language = converted
|
||||
|
||||
changed = await super()._update_settings(update)
|
||||
|
||||
if "model" in changed:
|
||||
@@ -555,12 +549,6 @@ class ElevenLabsRealtimeSTTService(WebsocketSTTService):
|
||||
Returns:
|
||||
Dict mapping changed field names to their previous values.
|
||||
"""
|
||||
# Convert language to ElevenLabs format before applying
|
||||
if is_given(update.language) and isinstance(update.language, Language):
|
||||
converted = language_to_elevenlabs_language(update.language)
|
||||
if converted is not None:
|
||||
update.language = converted
|
||||
|
||||
changed = await super()._update_settings(update)
|
||||
|
||||
if not changed:
|
||||
|
||||
@@ -252,15 +252,8 @@ class FalSTTService(SegmentedSTTService):
|
||||
return language_to_fal_language(language)
|
||||
|
||||
async def _update_settings(self, update: STTSettings) -> dict[str, Any]:
|
||||
"""Apply a settings update, converting language if changed."""
|
||||
"""Apply a settings update."""
|
||||
changed = await super()._update_settings(update)
|
||||
|
||||
if "language" in changed:
|
||||
# Convert the Language enum to a Fal language code.
|
||||
lang = self._settings.language
|
||||
if isinstance(lang, Language):
|
||||
self._settings.language = self.language_to_service_language(lang)
|
||||
|
||||
return changed
|
||||
|
||||
@traced_stt
|
||||
|
||||
@@ -488,7 +488,8 @@ class NvidiaSegmentedSTTService(SegmentedSTTService):
|
||||
self._config = None
|
||||
self._asr_service = None
|
||||
self._settings = NvidiaSegmentedSTTSettings(
|
||||
language=params.language or Language.EN_US,
|
||||
language=self.language_to_service_language(params.language or Language.EN_US)
|
||||
or "en-US",
|
||||
profanity_filter=params.profanity_filter,
|
||||
automatic_punctuation=params.automatic_punctuation,
|
||||
verbatim_transcripts=params.verbatim_transcripts,
|
||||
@@ -523,8 +524,8 @@ class NvidiaSegmentedSTTService(SegmentedSTTService):
|
||||
self._asr_service = riva.client.ASRService(auth)
|
||||
|
||||
def _get_language_code(self) -> str:
|
||||
"""Resolve the current language enum to an NVIDIA Riva language code string."""
|
||||
return self.language_to_service_language(self._settings.language) or "en-US"
|
||||
"""Get the current NVIDIA Riva language code string."""
|
||||
return self._settings.language or "en-US"
|
||||
|
||||
def _create_recognition_config(self):
|
||||
"""Create the NVIDIA Riva ASR recognition configuration."""
|
||||
|
||||
@@ -34,6 +34,8 @@ from typing import TYPE_CHECKING, Any, ClassVar, Dict, Mapping, Optional, Type,
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from pipecat.transcriptions.language import Language
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pipecat.turns.user_turn_completion_mixin import UserTurnCompletionConfig
|
||||
|
||||
@@ -294,11 +296,13 @@ class TTSSettings(ServiceSettings):
|
||||
Parameters:
|
||||
model: TTS model identifier.
|
||||
voice: Voice identifier or name.
|
||||
language: Language for speech synthesis.
|
||||
language: Language for speech synthesis. Accepts a ``Language`` enum
|
||||
(converted to a service-specific string) or a raw string (stored
|
||||
as-is).
|
||||
"""
|
||||
|
||||
voice: str | _NotGiven = field(default_factory=lambda: NOT_GIVEN)
|
||||
language: Any = field(default_factory=lambda: NOT_GIVEN)
|
||||
language: Language | str | _NotGiven = field(default_factory=lambda: NOT_GIVEN)
|
||||
|
||||
_aliases: ClassVar[Dict[str, str]] = {"voice_id": "voice"}
|
||||
|
||||
@@ -309,7 +313,9 @@ class STTSettings(ServiceSettings):
|
||||
|
||||
Parameters:
|
||||
model: STT model identifier.
|
||||
language: Language for speech recognition.
|
||||
language: Language for speech recognition. Accepts a ``Language`` enum
|
||||
(converted to a service-specific string) or a raw string (stored
|
||||
as-is).
|
||||
"""
|
||||
|
||||
language: Any = field(default_factory=lambda: NOT_GIVEN)
|
||||
language: Language | str | _NotGiven = field(default_factory=lambda: NOT_GIVEN)
|
||||
|
||||
@@ -35,7 +35,7 @@ from pipecat.frames.frames import (
|
||||
from pipecat.metrics.metrics import TTFBMetricsData
|
||||
from pipecat.processors.frame_processor import FrameDirection
|
||||
from pipecat.services.ai_service import AIService
|
||||
from pipecat.services.settings import STTSettings
|
||||
from pipecat.services.settings import STTSettings, is_given
|
||||
from pipecat.services.stt_latency import DEFAULT_TTFS_P99
|
||||
from pipecat.services.websocket_service import WebsocketService
|
||||
from pipecat.transcriptions.language import Language
|
||||
@@ -206,6 +206,17 @@ class STTService(AIService):
|
||||
settings_cls = type(self._settings)
|
||||
await self._update_settings(settings_cls(language=language))
|
||||
|
||||
def language_to_service_language(self, language: Language) -> Optional[str]:
|
||||
"""Convert a language to the service-specific language format.
|
||||
|
||||
Args:
|
||||
language: The language to convert.
|
||||
|
||||
Returns:
|
||||
The service-specific language identifier, or None if not supported.
|
||||
"""
|
||||
return Language(language)
|
||||
|
||||
@abstractmethod
|
||||
async def run_stt(self, audio: bytes) -> AsyncGenerator[Frame, None]:
|
||||
"""Run speech-to-text on the provided audio data.
|
||||
@@ -239,8 +250,9 @@ class STTService(AIService):
|
||||
async def _update_settings(self, update: STTSettings) -> dict[str, Any]:
|
||||
"""Apply an STT settings update.
|
||||
|
||||
Handles ``model`` (via parent). Does **not** call ``set_language``
|
||||
— concrete services should override this method and handle language
|
||||
Handles ``model`` (via parent). Translates ``Language`` enum values
|
||||
before applying so the stored value is a service-specific string.
|
||||
Concrete services should override this method and handle language
|
||||
changes (including any reconnect logic) based on the returned
|
||||
changed-field dict.
|
||||
|
||||
@@ -250,6 +262,12 @@ class STTService(AIService):
|
||||
Returns:
|
||||
Dict mapping changed field names to their previous values.
|
||||
"""
|
||||
# Translate language *before* applying so the stored value is canonical
|
||||
if is_given(update.language) and isinstance(update.language, Language):
|
||||
converted = self.language_to_service_language(update.language)
|
||||
if converted is not None:
|
||||
update.language = converted
|
||||
|
||||
changed = await super()._update_settings(update)
|
||||
return changed
|
||||
|
||||
|
||||
@@ -448,7 +448,7 @@ class TTSService(AIService):
|
||||
Dict mapping changed field names to their previous values.
|
||||
"""
|
||||
# Translate language *before* applying so the stored value is canonical
|
||||
if is_given(update.language) and update.language is not None:
|
||||
if is_given(update.language) and isinstance(update.language, Language):
|
||||
converted = self.language_to_service_language(update.language)
|
||||
if converted is not None:
|
||||
update.language = converted
|
||||
|
||||
@@ -183,7 +183,7 @@ class BaseWhisperSTTService(SegmentedSTTService):
|
||||
changed = await super()._update_settings(update)
|
||||
|
||||
if "language" in changed:
|
||||
self._language = self.language_to_service_language(Language(self._settings.language))
|
||||
self._language = self._settings.language
|
||||
if "prompt" in changed:
|
||||
self._prompt = self._settings.prompt
|
||||
if "temperature" in changed:
|
||||
|
||||
@@ -319,9 +319,8 @@ class WhisperSTTService(SegmentedSTTService):
|
||||
# Divide by 32768 because we have signed 16-bit data.
|
||||
audio_float = np.frombuffer(audio, dtype=np.int16).astype(np.float32) / 32768.0
|
||||
|
||||
whisper_lang = self.language_to_service_language(self._settings.language)
|
||||
segments, _ = await asyncio.to_thread(
|
||||
self._model.transcribe, audio_float, language=whisper_lang
|
||||
self._model.transcribe, audio_float, language=self._settings.language
|
||||
)
|
||||
text: str = ""
|
||||
for segment in segments:
|
||||
@@ -419,13 +418,12 @@ class WhisperSTTServiceMLX(WhisperSTTService):
|
||||
# Divide by 32768 because we have signed 16-bit data.
|
||||
audio_float = np.frombuffer(audio, dtype=np.int16).astype(np.float32) / 32768.0
|
||||
|
||||
whisper_lang = self.language_to_service_language(self._settings.language)
|
||||
chunk = await asyncio.to_thread(
|
||||
mlx_whisper.transcribe,
|
||||
audio_float,
|
||||
path_or_hf_repo=self.model_name,
|
||||
temperature=self._temperature,
|
||||
language=whisper_lang,
|
||||
language=self._settings.language,
|
||||
)
|
||||
text: str = ""
|
||||
for segment in chunk.get("segments", []):
|
||||
|
||||
Reference in New Issue
Block a user