Add GroqSTTService, BaseWhisperSTTService, and refactor OpenAISTTService

This commit is contained in:
Mark Backman
2025-02-08 11:11:33 -05:00
parent 71ce8f9bcf
commit 32b9de5f51
6 changed files with 155 additions and 45 deletions

View File

@@ -7,8 +7,18 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
## [Unreleased]
### Added
- Added support for Groq's Whisper API through the new `GroqSTTService` and
OpenAI's Whisper API through the new `OpenAISTTService`. Introduced a new
base class `BaseWhisperSTTService` to handle common Whisper API
functionality.
### Changed
- Updated foundation example `14f-function-calling-groq.py` to use
`GroqSTTService` for transcription.
- `RTVIObserver` doesn't handle `LLMSearchResponseFrame` frames anymore. For
now, to handle those frames you need to create a `GoogleRTVIObserver` instead.

View File

@@ -57,7 +57,7 @@ pip install "pipecat-ai[option,...]"
| Category | Services | Install Command Example |
| ------------------- | ----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | --------------------------------------- |
| Speech-to-Text | [AssemblyAI](https://docs.pipecat.ai/server/services/stt/assemblyai), [Azure](https://docs.pipecat.ai/server/services/stt/azure), [Deepgram](https://docs.pipecat.ai/server/services/stt/deepgram), [Gladia](https://docs.pipecat.ai/server/services/stt/gladia), [Whisper](https://docs.pipecat.ai/server/services/stt/whisper) | `pip install "pipecat-ai[deepgram]"` |
| Speech-to-Text | [AssemblyAI](https://docs.pipecat.ai/server/services/stt/assemblyai), [Azure](https://docs.pipecat.ai/server/services/stt/azure), [Deepgram](https://docs.pipecat.ai/server/services/stt/deepgram), [Gladia](https://docs.pipecat.ai/server/services/stt/gladia), [Groq (Whisper)](https://docs.pipecat.ai/server/services/stt/groq), [OpenAI (Whisper)](https://docs.pipecat.ai/server/services/stt/openai), [Whisper](https://docs.pipecat.ai/server/services/stt/whisper) | `pip install "pipecat-ai[deepgram]"` |
| LLMs | [Anthropic](https://docs.pipecat.ai/server/services/llm/anthropic), [Azure](https://docs.pipecat.ai/server/services/llm/azure), [Cerebras](https://docs.pipecat.ai/server/services/llm/cerebras), [DeepSeek](https://docs.pipecat.ai/server/services/llm/deepseek), [Fireworks AI](https://docs.pipecat.ai/server/services/llm/fireworks), [Gemini](https://docs.pipecat.ai/server/services/llm/gemini), [Grok](https://docs.pipecat.ai/server/services/llm/grok), [Groq](https://docs.pipecat.ai/server/services/llm/groq), [NVIDIA NIM](https://docs.pipecat.ai/server/services/llm/nim), [Ollama](https://docs.pipecat.ai/server/services/llm/ollama), [OpenAI](https://docs.pipecat.ai/server/services/llm/openai), [OpenRouter](https://docs.pipecat.ai/server/services/llm/openrouter), [Together AI](https://docs.pipecat.ai/server/services/llm/together) | `pip install "pipecat-ai[openai]"` |
| Text-to-Speech | [AWS](https://docs.pipecat.ai/server/services/tts/aws), [Azure](https://docs.pipecat.ai/server/services/tts/azure), [Cartesia](https://docs.pipecat.ai/server/services/tts/cartesia), [Deepgram](https://docs.pipecat.ai/server/services/tts/deepgram), [ElevenLabs](https://docs.pipecat.ai/server/services/tts/elevenlabs), [Fish](https://docs.pipecat.ai/server/services/tts/fish), [Google](https://docs.pipecat.ai/server/services/tts/google), [LMNT](https://docs.pipecat.ai/server/services/tts/lmnt), [OpenAI](https://docs.pipecat.ai/server/services/tts/openai), [PlayHT](https://docs.pipecat.ai/server/services/tts/playht), [Rime](https://docs.pipecat.ai/server/services/tts/rime), [XTTS](https://docs.pipecat.ai/server/services/tts/xtts) | `pip install "pipecat-ai[cartesia]"` |
| Speech-to-Speech | [Gemini Multimodal Live](https://docs.pipecat.ai/server/services/s2s/gemini), [OpenAI Realtime](https://docs.pipecat.ai/server/services/s2s/openai) | `pip install "pipecat-ai[openai]"` |

View File

@@ -20,7 +20,7 @@ from pipecat.pipeline.pipeline import Pipeline
from pipecat.pipeline.runner import PipelineRunner
from pipecat.pipeline.task import PipelineParams, PipelineTask
from pipecat.services.cartesia import CartesiaTTSService
from pipecat.services.groq import GroqLLMService
from pipecat.services.groq import GroqLLMService, GroqSTTService
from pipecat.services.openai import OpenAILLMContext
from pipecat.transports.services.daily import DailyParams, DailyTransport
@@ -50,20 +50,20 @@ async def main():
"Respond bot",
DailyParams(
audio_out_enabled=True,
transcription_enabled=True,
vad_enabled=True,
vad_analyzer=SileroVADAnalyzer(),
vad_audio_passthrough=True,
),
)
stt = GroqSTTService(api_key=os.getenv("GROQ_API_KEY"), model="distil-whisper-large-v3-en")
tts = CartesiaTTSService(
api_key=os.getenv("CARTESIA_API_KEY"),
voice_id="79a125e8-cd45-4c13-8a67-188112f4dd22", # British Lady
)
llm = GroqLLMService(
api_key=os.getenv("GROQ_API_KEY"), model="llama3-groq-70b-8192-tool-use-preview"
)
llm = GroqLLMService(api_key=os.getenv("GROQ_API_KEY"), model="llama-3.3-70b-versatile")
# Register a function_name of None to get all functions
# sent to the same callback with an additional function_name parameter.
llm.register_function(None, fetch_weather_from_api, start_callback=start_fetch_weather)
@@ -105,6 +105,7 @@ async def main():
pipeline = Pipeline(
[
transport.input(),
stt,
context_aggregator.user(),
llm,
tts,

View File

@@ -5,9 +5,12 @@
#
from typing import Optional
from loguru import logger
from pipecat.services.openai import OpenAILLMService
from pipecat.services.whisper_base import BaseWhisperSTTService, Transcription
class GroqLLMService(OpenAILLMService):
@@ -37,3 +40,33 @@ class GroqLLMService(OpenAILLMService):
"""Create OpenAI-compatible client for Groq API endpoint."""
logger.debug(f"Creating Groq client with api {base_url}")
return super().create_client(api_key, base_url, **kwargs)
class GroqSTTService(BaseWhisperSTTService):
"""Groq Whisper speech-to-text service.
Uses Groq's Whisper API to convert audio to text. Requires a Groq API key
set via the api_key parameter or GROQ_API_KEY environment variable.
Args:
model: Whisper model to use. Defaults to "whisper-large-v3-turbo".
api_key: Groq API key. Defaults to None.
base_url: API base URL. Defaults to "https://api.groq.com/openai/v1".
**kwargs: Additional arguments passed to BaseWhisperSTTService.
"""
def __init__(
self,
*,
model: str = "whisper-large-v3-turbo",
api_key: Optional[str] = None,
base_url: str = "https://api.groq.com/openai/v1",
**kwargs,
):
super().__init__(model=model, api_key=api_key, base_url=base_url, **kwargs)
async def _transcribe(self, audio: bytes) -> Transcription:
return await self._client.audio.transcriptions.create(
file=("audio.wav", audio, "audio/wav"), model=self.model_name, response_format="json"
)

View File

@@ -30,7 +30,6 @@ from pipecat.frames.frames import (
OpenAILLMContextAssistantTimestampFrame,
StartFrame,
StartInterruptionFrame,
TranscriptionFrame,
TTSAudioRawFrame,
TTSStartedFrame,
TTSStoppedFrame,
@@ -52,9 +51,9 @@ from pipecat.processors.frame_processor import FrameDirection
from pipecat.services.ai_services import (
ImageGenService,
LLMService,
SegmentedSTTService,
TTSService,
)
from pipecat.services.whisper_base import BaseWhisperSTTService, Transcription
from pipecat.utils.time import time_now_iso8601
try:
@@ -65,7 +64,6 @@ try:
BadRequestError,
DefaultAsyncHttpxClient,
)
from openai.types.audio import Transcription
from openai.types.chat import ChatCompletionChunk, ChatCompletionMessageParam
except ModuleNotFoundError as e:
logger.error(f"Exception: {e}")
@@ -398,16 +396,17 @@ class OpenAIImageGenService(ImageGenService):
yield frame
class OpenAISTTService(SegmentedSTTService):
"""OpenAI Speech-to-Text (STT) service.
class OpenAISTTService(BaseWhisperSTTService):
"""OpenAI Whisper speech-to-text service.
This service uses OpenAI's Whisper API to convert audio to text.
Uses OpenAI's Whisper API to convert audio to text. Requires an OpenAI API key
set via the api_key parameter or OPENAI_API_KEY environment variable.
Args:
model: Whisper model to use. Defaults to "whisper-1".
api_key: OpenAI API key. Defaults to None.
base_url: API base URL. Defaults to None.
**kwargs: Additional arguments passed to SegmentedSTTService.
**kwargs: Additional arguments passed to BaseWhisperSTTService.
"""
def __init__(
@@ -418,39 +417,12 @@ class OpenAISTTService(SegmentedSTTService):
base_url: Optional[str] = None,
**kwargs,
):
super().__init__(**kwargs)
self.set_model_name(model)
self._client = AsyncOpenAI(api_key=api_key, base_url=base_url)
super().__init__(model=model, api_key=api_key, base_url=base_url, **kwargs)
async def set_model(self, model: str):
self.set_model_name(model)
def can_generate_metrics(self) -> bool:
return True
async def run_stt(self, audio: bytes) -> AsyncGenerator[Frame, None]:
try:
await self.start_processing_metrics()
await self.start_ttfb_metrics()
response: Transcription = await self._client.audio.transcriptions.create(
file=("audio.wav", audio, "audio/wav"), model=self.model_name
)
await self.stop_ttfb_metrics()
await self.stop_processing_metrics()
text = response.text.strip()
if text:
logger.debug(f"Transcription: [{text}]")
yield TranscriptionFrame(text, "", time_now_iso8601())
else:
logger.warning("Received empty transcription from API")
except Exception as e:
logger.exception(f"Exception during transcription: {e}")
yield ErrorFrame(f"Error during transcription: {str(e)}")
async def _transcribe(self, audio: bytes) -> Transcription:
return await self._client.audio.transcriptions.create(
file=("audio.wav", audio, "audio/wav"), model=self.model_name
)
class OpenAITTSService(TTSService):

View File

@@ -0,0 +1,94 @@
#
# Copyright (c) 2025, Daily
#
# SPDX-License-Identifier: BSD 2-Clause License
#
from typing import AsyncGenerator, Optional
from loguru import logger
from pipecat.frames.frames import ErrorFrame, Frame, TranscriptionFrame
from pipecat.services.ai_services import SegmentedSTTService
from pipecat.utils.time import time_now_iso8601
try:
from openai import AsyncOpenAI
from openai.types.audio import Transcription
except ModuleNotFoundError as e:
logger.error(f"Exception: {e}")
logger.error(
"In order to use OpenAI, you need to `pip install pipecat-ai[openai]`. Also, set `OPENAI_API_KEY` environment variable."
)
raise Exception(f"Missing module: {e}")
class BaseWhisperSTTService(SegmentedSTTService):
"""Base class for Whisper-based speech-to-text services.
Provides common functionality for services implementing the Whisper API interface,
including metrics generation and error handling.
Args:
model: Name of the Whisper model to use.
api_key: Service API key. Defaults to None.
base_url: Service API base URL. Defaults to None.
**kwargs: Additional arguments passed to SegmentedSTTService.
"""
def __init__(
self,
*,
model: str,
api_key: Optional[str] = None,
base_url: Optional[str] = None,
**kwargs,
):
super().__init__(**kwargs)
self.set_model_name(model)
self._client = self._create_client(api_key, base_url)
def _create_client(self, api_key: Optional[str], base_url: Optional[str]):
return AsyncOpenAI(api_key=api_key, base_url=base_url)
async def set_model(self, model: str):
self.set_model_name(model)
def can_generate_metrics(self) -> bool:
return True
async def run_stt(self, audio: bytes) -> AsyncGenerator[Frame, None]:
try:
await self.start_processing_metrics()
await self.start_ttfb_metrics()
response = await self._transcribe(audio)
await self.stop_ttfb_metrics()
await self.stop_processing_metrics()
text = response.text.strip()
if text:
logger.debug(f"Transcription: [{text}]")
yield TranscriptionFrame(text, "", time_now_iso8601())
else:
logger.warning("Received empty transcription from API")
except Exception as e:
logger.exception(f"Exception during transcription: {e}")
yield ErrorFrame(f"Error during transcription: {str(e)}")
async def _transcribe(self, audio: bytes) -> Transcription:
"""Transcribe audio data to text.
Args:
audio: Raw audio data in WAV format.
Returns:
Transcription: Object containing the transcribed text.
Raises:
NotImplementedError: Must be implemented by subclasses.
"""
raise NotImplementedError