services: added support for setting STT model and language

This commit is contained in:
Aleix Conchillo Flaqué
2024-08-26 11:12:51 -07:00
parent fd3fdacdee
commit 4e0ece17b6
3 changed files with 95 additions and 57 deletions

View File

@@ -8,7 +8,7 @@ from typing import Any, List, Mapping, Optional, Tuple
from dataclasses import dataclass, field
from pipecat.transcriptions.languages import Language
from pipecat.transcriptions.language import Language
from pipecat.utils.utils import obj_count, obj_id
from pipecat.vad.vad_analyzer import VADParams
@@ -436,6 +436,13 @@ class LLMModelUpdateFrame(ControlFrame):
model: str
@dataclass
class TTSModelUpdateFrame(ControlFrame):
"""A control frame containing a request to update the TTS model.
"""
model: str
@dataclass
class TTSVoiceUpdateFrame(ControlFrame):
"""A control frame containing a request to update to a new TTS voice.
@@ -445,18 +452,26 @@ class TTSVoiceUpdateFrame(ControlFrame):
@dataclass
class TTSLanguageUpdateFrame(ControlFrame):
"""A control frame containing a request to update to a new TTS language.
"""A control frame containing a request to update to a new TTS language and
optional voice.
"""
language: Language
voice: str | None = None
@dataclass
class TTSLanguageVoicesUpdateFrame(ControlFrame):
"""A control frame containing a mapping between a language and the desired
voice for that language.
class STTModelUpdateFrame(ControlFrame):
"""A control frame containing a request to update the STT model.
"""
voices: Mapping[Language, str]
model: str
@dataclass
class STTLanguageUpdateFrame(ControlFrame):
"""A control frame containing a request to update to STT language.
"""
language: Language
@dataclass

View File

@@ -18,22 +18,23 @@ from pipecat.frames.frames import (
ErrorFrame,
Frame,
LLMFullResponseEndFrame,
STTLanguageUpdateFrame,
STTModelUpdateFrame,
StartFrame,
StartInterruptionFrame,
TTSLanguageUpdateFrame,
TTSLanguageVoicesUpdateFrame,
TTSModelUpdateFrame,
TTSSpeakFrame,
TTSStartedFrame,
TTSStoppedFrame,
TTSVoiceUpdateFrame,
TextFrame,
TranscriptionFrame,
UserImageRequestFrame,
VisionImageRawFrame
)
from pipecat.processors.async_frame_processor import AsyncFrameProcessor
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
from pipecat.transcriptions.languages import Language
from pipecat.transcriptions.language import Language
from pipecat.utils.audio import calculate_audio_volume
from pipecat.utils.string import match_endofsentence
from pipecat.utils.utils import exp_smoothing
@@ -177,16 +178,16 @@ class TTSService(AIService):
self._stop_frame_queue: asyncio.Queue = asyncio.Queue()
self._current_sentence: str = ""
@abstractmethod
async def set_model(self, model: str):
pass
@abstractmethod
async def set_voice(self, voice: str):
pass
@abstractmethod
async def set_language(self, language: Language):
pass
@abstractmethod
async def set_language_voices(self, voices: Mapping[Language, str]):
async def set_language(self, language: Language, voice: str | None):
pass
# Converts the text to audio.
@@ -245,12 +246,12 @@ class TTSService(AIService):
await self.push_frame(frame, direction)
elif isinstance(frame, TTSSpeakFrame):
await self._push_tts_frames(frame.text, False)
elif isinstance(frame, TTSModelUpdateFrame):
await self.set_model(frame.model)
elif isinstance(frame, TTSVoiceUpdateFrame):
await self.set_voice(frame.voice)
elif isinstance(frame, TTSLanguageUpdateFrame):
await self.set_language(frame.language)
elif isinstance(frame, TTSLanguageVoicesUpdateFrame):
await self.set_language_voices(frame.voices)
await self.set_language(frame.language, frame.voice)
else:
await self.push_frame(frame, direction)
@@ -305,6 +306,47 @@ class TTSService(AIService):
class STTService(AIService):
"""STTService is a base class for speech-to-text services."""
def __init__(self, **kwargs):
super().__init__(**kwargs)
@abstractmethod
async def set_model(self, model: str):
pass
@abstractmethod
async def set_language(self, language: Language):
pass
@abstractmethod
async def run_stt(self, audio: bytes) -> AsyncGenerator[Frame, None]:
"""Returns transcript as a string"""
pass
async def process_audio_frame(self, frame: AudioRawFrame):
await self.process_generator(self.run_stt(frame.audio))
async def process_frame(self, frame: Frame, direction: FrameDirection):
"""Processes a frame of audio data, either buffering or transcribing it."""
await super().process_frame(frame, direction)
if isinstance(frame, AudioRawFrame):
# In this service we accumulate audio internally and at the end we
# push a TextFrame. We don't really want to push audio frames down.
await self.process_audio_frame(frame)
elif isinstance(frame, STTModelUpdateFrame):
await self.set_model(frame.model)
elif isinstance(frame, STTLanguageUpdateFrame):
await self.set_language(frame.language)
else:
await self.push_frame(frame, direction)
class SegmentedSTTService(STTService):
"""SegmentedSTTService is an STTService that will detect speech and will run
speech-to-text on speech segments only, instead of a continous stream.
"""
def __init__(self,
*,
min_volume: float = 0.6,
@@ -325,24 +367,7 @@ class STTService(AIService):
self._smoothing_factor = 0.2
self._prev_volume = 0
@abstractmethod
async def run_stt(self, audio: bytes) -> AsyncGenerator[Frame, None]:
"""Returns transcript as a string"""
pass
def _new_wave(self):
content = io.BytesIO()
ww = wave.open(content, "wb")
ww.setsampwidth(2)
ww.setnchannels(self._num_channels)
ww.setframerate(self._sample_rate)
return (content, ww)
def _get_smoothed_volume(self, frame: AudioRawFrame) -> float:
volume = calculate_audio_volume(frame.audio, frame.sample_rate)
return exp_smoothing(volume, self._prev_volume, self._smoothing_factor)
async def _append_audio(self, frame: AudioRawFrame):
async def process_audio_frame(self, frame: AudioRawFrame):
# Try to filter out empty background noise
volume = self._get_smoothed_volume(frame)
if volume >= self._min_volume:
@@ -362,9 +387,7 @@ class STTService(AIService):
self._silence_num_frames = 0
self._wave.close()
self._content.seek(0)
await self.start_processing_metrics()
await self.process_generator(self.run_stt(self._content.read()))
await self.stop_processing_metrics()
(self._content, self._wave) = self._new_wave()
async def stop(self, frame: EndFrame):
@@ -373,16 +396,17 @@ class STTService(AIService):
async def cancel(self, frame: CancelFrame):
self._wave.close()
async def process_frame(self, frame: Frame, direction: FrameDirection):
"""Processes a frame of audio data, either buffering or transcribing it."""
await super().process_frame(frame, direction)
def _new_wave(self):
content = io.BytesIO()
ww = wave.open(content, "wb")
ww.setsampwidth(2)
ww.setnchannels(self._num_channels)
ww.setframerate(self._sample_rate)
return (content, ww)
if isinstance(frame, AudioRawFrame):
# In this service we accumulate audio internally and at the end we
# push a TextFrame. We don't really want to push audio frames down.
await self._append_audio(frame)
else:
await self.push_frame(frame, direction)
def _get_smoothed_volume(self, frame: AudioRawFrame) -> float:
volume = calculate_audio_volume(frame.audio, frame.sample_rate)
return exp_smoothing(volume, self._prev_volume, self._smoothing_factor)
class ImageGenService(AIService):

View File

@@ -99,7 +99,6 @@ class CartesiaTTSService(TTSService):
"sample_rate": sample_rate,
}
self._language = language
self._language_voices = {}
self._websocket = None
self._context_id = None
@@ -111,20 +110,20 @@ class CartesiaTTSService(TTSService):
def can_generate_metrics(self) -> bool:
return True
async def set_model(self, model: str):
logger.debug(f"Switching TTS model to: [{model}]")
self._model_id = model
async def set_voice(self, voice: str):
logger.debug(f"Switching TTS voice to: [{voice}]")
self._voice_id = voice
async def set_language(self, language: Language):
async def set_language(self, language: Language, voice: str | None):
logger.debug(f"Switching TTS language to: [{language}]")
cartesia_language = language_to_cartesia_language(language)
if cartesia_language and language in self._language_voices:
logger.debug(f"Switching TTS language to: [{language}]")
self._language = cartesia_language
await self.set_voice(self._language_voices[language])
async def set_language_voices(self, voices: Mapping[Language, str]):
logger.debug(f"Setting TTS language voices to: {voices}")
self._language_voices = voices
self._language = cartesia_language
if voice:
self._voice_id = voice
async def start(self, frame: StartFrame):
await super().start(frame)