add sambanova llm and stt
This commit is contained in:
@@ -53,8 +53,8 @@ You can connect to Pipecat from any platform using our official SDKs:
|
||||
|
||||
| Category | Services |
|
||||
| ------------------- | --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| Speech-to-Text | [AssemblyAI](https://docs.pipecat.ai/server/services/stt/assemblyai), [AWS](https://docs.pipecat.ai/server/services/stt/aws), [Azure](https://docs.pipecat.ai/server/services/stt/azure), [Cartesia](https://docs.pipecat.ai/server/services/stt/cartesia), [Deepgram](https://docs.pipecat.ai/server/services/stt/deepgram), [Fal Wizper](https://docs.pipecat.ai/server/services/stt/fal), [Gladia](https://docs.pipecat.ai/server/services/stt/gladia), [Google](https://docs.pipecat.ai/server/services/stt/google), [Groq (Whisper)](https://docs.pipecat.ai/server/services/stt/groq), [OpenAI (Whisper)](https://docs.pipecat.ai/server/services/stt/openai), [Parakeet (NVIDIA)](https://docs.pipecat.ai/server/services/stt/parakeet), [Ultravox](https://docs.pipecat.ai/server/services/stt/ultravox), [Whisper](https://docs.pipecat.ai/server/services/stt/whisper) |
|
||||
| LLMs | [Anthropic](https://docs.pipecat.ai/server/services/llm/anthropic), [AWS](https://docs.pipecat.ai/server/services/llm/aws), [Azure](https://docs.pipecat.ai/server/services/llm/azure), [Cerebras](https://docs.pipecat.ai/server/services/llm/cerebras), [DeepSeek](https://docs.pipecat.ai/server/services/llm/deepseek), [Fireworks AI](https://docs.pipecat.ai/server/services/llm/fireworks), [Gemini](https://docs.pipecat.ai/server/services/llm/gemini), [Grok](https://docs.pipecat.ai/server/services/llm/grok), [Groq](https://docs.pipecat.ai/server/services/llm/groq), [NVIDIA NIM](https://docs.pipecat.ai/server/services/llm/nim), [Ollama](https://docs.pipecat.ai/server/services/llm/ollama), [OpenAI](https://docs.pipecat.ai/server/services/llm/openai), [OpenRouter](https://docs.pipecat.ai/server/services/llm/openrouter), [Perplexity](https://docs.pipecat.ai/server/services/llm/perplexity), [Qwen](https://docs.pipecat.ai/server/services/llm/qwen), [Together AI](https://docs.pipecat.ai/server/services/llm/together) |
|
||||
| Speech-to-Text | [AssemblyAI](https://docs.pipecat.ai/server/services/stt/assemblyai), [AWS](https://docs.pipecat.ai/server/services/stt/aws), [Azure](https://docs.pipecat.ai/server/services/stt/azure), [Cartesia](https://docs.pipecat.ai/server/services/stt/cartesia), [Deepgram](https://docs.pipecat.ai/server/services/stt/deepgram), [Fal Wizper](https://docs.pipecat.ai/server/services/stt/fal), [Gladia](https://docs.pipecat.ai/server/services/stt/gladia), [Google](https://docs.pipecat.ai/server/services/stt/google), [Groq (Whisper)](https://docs.pipecat.ai/server/services/stt/groq), [OpenAI (Whisper)](https://docs.pipecat.ai/server/services/stt/openai), [Parakeet (NVIDIA)](https://docs.pipecat.ai/server/services/stt/parakeet), [SambaNova (Whisper)](https://docs.pipecat.ai/server/services/stt/sambanova) [Ultravox](https://docs.pipecat.ai/server/services/stt/ultravox), [Whisper](https://docs.pipecat.ai/server/services/stt/whisper) |
|
||||
| LLMs | [Anthropic](https://docs.pipecat.ai/server/services/llm/anthropic), [AWS](https://docs.pipecat.ai/server/services/llm/aws), [Azure](https://docs.pipecat.ai/server/services/llm/azure), [Cerebras](https://docs.pipecat.ai/server/services/llm/cerebras), [DeepSeek](https://docs.pipecat.ai/server/services/llm/deepseek), [Fireworks AI](https://docs.pipecat.ai/server/services/llm/fireworks), [Gemini](https://docs.pipecat.ai/server/services/llm/gemini), [Grok](https://docs.pipecat.ai/server/services/llm/grok), [Groq](https://docs.pipecat.ai/server/services/llm/groq), [NVIDIA NIM](https://docs.pipecat.ai/server/services/llm/nim), [Ollama](https://docs.pipecat.ai/server/services/llm/ollama), [OpenAI](https://docs.pipecat.ai/server/services/llm/openai), [OpenRouter](https://docs.pipecat.ai/server/services/llm/openrouter), [Perplexity](https://docs.pipecat.ai/server/services/llm/perplexity), [Qwen](https://docs.pipecat.ai/server/services/llm/qwen), [SambaNova](https://docs.pipecat.ai/server/services/llm/sambanova) [Together AI](https://docs.pipecat.ai/server/services/llm/together) |
|
||||
| Text-to-Speech | [AWS](https://docs.pipecat.ai/server/services/tts/aws), [Azure](https://docs.pipecat.ai/server/services/tts/azure), [Cartesia](https://docs.pipecat.ai/server/services/tts/cartesia), [Deepgram](https://docs.pipecat.ai/server/services/tts/deepgram), [ElevenLabs](https://docs.pipecat.ai/server/services/tts/elevenlabs), [FastPitch (NVIDIA)](https://docs.pipecat.ai/server/services/tts/fastpitch), [Fish](https://docs.pipecat.ai/server/services/tts/fish), [Google](https://docs.pipecat.ai/server/services/tts/google), [LMNT](https://docs.pipecat.ai/server/services/tts/lmnt), [MiniMax](https://docs.pipecat.ai/server/services/tts/minimax), [Neuphonic](https://docs.pipecat.ai/server/services/tts/neuphonic), [OpenAI](https://docs.pipecat.ai/server/services/tts/openai), [Piper](https://docs.pipecat.ai/server/services/tts/piper), [PlayHT](https://docs.pipecat.ai/server/services/tts/playht), [Rime](https://docs.pipecat.ai/server/services/tts/rime), [Sarvam](https://docs.pipecat.ai/server/services/tts/sarvam), [XTTS](https://docs.pipecat.ai/server/services/tts/xtts) |
|
||||
| Speech-to-Speech | [AWS Nova Sonic](https://docs.pipecat.ai/server/services/s2s/aws), [Gemini Multimodal Live](https://docs.pipecat.ai/server/services/s2s/gemini), [OpenAI Realtime](https://docs.pipecat.ai/server/services/s2s/openai) |
|
||||
| Transport | [Daily (WebRTC)](https://docs.pipecat.ai/server/services/transport/daily), [FastAPI Websocket](https://docs.pipecat.ai/server/services/transport/fastapi-websocket), [SmallWebRTCTransport](https://docs.pipecat.ai/server/services/transport/small-webrtc), [WebSocket Server](https://docs.pipecat.ai/server/services/transport/websocket-server), Local |
|
||||
|
||||
@@ -42,6 +42,7 @@ pipecat-ai[openai]
|
||||
pipecat-ai[qwen]
|
||||
pipecat-ai[remote-smart-turn]
|
||||
# pipecat-ai[riva] # Mocked
|
||||
pipecat-ai[sambanova]
|
||||
pipecat-ai[silero]
|
||||
pipecat-ai[simli]
|
||||
pipecat-ai[soundfile]
|
||||
|
||||
@@ -107,4 +107,7 @@ MINIMAX_API_KEY=...
|
||||
MINIMAX_GROUP_ID=...
|
||||
|
||||
# Sarvam AI
|
||||
SARVAM_API_KEY=...
|
||||
SARVAM_API_KEY=...
|
||||
|
||||
# SambaNova
|
||||
SAMBANOVA_API_KEY=...
|
||||
@@ -79,6 +79,7 @@ playht = [ "pyht~=0.1.12", "websockets~=13.1" ]
|
||||
qwen = []
|
||||
rime = [ "websockets~=13.1" ]
|
||||
riva = [ "nvidia-riva-client~=2.19.1" ]
|
||||
sambanova = []
|
||||
sentry = [ "sentry-sdk~=2.23.1" ]
|
||||
local-smart-turn = [ "coremltools>=8.0", "transformers", "torch==2.5.0", "torchaudio==2.5.0" ]
|
||||
remote-smart-turn = []
|
||||
|
||||
14
src/pipecat/services/sambanova/__init__.py
Normal file
14
src/pipecat/services/sambanova/__init__.py
Normal file
@@ -0,0 +1,14 @@
|
||||
#
|
||||
# Copyright (c) 2024–2025, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
import sys
|
||||
|
||||
from pipecat.services import DeprecatedModuleProxy
|
||||
|
||||
from .llm import *
|
||||
from .stt import *
|
||||
|
||||
sys.modules[__name__] = DeprecatedModuleProxy(globals(), "sambanova", "sambanova.[llm,stt,tts]")
|
||||
168
src/pipecat/services/sambanova/llm.py
Normal file
168
src/pipecat/services/sambanova/llm.py
Normal file
@@ -0,0 +1,168 @@
|
||||
#
|
||||
# Copyright (c) 2024–2025, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
import json
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from loguru import logger
|
||||
from openai import AsyncStream
|
||||
from openai.types.chat import ChatCompletionChunk, ChatCompletionMessageParam
|
||||
from pipecat.frames.frames import (
|
||||
LLMTextFrame,
|
||||
)
|
||||
from pipecat.metrics.metrics import LLMTokenUsage
|
||||
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
|
||||
from pipecat.services.llm_service import FunctionCallFromLLM
|
||||
from pipecat.services.openai.llm import OpenAILLMService
|
||||
from pipecat.utils.tracing.service_decorators import traced_llm
|
||||
|
||||
|
||||
class SambaNovaLLMService(OpenAILLMService): # type: ignore
|
||||
"""A service for interacting with SambaNova using the OpenAI-compatible interface.
|
||||
This service extends OpenAILLMService to connect to SambaNova's API endpoint while
|
||||
maintaining full compatibility with OpenAI's interface and functionality.
|
||||
Args:
|
||||
api_key (str): The API key for accessing SambaNova API.
|
||||
model (str, optional): The model identifier to use. Defaults to "Meta-Llama-3.3-70B-Instruct".
|
||||
base_url (str, optional): The base URL for SambaNova API. Defaults to "https://api.sambanova.ai/v1".
|
||||
**kwargs: Additional keyword arguments passed to OpenAILLMService.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
api_key: str,
|
||||
model: str = 'Llama-4-Maverick-17B-128E-Instruct',
|
||||
base_url: str = 'https://api.sambanova.ai/v1',
|
||||
**kwargs: Dict[Any, Any],
|
||||
) -> None:
|
||||
super().__init__(api_key=api_key, base_url=base_url, model=model, **kwargs)
|
||||
|
||||
def create_client(
|
||||
self, api_key: Optional[str] = None, base_url: Optional[str] = None, **kwargs: Dict[Any, Any]
|
||||
) -> Any:
|
||||
"""Create OpenAI-compatible client for SambaNova API endpoint."""
|
||||
|
||||
logger.debug(f'Creating SambaNova client with API {base_url}')
|
||||
return super().create_client(api_key, base_url, **kwargs)
|
||||
|
||||
async def get_chat_completions(self, context: OpenAILLMContext, messages: List[ChatCompletionMessageParam]) -> Any:
|
||||
"""Get chat completions from SambaNova API endpoint."""
|
||||
|
||||
params = {
|
||||
'model': self.model_name,
|
||||
'stream': True,
|
||||
'messages': messages,
|
||||
'tools': context.tools,
|
||||
'tool_choice': context.tool_choice,
|
||||
'stream_options': {'include_usage': True},
|
||||
'temperature': self._settings['temperature'],
|
||||
'top_p': self._settings['top_p'],
|
||||
'max_tokens': self._settings['max_tokens'],
|
||||
'max_completion_tokens': self._settings['max_completion_tokens'],
|
||||
}
|
||||
|
||||
params.update(self._settings['extra'])
|
||||
|
||||
chunks = await self._client.chat.completions.create(**params)
|
||||
return chunks
|
||||
|
||||
@traced_llm # type: ignore
|
||||
async def _process_context(self, context: OpenAILLMContext) -> AsyncStream[ChatCompletionChunk]:
|
||||
"""Redefine this method until SambaNova API introduces indexing in tool calls."""
|
||||
|
||||
functions_list = []
|
||||
arguments_list = []
|
||||
tool_id_list = []
|
||||
func_idx = 0
|
||||
function_name = ''
|
||||
arguments = ''
|
||||
tool_call_id = ''
|
||||
|
||||
await self.start_ttfb_metrics()
|
||||
|
||||
chunk_stream: AsyncStream[ChatCompletionChunk] = await self._stream_chat_completions(context)
|
||||
|
||||
async for chunk in chunk_stream:
|
||||
if chunk.usage:
|
||||
tokens = LLMTokenUsage(
|
||||
prompt_tokens=chunk.usage.prompt_tokens,
|
||||
completion_tokens=chunk.usage.completion_tokens,
|
||||
total_tokens=chunk.usage.total_tokens,
|
||||
)
|
||||
await self.start_llm_usage_metrics(tokens)
|
||||
|
||||
if chunk.choices is None or len(chunk.choices) == 0:
|
||||
continue
|
||||
|
||||
await self.stop_ttfb_metrics()
|
||||
|
||||
if not chunk.choices[0].delta:
|
||||
continue
|
||||
|
||||
if chunk.choices[0].delta.tool_calls:
|
||||
# We're streaming the LLM response to enable the fastest response times.
|
||||
# For text, we just yield each chunk as we receive it and count on consumers
|
||||
# to do whatever coalescing they need (eg. to pass full sentences to TTS)
|
||||
#
|
||||
# If the LLM is a function call, we'll do some coalescing here.
|
||||
# If the response contains a function name, we'll yield a frame to tell consumers
|
||||
# that they can start preparing to call the function with that name.
|
||||
# We accumulate all the arguments for the rest of the streamed response, then when
|
||||
# the response is done, we package up all the arguments and the function name and
|
||||
# yield a frame containing the function name and the arguments.
|
||||
|
||||
tool_call = chunk.choices[0].delta.tool_calls[0]
|
||||
if tool_call.index != func_idx:
|
||||
functions_list.append(function_name)
|
||||
arguments_list.append(arguments)
|
||||
tool_id_list.append(tool_call_id)
|
||||
function_name = ''
|
||||
arguments = ''
|
||||
tool_call_id = ''
|
||||
func_idx += 1
|
||||
if tool_call.function and tool_call.function.name:
|
||||
function_name += tool_call.function.name
|
||||
tool_call_id = tool_call.id # type: ignore
|
||||
if tool_call.function and tool_call.function.arguments:
|
||||
# Keep iterating through the response to collect all the argument fragments
|
||||
arguments += tool_call.function.arguments
|
||||
elif chunk.choices[0].delta.content:
|
||||
await self.push_frame(LLMTextFrame(chunk.choices[0].delta.content))
|
||||
|
||||
# When gpt-4o-audio / gpt-4o-mini-audio is used for llm or stt+llm
|
||||
# we need to get LLMTextFrame for the transcript
|
||||
elif hasattr(chunk.choices[0].delta, 'audio') and chunk.choices[0].delta.audio.get('transcript'):
|
||||
await self.push_frame(LLMTextFrame(chunk.choices[0].delta.audio['transcript']))
|
||||
|
||||
# if we got a function name and arguments, check to see if it's a function with
|
||||
# a registered handler. If so, run the registered callback, save the result to
|
||||
# the context, and re-prompt to get a chat answer. If we don't have a registered
|
||||
# handler, raise an exception.
|
||||
if function_name and arguments:
|
||||
# added to the list as last function name and arguments not added to the list
|
||||
functions_list.append(function_name)
|
||||
arguments_list.append(arguments)
|
||||
tool_id_list.append(tool_call_id)
|
||||
|
||||
function_calls = []
|
||||
|
||||
for function_name, arguments, tool_id in zip(functions_list, arguments_list, tool_id_list):
|
||||
# This allows compatibility until SambaNova API introduces indexing in tool calls.
|
||||
if len(arguments) < 1:
|
||||
continue
|
||||
|
||||
arguments = json.loads(arguments)
|
||||
function_calls.append(
|
||||
FunctionCallFromLLM(
|
||||
context=context,
|
||||
tool_call_id=tool_id,
|
||||
function_name=function_name,
|
||||
arguments=arguments,
|
||||
)
|
||||
)
|
||||
|
||||
await self.run_function_calls(function_calls)
|
||||
65
src/pipecat/services/sambanova/stt.py
Normal file
65
src/pipecat/services/sambanova/stt.py
Normal file
@@ -0,0 +1,65 @@
|
||||
#
|
||||
# Copyright (c) 2024–2025, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
from typing import Any, Optional
|
||||
|
||||
from pipecat.services.whisper.base_stt import BaseWhisperSTTService, Transcription
|
||||
from pipecat.transcriptions.language import Language
|
||||
|
||||
|
||||
class SambaNovaSTTService(BaseWhisperSTTService): # type: ignore
|
||||
"""SambaNova Whisper speech-to-text service.
|
||||
Uses SambaNova's Whisper API to convert audio to text.
|
||||
Requires a SambaNova API key set via the api_key parameter or SAMBANOVA_API_KEY environment variable.
|
||||
Args:
|
||||
model: Whisper model to use. Defaults to "Whisper-Large-v3".
|
||||
api_key: SambaNova API key. Defaults to None.
|
||||
base_url: API base URL. Defaults to "https://api.sambanova.ai/v1".
|
||||
language: Language of the audio input. Defaults to English.
|
||||
prompt: Optional text to guide the model's style or continue a previous segment.
|
||||
temperature: Optional sampling temperature between 0 and 1. Defaults to 0.0.
|
||||
**kwargs: Additional arguments passed to `pipecat.services.whisper.base_stt.BaseWhisperSTTService`.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
model: str = 'Whisper-Large-v3',
|
||||
api_key: Optional[str] = None,
|
||||
base_url: str = 'https://api.sambanova.ai/v1',
|
||||
language: Optional[Language] = Language.EN,
|
||||
prompt: Optional[str] = None,
|
||||
temperature: Optional[float] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
model=model,
|
||||
api_key=api_key,
|
||||
base_url=base_url,
|
||||
language=language,
|
||||
prompt=prompt,
|
||||
temperature=temperature,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
async def _transcribe(self, audio: bytes) -> Transcription:
|
||||
assert self._language is not None # Assigned in the BaseWhisperSTTService class
|
||||
|
||||
# Build kwargs dict with only set parameters
|
||||
kwargs = {
|
||||
'file': ('audio.wav', audio, 'audio/wav'),
|
||||
'model': self.model_name,
|
||||
'response_format': 'json',
|
||||
'language': self._language,
|
||||
}
|
||||
|
||||
if self._prompt is not None:
|
||||
kwargs['prompt'] = self._prompt
|
||||
|
||||
if self._temperature is not None:
|
||||
kwargs['temperature'] = self._temperature
|
||||
|
||||
return await self._client.audio.transcriptions.create(**kwargs)
|
||||
Reference in New Issue
Block a user