Add STTService docstrings

This commit is contained in:
Mark Backman
2025-06-25 16:24:44 -04:00
parent f622b281d0
commit ab1d2dbe6a

View File

@@ -4,6 +4,8 @@
# SPDX-License-Identifier: BSD 2-Clause License
#
"""Base classes for Speech-to-Text services with continuous and segmented processing."""
import io
import wave
from abc import abstractmethod
@@ -26,7 +28,19 @@ from pipecat.transcriptions.language import Language
class STTService(AIService):
"""STTService is a base class for speech-to-text services."""
"""Base class for speech-to-text services.
Provides common functionality for STT services including audio passthrough,
muting, settings management, and audio processing. Subclasses must implement
the run_stt method to provide actual speech recognition.
Args:
audio_passthrough: Whether to pass audio frames downstream after processing.
Defaults to True.
sample_rate: The sample rate for audio input. If None, will be determined
from the start frame.
**kwargs: Additional arguments passed to the parent AIService.
"""
def __init__(
self,
@@ -44,25 +58,59 @@ class STTService(AIService):
@property
def is_muted(self) -> bool:
"""Returns whether the STT service is currently muted."""
"""Check if the STT service is currently muted.
Returns:
True if the service is muted and will not process audio.
"""
return self._muted
@property
def sample_rate(self) -> int:
"""Get the current sample rate for audio processing.
Returns:
The sample rate in Hz.
"""
return self._sample_rate
async def set_model(self, model: str):
"""Set the speech recognition model.
Args:
model: The name of the model to use for speech recognition.
"""
self.set_model_name(model)
async def set_language(self, language: Language):
"""Set the language for speech recognition.
Args:
language: The language to use for speech recognition.
"""
pass
@abstractmethod
async def run_stt(self, audio: bytes) -> AsyncGenerator[Frame, None]:
"""Returns transcript as a string"""
"""Run speech-to-text on the provided audio data.
This method must be implemented by subclasses to provide actual speech
recognition functionality.
Args:
audio: Raw audio bytes to transcribe.
Yields:
Frame: Frames containing transcription results (typically TextFrame).
"""
pass
async def start(self, frame: StartFrame):
"""Start the STT service.
Args:
frame: The start frame containing initialization parameters.
"""
await super().start(frame)
self._sample_rate = self._init_sample_rate or frame.audio_in_sample_rate
@@ -80,13 +128,24 @@ class STTService(AIService):
logger.warning(f"Unknown setting for STT service: {key}")
async def process_audio_frame(self, frame: AudioRawFrame, direction: FrameDirection):
"""Process an audio frame for speech recognition.
Args:
frame: The audio frame to process.
direction: The direction of frame processing.
"""
if self._muted:
return
await self.process_generator(self.run_stt(frame.audio))
async def process_frame(self, frame: Frame, direction: FrameDirection):
"""Processes a frame of audio data, either buffering or transcribing it."""
"""Process frames, handling VAD events and audio segmentation.
Args:
frame: The frame to process.
direction: The direction of frame processing.
"""
await super().process_frame(frame, direction)
if isinstance(frame, AudioRawFrame):
@@ -106,14 +165,19 @@ class STTService(AIService):
class SegmentedSTTService(STTService):
"""SegmentedSTTService is an STTService that uses VAD events to detect
speech and will run speech-to-text on speech segments only, instead of a
continous stream. Since it uses VAD it means that VAD needs to be enabled in
the pipeline.
"""STT service that processes speech in segments using VAD events.
This service always keeps a small audio buffer to take into account that VAD
events are delayed from when the user speech really starts.
Uses Voice Activity Detection (VAD) events to detect speech segments and runs
speech-to-text only on those segments, rather than continuously.
Requires VAD to be enabled in the pipeline to function properly. Maintains a
small audio buffer to account for the delay between actual speech start and
VAD detection.
Args:
sample_rate: The sample rate for audio input. If None, will be determined
from the start frame.
**kwargs: Additional arguments passed to the parent STTService.
"""
def __init__(self, *, sample_rate: Optional[int] = None, **kwargs):
@@ -125,10 +189,16 @@ class SegmentedSTTService(STTService):
self._user_speaking = False
async def start(self, frame: StartFrame):
"""Start the segmented STT service and initialize audio buffer.
Args:
frame: The start frame containing initialization parameters.
"""
await super().start(frame)
self._audio_buffer_size_1s = self.sample_rate * 2
async def process_frame(self, frame: Frame, direction: FrameDirection):
"""Process frames, handling VAD events and audio segmentation."""
await super().process_frame(frame, direction)
if isinstance(frame, UserStartedSpeakingFrame):
@@ -162,6 +232,15 @@ class SegmentedSTTService(STTService):
self._audio_buffer.clear()
async def process_audio_frame(self, frame: AudioRawFrame, direction: FrameDirection):
"""Process audio frames by buffering them for segmented transcription.
Continuously buffers audio, growing the buffer while user is speaking and
maintaining a small buffer when not speaking to account for VAD delay.
Args:
frame: The audio frame to process.
direction: The direction of frame processing.
"""
# If the user is speaking the audio buffer will keep growing.
self._audio_buffer += frame.audio