Add GroqSTTService, BaseWhisperSTTService, and refactor OpenAISTTService
This commit is contained in:
10
CHANGELOG.md
10
CHANGELOG.md
@@ -7,8 +7,18 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
||||
|
||||
## [Unreleased]
|
||||
|
||||
### Added
|
||||
|
||||
- Added support for Groq's Whisper API through the new `GroqSTTService` and
|
||||
OpenAI's Whisper API through the new `OpenAISTTService`. Introduced a new
|
||||
base class `BaseWhisperSTTService` to handle common Whisper API
|
||||
functionality.
|
||||
|
||||
### Changed
|
||||
|
||||
- Updated foundation example `14f-function-calling-groq.py` to use
|
||||
`GroqSTTService` for transcription.
|
||||
|
||||
- `RTVIObserver` doesn't handle `LLMSearchResponseFrame` frames anymore. For
|
||||
now, to handle those frames you need to create a `GoogleRTVIObserver` instead.
|
||||
|
||||
|
||||
@@ -57,7 +57,7 @@ pip install "pipecat-ai[option,...]"
|
||||
|
||||
| Category | Services | Install Command Example |
|
||||
| ------------------- | ----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | --------------------------------------- |
|
||||
| Speech-to-Text | [AssemblyAI](https://docs.pipecat.ai/server/services/stt/assemblyai), [Azure](https://docs.pipecat.ai/server/services/stt/azure), [Deepgram](https://docs.pipecat.ai/server/services/stt/deepgram), [Gladia](https://docs.pipecat.ai/server/services/stt/gladia), [Whisper](https://docs.pipecat.ai/server/services/stt/whisper) | `pip install "pipecat-ai[deepgram]"` |
|
||||
| Speech-to-Text | [AssemblyAI](https://docs.pipecat.ai/server/services/stt/assemblyai), [Azure](https://docs.pipecat.ai/server/services/stt/azure), [Deepgram](https://docs.pipecat.ai/server/services/stt/deepgram), [Gladia](https://docs.pipecat.ai/server/services/stt/gladia), [Groq (Whisper)](https://docs.pipecat.ai/server/services/stt/groq), [OpenAI (Whisper)](https://docs.pipecat.ai/server/services/stt/openai), [Whisper](https://docs.pipecat.ai/server/services/stt/whisper) | `pip install "pipecat-ai[deepgram]"` |
|
||||
| LLMs | [Anthropic](https://docs.pipecat.ai/server/services/llm/anthropic), [Azure](https://docs.pipecat.ai/server/services/llm/azure), [Cerebras](https://docs.pipecat.ai/server/services/llm/cerebras), [DeepSeek](https://docs.pipecat.ai/server/services/llm/deepseek), [Fireworks AI](https://docs.pipecat.ai/server/services/llm/fireworks), [Gemini](https://docs.pipecat.ai/server/services/llm/gemini), [Grok](https://docs.pipecat.ai/server/services/llm/grok), [Groq](https://docs.pipecat.ai/server/services/llm/groq), [NVIDIA NIM](https://docs.pipecat.ai/server/services/llm/nim), [Ollama](https://docs.pipecat.ai/server/services/llm/ollama), [OpenAI](https://docs.pipecat.ai/server/services/llm/openai), [OpenRouter](https://docs.pipecat.ai/server/services/llm/openrouter), [Together AI](https://docs.pipecat.ai/server/services/llm/together) | `pip install "pipecat-ai[openai]"` |
|
||||
| Text-to-Speech | [AWS](https://docs.pipecat.ai/server/services/tts/aws), [Azure](https://docs.pipecat.ai/server/services/tts/azure), [Cartesia](https://docs.pipecat.ai/server/services/tts/cartesia), [Deepgram](https://docs.pipecat.ai/server/services/tts/deepgram), [ElevenLabs](https://docs.pipecat.ai/server/services/tts/elevenlabs), [Fish](https://docs.pipecat.ai/server/services/tts/fish), [Google](https://docs.pipecat.ai/server/services/tts/google), [LMNT](https://docs.pipecat.ai/server/services/tts/lmnt), [OpenAI](https://docs.pipecat.ai/server/services/tts/openai), [PlayHT](https://docs.pipecat.ai/server/services/tts/playht), [Rime](https://docs.pipecat.ai/server/services/tts/rime), [XTTS](https://docs.pipecat.ai/server/services/tts/xtts) | `pip install "pipecat-ai[cartesia]"` |
|
||||
| Speech-to-Speech | [Gemini Multimodal Live](https://docs.pipecat.ai/server/services/s2s/gemini), [OpenAI Realtime](https://docs.pipecat.ai/server/services/s2s/openai) | `pip install "pipecat-ai[openai]"` |
|
||||
|
||||
@@ -20,7 +20,7 @@ from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
from pipecat.pipeline.task import PipelineParams, PipelineTask
|
||||
from pipecat.services.cartesia import CartesiaTTSService
|
||||
from pipecat.services.groq import GroqLLMService
|
||||
from pipecat.services.groq import GroqLLMService, GroqSTTService
|
||||
from pipecat.services.openai import OpenAILLMContext
|
||||
from pipecat.transports.services.daily import DailyParams, DailyTransport
|
||||
|
||||
@@ -50,20 +50,20 @@ async def main():
|
||||
"Respond bot",
|
||||
DailyParams(
|
||||
audio_out_enabled=True,
|
||||
transcription_enabled=True,
|
||||
vad_enabled=True,
|
||||
vad_analyzer=SileroVADAnalyzer(),
|
||||
vad_audio_passthrough=True,
|
||||
),
|
||||
)
|
||||
|
||||
stt = GroqSTTService(api_key=os.getenv("GROQ_API_KEY"), model="distil-whisper-large-v3-en")
|
||||
|
||||
tts = CartesiaTTSService(
|
||||
api_key=os.getenv("CARTESIA_API_KEY"),
|
||||
voice_id="79a125e8-cd45-4c13-8a67-188112f4dd22", # British Lady
|
||||
)
|
||||
|
||||
llm = GroqLLMService(
|
||||
api_key=os.getenv("GROQ_API_KEY"), model="llama3-groq-70b-8192-tool-use-preview"
|
||||
)
|
||||
llm = GroqLLMService(api_key=os.getenv("GROQ_API_KEY"), model="llama-3.3-70b-versatile")
|
||||
# Register a function_name of None to get all functions
|
||||
# sent to the same callback with an additional function_name parameter.
|
||||
llm.register_function(None, fetch_weather_from_api, start_callback=start_fetch_weather)
|
||||
@@ -105,6 +105,7 @@ async def main():
|
||||
pipeline = Pipeline(
|
||||
[
|
||||
transport.input(),
|
||||
stt,
|
||||
context_aggregator.user(),
|
||||
llm,
|
||||
tts,
|
||||
|
||||
@@ -5,9 +5,12 @@
|
||||
#
|
||||
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from pipecat.services.openai import OpenAILLMService
|
||||
from pipecat.services.whisper_base import BaseWhisperSTTService, Transcription
|
||||
|
||||
|
||||
class GroqLLMService(OpenAILLMService):
|
||||
@@ -37,3 +40,33 @@ class GroqLLMService(OpenAILLMService):
|
||||
"""Create OpenAI-compatible client for Groq API endpoint."""
|
||||
logger.debug(f"Creating Groq client with api {base_url}")
|
||||
return super().create_client(api_key, base_url, **kwargs)
|
||||
|
||||
|
||||
class GroqSTTService(BaseWhisperSTTService):
|
||||
"""Groq Whisper speech-to-text service.
|
||||
|
||||
Uses Groq's Whisper API to convert audio to text. Requires a Groq API key
|
||||
set via the api_key parameter or GROQ_API_KEY environment variable.
|
||||
|
||||
Args:
|
||||
model: Whisper model to use. Defaults to "whisper-large-v3-turbo".
|
||||
api_key: Groq API key. Defaults to None.
|
||||
base_url: API base URL. Defaults to "https://api.groq.com/openai/v1".
|
||||
**kwargs: Additional arguments passed to BaseWhisperSTTService.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
model: str = "whisper-large-v3-turbo",
|
||||
api_key: Optional[str] = None,
|
||||
base_url: str = "https://api.groq.com/openai/v1",
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(model=model, api_key=api_key, base_url=base_url, **kwargs)
|
||||
|
||||
async def _transcribe(self, audio: bytes) -> Transcription:
|
||||
return await self._client.audio.transcriptions.create(
|
||||
file=("audio.wav", audio, "audio/wav"), model=self.model_name, response_format="json"
|
||||
)
|
||||
|
||||
@@ -30,7 +30,6 @@ from pipecat.frames.frames import (
|
||||
OpenAILLMContextAssistantTimestampFrame,
|
||||
StartFrame,
|
||||
StartInterruptionFrame,
|
||||
TranscriptionFrame,
|
||||
TTSAudioRawFrame,
|
||||
TTSStartedFrame,
|
||||
TTSStoppedFrame,
|
||||
@@ -52,9 +51,9 @@ from pipecat.processors.frame_processor import FrameDirection
|
||||
from pipecat.services.ai_services import (
|
||||
ImageGenService,
|
||||
LLMService,
|
||||
SegmentedSTTService,
|
||||
TTSService,
|
||||
)
|
||||
from pipecat.services.whisper_base import BaseWhisperSTTService, Transcription
|
||||
from pipecat.utils.time import time_now_iso8601
|
||||
|
||||
try:
|
||||
@@ -65,7 +64,6 @@ try:
|
||||
BadRequestError,
|
||||
DefaultAsyncHttpxClient,
|
||||
)
|
||||
from openai.types.audio import Transcription
|
||||
from openai.types.chat import ChatCompletionChunk, ChatCompletionMessageParam
|
||||
except ModuleNotFoundError as e:
|
||||
logger.error(f"Exception: {e}")
|
||||
@@ -398,16 +396,17 @@ class OpenAIImageGenService(ImageGenService):
|
||||
yield frame
|
||||
|
||||
|
||||
class OpenAISTTService(SegmentedSTTService):
|
||||
"""OpenAI Speech-to-Text (STT) service.
|
||||
class OpenAISTTService(BaseWhisperSTTService):
|
||||
"""OpenAI Whisper speech-to-text service.
|
||||
|
||||
This service uses OpenAI's Whisper API to convert audio to text.
|
||||
Uses OpenAI's Whisper API to convert audio to text. Requires an OpenAI API key
|
||||
set via the api_key parameter or OPENAI_API_KEY environment variable.
|
||||
|
||||
Args:
|
||||
model: Whisper model to use. Defaults to "whisper-1".
|
||||
api_key: OpenAI API key. Defaults to None.
|
||||
base_url: API base URL. Defaults to None.
|
||||
**kwargs: Additional arguments passed to SegmentedSTTService.
|
||||
**kwargs: Additional arguments passed to BaseWhisperSTTService.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -418,39 +417,12 @@ class OpenAISTTService(SegmentedSTTService):
|
||||
base_url: Optional[str] = None,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
self.set_model_name(model)
|
||||
self._client = AsyncOpenAI(api_key=api_key, base_url=base_url)
|
||||
super().__init__(model=model, api_key=api_key, base_url=base_url, **kwargs)
|
||||
|
||||
async def set_model(self, model: str):
|
||||
self.set_model_name(model)
|
||||
|
||||
def can_generate_metrics(self) -> bool:
|
||||
return True
|
||||
|
||||
async def run_stt(self, audio: bytes) -> AsyncGenerator[Frame, None]:
|
||||
try:
|
||||
await self.start_processing_metrics()
|
||||
await self.start_ttfb_metrics()
|
||||
|
||||
response: Transcription = await self._client.audio.transcriptions.create(
|
||||
file=("audio.wav", audio, "audio/wav"), model=self.model_name
|
||||
)
|
||||
|
||||
await self.stop_ttfb_metrics()
|
||||
await self.stop_processing_metrics()
|
||||
|
||||
text = response.text.strip()
|
||||
|
||||
if text:
|
||||
logger.debug(f"Transcription: [{text}]")
|
||||
yield TranscriptionFrame(text, "", time_now_iso8601())
|
||||
else:
|
||||
logger.warning("Received empty transcription from API")
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"Exception during transcription: {e}")
|
||||
yield ErrorFrame(f"Error during transcription: {str(e)}")
|
||||
async def _transcribe(self, audio: bytes) -> Transcription:
|
||||
return await self._client.audio.transcriptions.create(
|
||||
file=("audio.wav", audio, "audio/wav"), model=self.model_name
|
||||
)
|
||||
|
||||
|
||||
class OpenAITTSService(TTSService):
|
||||
|
||||
94
src/pipecat/services/whisper_base.py
Normal file
94
src/pipecat/services/whisper_base.py
Normal file
@@ -0,0 +1,94 @@
|
||||
#
|
||||
# Copyright (c) 2025, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
from typing import AsyncGenerator, Optional
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from pipecat.frames.frames import ErrorFrame, Frame, TranscriptionFrame
|
||||
from pipecat.services.ai_services import SegmentedSTTService
|
||||
from pipecat.utils.time import time_now_iso8601
|
||||
|
||||
try:
|
||||
from openai import AsyncOpenAI
|
||||
from openai.types.audio import Transcription
|
||||
except ModuleNotFoundError as e:
|
||||
logger.error(f"Exception: {e}")
|
||||
logger.error(
|
||||
"In order to use OpenAI, you need to `pip install pipecat-ai[openai]`. Also, set `OPENAI_API_KEY` environment variable."
|
||||
)
|
||||
raise Exception(f"Missing module: {e}")
|
||||
|
||||
|
||||
class BaseWhisperSTTService(SegmentedSTTService):
|
||||
"""Base class for Whisper-based speech-to-text services.
|
||||
|
||||
Provides common functionality for services implementing the Whisper API interface,
|
||||
including metrics generation and error handling.
|
||||
|
||||
Args:
|
||||
model: Name of the Whisper model to use.
|
||||
api_key: Service API key. Defaults to None.
|
||||
base_url: Service API base URL. Defaults to None.
|
||||
**kwargs: Additional arguments passed to SegmentedSTTService.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
model: str,
|
||||
api_key: Optional[str] = None,
|
||||
base_url: Optional[str] = None,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
self.set_model_name(model)
|
||||
self._client = self._create_client(api_key, base_url)
|
||||
|
||||
def _create_client(self, api_key: Optional[str], base_url: Optional[str]):
|
||||
return AsyncOpenAI(api_key=api_key, base_url=base_url)
|
||||
|
||||
async def set_model(self, model: str):
|
||||
self.set_model_name(model)
|
||||
|
||||
def can_generate_metrics(self) -> bool:
|
||||
return True
|
||||
|
||||
async def run_stt(self, audio: bytes) -> AsyncGenerator[Frame, None]:
|
||||
try:
|
||||
await self.start_processing_metrics()
|
||||
await self.start_ttfb_metrics()
|
||||
|
||||
response = await self._transcribe(audio)
|
||||
|
||||
await self.stop_ttfb_metrics()
|
||||
await self.stop_processing_metrics()
|
||||
|
||||
text = response.text.strip()
|
||||
|
||||
if text:
|
||||
logger.debug(f"Transcription: [{text}]")
|
||||
yield TranscriptionFrame(text, "", time_now_iso8601())
|
||||
else:
|
||||
logger.warning("Received empty transcription from API")
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"Exception during transcription: {e}")
|
||||
yield ErrorFrame(f"Error during transcription: {str(e)}")
|
||||
|
||||
async def _transcribe(self, audio: bytes) -> Transcription:
|
||||
"""Transcribe audio data to text.
|
||||
|
||||
Args:
|
||||
audio: Raw audio data in WAV format.
|
||||
|
||||
Returns:
|
||||
Transcription: Object containing the transcribed text.
|
||||
|
||||
Raises:
|
||||
NotImplementedError: Must be implemented by subclasses.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
Reference in New Issue
Block a user