Add whisper STT service using OpenAI API
This commit is contained in:
@@ -29,6 +29,7 @@ from pipecat.frames.frames import (
|
||||
LLMUpdateSettingsFrame,
|
||||
OpenAILLMContextAssistantTimestampFrame,
|
||||
StartInterruptionFrame,
|
||||
TranscriptionFrame,
|
||||
TTSAudioRawFrame,
|
||||
TTSStartedFrame,
|
||||
TTSStoppedFrame,
|
||||
@@ -47,7 +48,12 @@ from pipecat.processors.aggregators.openai_llm_context import (
|
||||
OpenAILLMContextFrame,
|
||||
)
|
||||
from pipecat.processors.frame_processor import FrameDirection
|
||||
from pipecat.services.ai_services import ImageGenService, LLMService, TTSService
|
||||
from pipecat.services.ai_services import (
|
||||
ImageGenService,
|
||||
LLMService,
|
||||
SegmentedSTTService,
|
||||
TTSService,
|
||||
)
|
||||
from pipecat.utils.time import time_now_iso8601
|
||||
|
||||
try:
|
||||
@@ -58,6 +64,7 @@ try:
|
||||
BadRequestError,
|
||||
DefaultAsyncHttpxClient,
|
||||
)
|
||||
from openai.types.audio import Transcription
|
||||
from openai.types.chat import ChatCompletionChunk, ChatCompletionMessageParam
|
||||
except ModuleNotFoundError as e:
|
||||
logger.error(f"Exception: {e}")
|
||||
@@ -390,6 +397,61 @@ class OpenAIImageGenService(ImageGenService):
|
||||
yield frame
|
||||
|
||||
|
||||
class OpenAISTTService(SegmentedSTTService):
|
||||
"""OpenAI Speech-to-Text (STT) service.
|
||||
|
||||
This service uses OpenAI's Whisper API to convert audio to text.
|
||||
|
||||
Args:
|
||||
model: Whisper model to use. Defaults to "whisper-1".
|
||||
api_key: OpenAI API key. Defaults to None.
|
||||
base_url: API base URL. Defaults to None.
|
||||
**kwargs: Additional arguments passed to SegmentedSTTService.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
model: str = "whisper-1",
|
||||
api_key: Optional[str] = None,
|
||||
base_url: Optional[str] = None,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
self.set_model_name(model)
|
||||
self._client = AsyncOpenAI(api_key=api_key, base_url=base_url)
|
||||
|
||||
async def set_model(self, model: str):
|
||||
self.set_model_name(model)
|
||||
|
||||
def can_generate_metrics(self) -> bool:
|
||||
return True
|
||||
|
||||
async def run_stt(self, audio: bytes) -> AsyncGenerator[Frame, None]:
|
||||
try:
|
||||
await self.start_processing_metrics()
|
||||
await self.start_ttfb_metrics()
|
||||
|
||||
response: Transcription = await self._client.audio.transcriptions.create(
|
||||
file=("audio.wav", audio, "audio/wav"), model=self.model_name
|
||||
)
|
||||
|
||||
await self.stop_ttfb_metrics()
|
||||
await self.stop_processing_metrics()
|
||||
|
||||
text = response.text.strip()
|
||||
|
||||
if text:
|
||||
logger.debug(f"Transcription: [{text}]")
|
||||
yield TranscriptionFrame(text, "", time_now_iso8601())
|
||||
else:
|
||||
logger.warning("Received empty transcription from API")
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"Exception during transcription: {e}")
|
||||
yield ErrorFrame(f"Error during transcription: {str(e)}")
|
||||
|
||||
|
||||
class OpenAITTSService(TTSService):
|
||||
"""OpenAI Text-to-Speech service that generates audio from text.
|
||||
|
||||
|
||||
Reference in New Issue
Block a user