From 3cb189eb1fe37229e4697e665a6185d44004b511 Mon Sep 17 00:00:00 2001 From: Jin Kim Date: Tue, 4 Feb 2025 10:27:28 +0900 Subject: [PATCH] Add whisper STT service using OpenAI API --- src/pipecat/services/openai.py | 64 +++++++++++++++++++++++++++++++++- 1 file changed, 63 insertions(+), 1 deletion(-) diff --git a/src/pipecat/services/openai.py b/src/pipecat/services/openai.py index ab81a1abb..2e138cd99 100644 --- a/src/pipecat/services/openai.py +++ b/src/pipecat/services/openai.py @@ -29,6 +29,7 @@ from pipecat.frames.frames import ( LLMUpdateSettingsFrame, OpenAILLMContextAssistantTimestampFrame, StartInterruptionFrame, + TranscriptionFrame, TTSAudioRawFrame, TTSStartedFrame, TTSStoppedFrame, @@ -47,7 +48,12 @@ from pipecat.processors.aggregators.openai_llm_context import ( OpenAILLMContextFrame, ) from pipecat.processors.frame_processor import FrameDirection -from pipecat.services.ai_services import ImageGenService, LLMService, TTSService +from pipecat.services.ai_services import ( + ImageGenService, + LLMService, + SegmentedSTTService, + TTSService, +) from pipecat.utils.time import time_now_iso8601 try: @@ -58,6 +64,7 @@ try: BadRequestError, DefaultAsyncHttpxClient, ) + from openai.types.audio import Transcription from openai.types.chat import ChatCompletionChunk, ChatCompletionMessageParam except ModuleNotFoundError as e: logger.error(f"Exception: {e}") @@ -390,6 +397,61 @@ class OpenAIImageGenService(ImageGenService): yield frame +class OpenAISTTService(SegmentedSTTService): + """OpenAI Speech-to-Text (STT) service. + + This service uses OpenAI's Whisper API to convert audio to text. + + Args: + model: Whisper model to use. Defaults to "whisper-1". + api_key: OpenAI API key. Defaults to None. + base_url: API base URL. Defaults to None. + **kwargs: Additional arguments passed to SegmentedSTTService. + """ + + def __init__( + self, + *, + model: str = "whisper-1", + api_key: Optional[str] = None, + base_url: Optional[str] = None, + **kwargs, + ): + super().__init__(**kwargs) + self.set_model_name(model) + self._client = AsyncOpenAI(api_key=api_key, base_url=base_url) + + async def set_model(self, model: str): + self.set_model_name(model) + + def can_generate_metrics(self) -> bool: + return True + + async def run_stt(self, audio: bytes) -> AsyncGenerator[Frame, None]: + try: + await self.start_processing_metrics() + await self.start_ttfb_metrics() + + response: Transcription = await self._client.audio.transcriptions.create( + file=("audio.wav", audio, "audio/wav"), model=self.model_name + ) + + await self.stop_ttfb_metrics() + await self.stop_processing_metrics() + + text = response.text.strip() + + if text: + logger.debug(f"Transcription: [{text}]") + yield TranscriptionFrame(text, "", time_now_iso8601()) + else: + logger.warning("Received empty transcription from API") + + except Exception as e: + logger.exception(f"Exception during transcription: {e}") + yield ErrorFrame(f"Error during transcription: {str(e)}") + + class OpenAITTSService(TTSService): """OpenAI Text-to-Speech service that generates audio from text.