From 4e0ece17b6fd4607b69dacc4a1565e6cc6a67e78 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aleix=20Conchillo=20Flaqu=C3=A9?= Date: Mon, 26 Aug 2024 11:12:51 -0700 Subject: [PATCH] services: added support for setting STT model and language --- src/pipecat/frames/frames.py | 29 ++++++-- src/pipecat/services/ai_services.py | 104 +++++++++++++++++----------- src/pipecat/services/cartesia.py | 19 +++-- 3 files changed, 95 insertions(+), 57 deletions(-) diff --git a/src/pipecat/frames/frames.py b/src/pipecat/frames/frames.py index a6b3e22ae..884e1e9eb 100644 --- a/src/pipecat/frames/frames.py +++ b/src/pipecat/frames/frames.py @@ -8,7 +8,7 @@ from typing import Any, List, Mapping, Optional, Tuple from dataclasses import dataclass, field -from pipecat.transcriptions.languages import Language +from pipecat.transcriptions.language import Language from pipecat.utils.utils import obj_count, obj_id from pipecat.vad.vad_analyzer import VADParams @@ -436,6 +436,13 @@ class LLMModelUpdateFrame(ControlFrame): model: str +@dataclass +class TTSModelUpdateFrame(ControlFrame): + """A control frame containing a request to update the TTS model. + """ + model: str + + @dataclass class TTSVoiceUpdateFrame(ControlFrame): """A control frame containing a request to update to a new TTS voice. @@ -445,18 +452,26 @@ class TTSVoiceUpdateFrame(ControlFrame): @dataclass class TTSLanguageUpdateFrame(ControlFrame): - """A control frame containing a request to update to a new TTS language. + """A control frame containing a request to update to a new TTS language and + optional voice. + """ language: Language + voice: str | None = None @dataclass -class TTSLanguageVoicesUpdateFrame(ControlFrame): - """A control frame containing a mapping between a language and the desired - voice for that language. - +class STTModelUpdateFrame(ControlFrame): + """A control frame containing a request to update the STT model. """ - voices: Mapping[Language, str] + model: str + + +@dataclass +class STTLanguageUpdateFrame(ControlFrame): + """A control frame containing a request to update to STT language. + """ + language: Language @dataclass diff --git a/src/pipecat/services/ai_services.py b/src/pipecat/services/ai_services.py index 8208986ef..50b73e85b 100644 --- a/src/pipecat/services/ai_services.py +++ b/src/pipecat/services/ai_services.py @@ -18,22 +18,23 @@ from pipecat.frames.frames import ( ErrorFrame, Frame, LLMFullResponseEndFrame, + STTLanguageUpdateFrame, + STTModelUpdateFrame, StartFrame, StartInterruptionFrame, TTSLanguageUpdateFrame, - TTSLanguageVoicesUpdateFrame, + TTSModelUpdateFrame, TTSSpeakFrame, TTSStartedFrame, TTSStoppedFrame, TTSVoiceUpdateFrame, TextFrame, - TranscriptionFrame, UserImageRequestFrame, VisionImageRawFrame ) from pipecat.processors.async_frame_processor import AsyncFrameProcessor from pipecat.processors.frame_processor import FrameDirection, FrameProcessor -from pipecat.transcriptions.languages import Language +from pipecat.transcriptions.language import Language from pipecat.utils.audio import calculate_audio_volume from pipecat.utils.string import match_endofsentence from pipecat.utils.utils import exp_smoothing @@ -177,16 +178,16 @@ class TTSService(AIService): self._stop_frame_queue: asyncio.Queue = asyncio.Queue() self._current_sentence: str = "" + @abstractmethod + async def set_model(self, model: str): + pass + @abstractmethod async def set_voice(self, voice: str): pass @abstractmethod - async def set_language(self, language: Language): - pass - - @abstractmethod - async def set_language_voices(self, voices: Mapping[Language, str]): + async def set_language(self, language: Language, voice: str | None): pass # Converts the text to audio. @@ -245,12 +246,12 @@ class TTSService(AIService): await self.push_frame(frame, direction) elif isinstance(frame, TTSSpeakFrame): await self._push_tts_frames(frame.text, False) + elif isinstance(frame, TTSModelUpdateFrame): + await self.set_model(frame.model) elif isinstance(frame, TTSVoiceUpdateFrame): await self.set_voice(frame.voice) elif isinstance(frame, TTSLanguageUpdateFrame): - await self.set_language(frame.language) - elif isinstance(frame, TTSLanguageVoicesUpdateFrame): - await self.set_language_voices(frame.voices) + await self.set_language(frame.language, frame.voice) else: await self.push_frame(frame, direction) @@ -305,6 +306,47 @@ class TTSService(AIService): class STTService(AIService): """STTService is a base class for speech-to-text services.""" + def __init__(self, **kwargs): + super().__init__(**kwargs) + + @abstractmethod + async def set_model(self, model: str): + pass + + @abstractmethod + async def set_language(self, language: Language): + pass + + @abstractmethod + async def run_stt(self, audio: bytes) -> AsyncGenerator[Frame, None]: + """Returns transcript as a string""" + pass + + async def process_audio_frame(self, frame: AudioRawFrame): + await self.process_generator(self.run_stt(frame.audio)) + + async def process_frame(self, frame: Frame, direction: FrameDirection): + """Processes a frame of audio data, either buffering or transcribing it.""" + await super().process_frame(frame, direction) + + if isinstance(frame, AudioRawFrame): + # In this service we accumulate audio internally and at the end we + # push a TextFrame. We don't really want to push audio frames down. + await self.process_audio_frame(frame) + elif isinstance(frame, STTModelUpdateFrame): + await self.set_model(frame.model) + elif isinstance(frame, STTLanguageUpdateFrame): + await self.set_language(frame.language) + else: + await self.push_frame(frame, direction) + + +class SegmentedSTTService(STTService): + """SegmentedSTTService is an STTService that will detect speech and will run + speech-to-text on speech segments only, instead of a continous stream. + + """ + def __init__(self, *, min_volume: float = 0.6, @@ -325,24 +367,7 @@ class STTService(AIService): self._smoothing_factor = 0.2 self._prev_volume = 0 - @abstractmethod - async def run_stt(self, audio: bytes) -> AsyncGenerator[Frame, None]: - """Returns transcript as a string""" - pass - - def _new_wave(self): - content = io.BytesIO() - ww = wave.open(content, "wb") - ww.setsampwidth(2) - ww.setnchannels(self._num_channels) - ww.setframerate(self._sample_rate) - return (content, ww) - - def _get_smoothed_volume(self, frame: AudioRawFrame) -> float: - volume = calculate_audio_volume(frame.audio, frame.sample_rate) - return exp_smoothing(volume, self._prev_volume, self._smoothing_factor) - - async def _append_audio(self, frame: AudioRawFrame): + async def process_audio_frame(self, frame: AudioRawFrame): # Try to filter out empty background noise volume = self._get_smoothed_volume(frame) if volume >= self._min_volume: @@ -362,9 +387,7 @@ class STTService(AIService): self._silence_num_frames = 0 self._wave.close() self._content.seek(0) - await self.start_processing_metrics() await self.process_generator(self.run_stt(self._content.read())) - await self.stop_processing_metrics() (self._content, self._wave) = self._new_wave() async def stop(self, frame: EndFrame): @@ -373,16 +396,17 @@ class STTService(AIService): async def cancel(self, frame: CancelFrame): self._wave.close() - async def process_frame(self, frame: Frame, direction: FrameDirection): - """Processes a frame of audio data, either buffering or transcribing it.""" - await super().process_frame(frame, direction) + def _new_wave(self): + content = io.BytesIO() + ww = wave.open(content, "wb") + ww.setsampwidth(2) + ww.setnchannels(self._num_channels) + ww.setframerate(self._sample_rate) + return (content, ww) - if isinstance(frame, AudioRawFrame): - # In this service we accumulate audio internally and at the end we - # push a TextFrame. We don't really want to push audio frames down. - await self._append_audio(frame) - else: - await self.push_frame(frame, direction) + def _get_smoothed_volume(self, frame: AudioRawFrame) -> float: + volume = calculate_audio_volume(frame.audio, frame.sample_rate) + return exp_smoothing(volume, self._prev_volume, self._smoothing_factor) class ImageGenService(AIService): diff --git a/src/pipecat/services/cartesia.py b/src/pipecat/services/cartesia.py index 17bf34d15..b267d14c1 100644 --- a/src/pipecat/services/cartesia.py +++ b/src/pipecat/services/cartesia.py @@ -99,7 +99,6 @@ class CartesiaTTSService(TTSService): "sample_rate": sample_rate, } self._language = language - self._language_voices = {} self._websocket = None self._context_id = None @@ -111,20 +110,20 @@ class CartesiaTTSService(TTSService): def can_generate_metrics(self) -> bool: return True + async def set_model(self, model: str): + logger.debug(f"Switching TTS model to: [{model}]") + self._model_id = model + async def set_voice(self, voice: str): logger.debug(f"Switching TTS voice to: [{voice}]") self._voice_id = voice - async def set_language(self, language: Language): + async def set_language(self, language: Language, voice: str | None): + logger.debug(f"Switching TTS language to: [{language}]") cartesia_language = language_to_cartesia_language(language) - if cartesia_language and language in self._language_voices: - logger.debug(f"Switching TTS language to: [{language}]") - self._language = cartesia_language - await self.set_voice(self._language_voices[language]) - - async def set_language_voices(self, voices: Mapping[Language, str]): - logger.debug(f"Setting TTS language voices to: {voices}") - self._language_voices = voices + self._language = cartesia_language + if voice: + self._voice_id = voice async def start(self, frame: StartFrame): await super().start(frame)