From ae6fbb3146fd3090ba8dca8c613777eb2ed14dbf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aleix=20Conchillo=20Flaqu=C3=A9?= Date: Mon, 26 Aug 2024 14:30:57 -0700 Subject: [PATCH] services: just set model, voice, language independently --- src/pipecat/frames/frames.py | 4 ---- src/pipecat/services/ai_services.py | 12 ++++++------ src/pipecat/services/cartesia.py | 17 +++-------------- src/pipecat/services/deepgram.py | 5 +---- 4 files changed, 10 insertions(+), 28 deletions(-) diff --git a/src/pipecat/frames/frames.py b/src/pipecat/frames/frames.py index 87f2d599d..13c2f53f1 100644 --- a/src/pipecat/frames/frames.py +++ b/src/pipecat/frames/frames.py @@ -441,8 +441,6 @@ class TTSModelUpdateFrame(ControlFrame): """A control frame containing a request to update the TTS model. """ model: str - voice: str | None = None - language: Language | None = None @dataclass @@ -459,7 +457,6 @@ class TTSLanguageUpdateFrame(ControlFrame): """ language: Language - voice: str | None = None @dataclass @@ -469,7 +466,6 @@ class STTModelUpdateFrame(ControlFrame): """ model: str - language: Language | None = None @dataclass diff --git a/src/pipecat/services/ai_services.py b/src/pipecat/services/ai_services.py index 4f5415586..c10f3e46e 100644 --- a/src/pipecat/services/ai_services.py +++ b/src/pipecat/services/ai_services.py @@ -179,7 +179,7 @@ class TTSService(AIService): self._current_sentence: str = "" @abstractmethod - async def set_model(self, model: str, voice: str | None, language: Language | None): + async def set_model(self, model: str): pass @abstractmethod @@ -187,7 +187,7 @@ class TTSService(AIService): pass @abstractmethod - async def set_language(self, language: Language, voice: str | None): + async def set_language(self, language: Language): pass # Converts the text to audio. @@ -247,11 +247,11 @@ class TTSService(AIService): elif isinstance(frame, TTSSpeakFrame): await self._push_tts_frames(frame.text, False) elif isinstance(frame, TTSModelUpdateFrame): - await self.set_model(frame.model, frame.voice, frame.language) + await self.set_model(frame.model) elif isinstance(frame, TTSVoiceUpdateFrame): await self.set_voice(frame.voice) elif isinstance(frame, TTSLanguageUpdateFrame): - await self.set_language(frame.language, frame.voice) + await self.set_language(frame.language) else: await self.push_frame(frame, direction) @@ -310,7 +310,7 @@ class STTService(AIService): super().__init__(**kwargs) @abstractmethod - async def set_model(self, model: str, language: Language | None): + async def set_model(self, model: str): pass @abstractmethod @@ -334,7 +334,7 @@ class STTService(AIService): # push a TextFrame. We don't really want to push audio frames down. await self.process_audio_frame(frame) elif isinstance(frame, STTModelUpdateFrame): - await self.set_model(frame.model, frame.language) + await self.set_model(frame.model) elif isinstance(frame, STTLanguageUpdateFrame): await self.set_language(frame.language) else: diff --git a/src/pipecat/services/cartesia.py b/src/pipecat/services/cartesia.py index bc7d76dde..e3541ccea 100644 --- a/src/pipecat/services/cartesia.py +++ b/src/pipecat/services/cartesia.py @@ -110,28 +110,17 @@ class CartesiaTTSService(TTSService): def can_generate_metrics(self) -> bool: return True - async def set_model(self, model: str, voice: str | None, language: Language | None): + async def set_model(self, model: str): logger.debug(f"Switching TTS model to: [{model}]") self._model_id = model - if language: - logger.debug(f"Switching TTS language to: [{language}]") - cartesia_language = language_to_cartesia_language(language) - self._language = cartesia_language - if voice: - logger.debug(f"Switching TTS voice to: [{voice}]") - self._voice_id = voice async def set_voice(self, voice: str): logger.debug(f"Switching TTS voice to: [{voice}]") self._voice_id = voice - async def set_language(self, language: Language, voice: str | None): + async def set_language(self, language: Language): logger.debug(f"Switching TTS language to: [{language}]") - cartesia_language = language_to_cartesia_language(language) - self._language = cartesia_language - if voice: - logger.debug(f"Switching TTS voice to: [{voice}]") - self._voice_id = voice + self._language = language_to_cartesia_language(language) async def start(self, frame: StartFrame): await super().start(frame) diff --git a/src/pipecat/services/deepgram.py b/src/pipecat/services/deepgram.py index d8eb5f7c6..d899d4bdb 100644 --- a/src/pipecat/services/deepgram.py +++ b/src/pipecat/services/deepgram.py @@ -134,12 +134,9 @@ class DeepgramSTTService(STTService): self._connection: AsyncListenWebSocketClient = self._client.listen.asyncwebsocket.v("1") self._connection.on(LiveTranscriptionEvents.Transcript, self._on_message) - async def set_model(self, model: str, language: Language | None): + async def set_model(self, model: str): logger.debug(f"Switching STT model to: [{model}]") self._live_options.model = model - if language: - logger.debug(f"Switching STT language to: [{language}]") - self._live_options.language = language await self._disconnect() await self._connect()