Initial
This commit is contained in:
1
changelog/3978.added.md
Normal file
1
changelog/3978.added.md
Normal file
@@ -0,0 +1 @@
|
||||
- Added `SarvamLLMService` with support for `sarvam-30b`, `sarvam-30b-16k`, `sarvam-105b` and `sarvam-105b-32k`
|
||||
146
examples/foundational/55zzq-update-settings-sarvam-llm.py
Normal file
146
examples/foundational/55zzq-update-settings-sarvam-llm.py
Normal file
@@ -0,0 +1,146 @@
|
||||
#
|
||||
# Copyright (c) 2024-2026, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
from dotenv import load_dotenv
|
||||
from loguru import logger
|
||||
|
||||
from pipecat.audio.vad.silero import SileroVADAnalyzer
|
||||
from pipecat.frames.frames import LLMRunFrame, LLMUpdateSettingsFrame
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
from pipecat.pipeline.task import PipelineParams, PipelineTask
|
||||
from pipecat.processors.aggregators.llm_context import LLMContext
|
||||
from pipecat.processors.aggregators.llm_response_universal import (
|
||||
LLMContextAggregatorPair,
|
||||
LLMUserAggregatorParams,
|
||||
)
|
||||
from pipecat.runner.types import RunnerArguments
|
||||
from pipecat.runner.utils import create_transport
|
||||
from pipecat.services.openai.base_llm import OpenAILLMSettings
|
||||
from pipecat.services.sarvam.llm import SarvamLLMService
|
||||
from pipecat.services.sarvam.stt import SarvamSTTService
|
||||
from pipecat.services.sarvam.tts import SarvamTTSService
|
||||
from pipecat.transports.base_transport import BaseTransport, TransportParams
|
||||
from pipecat.transports.daily.transport import DailyParams
|
||||
from pipecat.transports.websocket.fastapi import FastAPIWebsocketParams
|
||||
|
||||
load_dotenv(override=True)
|
||||
|
||||
transport_params = {
|
||||
"daily": lambda: DailyParams(
|
||||
audio_in_enabled=True,
|
||||
audio_out_enabled=True,
|
||||
),
|
||||
"twilio": lambda: FastAPIWebsocketParams(
|
||||
audio_in_enabled=True,
|
||||
audio_out_enabled=True,
|
||||
),
|
||||
"webrtc": lambda: TransportParams(
|
||||
audio_in_enabled=True,
|
||||
audio_out_enabled=True,
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
def _require_env(name: str) -> str:
|
||||
value = os.getenv(name)
|
||||
if not value:
|
||||
raise ValueError(f"Environment variable `{name}` is required.")
|
||||
return value
|
||||
|
||||
|
||||
async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
logger.info("Starting bot")
|
||||
|
||||
stt = SarvamSTTService(
|
||||
model="saaras:v3",
|
||||
api_key=_require_env("SARVAM_API_KEY"),
|
||||
)
|
||||
|
||||
tts = SarvamTTSService(
|
||||
model="bulbul:v3",
|
||||
api_key=_require_env("SARVAM_API_KEY"),
|
||||
)
|
||||
|
||||
llm = SarvamLLMService(
|
||||
api_key=_require_env("SARVAM_API_KEY"),
|
||||
model="sarvam-30b",
|
||||
)
|
||||
|
||||
messages: list[Any] = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": (
|
||||
"You are a helpful LLM in a WebRTC call. Your goal is to "
|
||||
"demonstrate your capabilities in a succinct way. Your output "
|
||||
"will be spoken aloud, so avoid special characters that can't "
|
||||
"easily be spoken, such as emojis or bullet points. Respond to "
|
||||
"what the user said in a creative and helpful way."
|
||||
),
|
||||
},
|
||||
]
|
||||
|
||||
context = LLMContext(messages)
|
||||
user_aggregator, assistant_aggregator = LLMContextAggregatorPair(
|
||||
context,
|
||||
user_params=LLMUserAggregatorParams(vad_analyzer=SileroVADAnalyzer()),
|
||||
)
|
||||
|
||||
pipeline = Pipeline(
|
||||
[
|
||||
transport.input(),
|
||||
stt,
|
||||
user_aggregator,
|
||||
llm,
|
||||
tts,
|
||||
transport.output(),
|
||||
assistant_aggregator,
|
||||
]
|
||||
)
|
||||
|
||||
task = PipelineTask(
|
||||
pipeline,
|
||||
params=PipelineParams(
|
||||
enable_metrics=True,
|
||||
enable_usage_metrics=True,
|
||||
),
|
||||
idle_timeout_secs=runner_args.pipeline_idle_timeout_secs,
|
||||
)
|
||||
|
||||
@transport.event_handler("on_client_connected")
|
||||
async def on_client_connected(transport, client):
|
||||
logger.info("Client connected")
|
||||
messages.append({"role": "system", "content": "Please introduce yourself to the user."})
|
||||
await task.queue_frames([LLMRunFrame()])
|
||||
|
||||
await asyncio.sleep(10)
|
||||
logger.info("Updating Sarvam LLM settings: temperature=0.1")
|
||||
await task.queue_frame(LLMUpdateSettingsFrame(delta=OpenAILLMSettings(temperature=0.1)))
|
||||
|
||||
@transport.event_handler("on_client_disconnected")
|
||||
async def on_client_disconnected(transport, client):
|
||||
logger.info("Client disconnected")
|
||||
await task.cancel()
|
||||
|
||||
runner = PipelineRunner(handle_sigint=runner_args.handle_sigint)
|
||||
|
||||
await runner.run(task)
|
||||
|
||||
|
||||
async def bot(runner_args: RunnerArguments):
|
||||
"""Main bot entry point compatible with Pipecat Cloud."""
|
||||
transport = await create_transport(runner_args, transport_params)
|
||||
await run_bot(transport, runner_args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from pipecat.runner.run import main
|
||||
|
||||
main()
|
||||
@@ -4,5 +4,5 @@
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
|
||||
from .llm import SarvamLLMService
|
||||
from .tts import *
|
||||
|
||||
235
src/pipecat/services/sarvam/llm.py
Normal file
235
src/pipecat/services/sarvam/llm.py
Normal file
@@ -0,0 +1,235 @@
|
||||
#
|
||||
# Copyright (c) 2024-2026, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
"""Sarvam LLM service implementation using OpenAI-compatible interface."""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from typing import Any, Literal, Mapping, Optional
|
||||
|
||||
import httpx
|
||||
from loguru import logger
|
||||
from openai import NOT_GIVEN, APITimeoutError
|
||||
from pydantic import BaseModel
|
||||
|
||||
from pipecat.adapters.services.open_ai_adapter import OpenAILLMInvocationParams
|
||||
from pipecat.processors.aggregators.llm_context import LLMContext
|
||||
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
|
||||
from pipecat.services.openai.llm import OpenAILLMService
|
||||
from pipecat.services.sarvam._sdk import sdk_headers
|
||||
|
||||
__all__ = ["SarvamLLMService"]
|
||||
|
||||
|
||||
class SarvamLLMService(OpenAILLMService):
|
||||
"""Sarvam LLM service using Sarvam's OpenAI-compatible chat completions API.
|
||||
|
||||
This service extends ``OpenAILLMService`` while adding Sarvam-specific behavior:
|
||||
|
||||
- model allow-list validation
|
||||
- request shaping for Sarvam-compatible parameters
|
||||
- Sarvam auth header wiring (``api-subscription-key``)
|
||||
- SDK User-Agent propagation on every API call
|
||||
- raw Sarvam server error passthrough
|
||||
"""
|
||||
|
||||
SUPPORTED_MODELS = frozenset({"sarvam-30b", "sarvam-30b-16k", "sarvam-105b", "sarvam-105b-32k"})
|
||||
TOOL_CALLING_MODELS = frozenset(
|
||||
{"sarvam-30b", "sarvam-30b-16k", "sarvam-105b", "sarvam-105b-32k"}
|
||||
)
|
||||
|
||||
class InputParams(OpenAILLMService.InputParams):
|
||||
"""Configuration parameters for Sarvam LLM service.
|
||||
|
||||
Parameters:
|
||||
frequency_penalty: Penalty for frequent tokens (-2.0 to 2.0).
|
||||
presence_penalty: Penalty for new tokens (-2.0 to 2.0).
|
||||
seed: Random seed for deterministic outputs.
|
||||
temperature: Sampling temperature (0.0 to 2.0).
|
||||
top_k: Top-k sampling parameter (currently ignored by OpenAI client).
|
||||
top_p: Top-p (nucleus) sampling parameter (0.0 to 1.0).
|
||||
max_tokens: Maximum tokens in response.
|
||||
max_completion_tokens: Maximum completion tokens (not sent to Sarvam API).
|
||||
service_tier: Service tier (not sent to Sarvam API).
|
||||
extra: Additional model-specific parameters.
|
||||
wiki_grounding: Sarvam wiki grounding toggle.
|
||||
reasoning_effort: Reasoning effort level (low, medium, high).
|
||||
"""
|
||||
|
||||
wiki_grounding: Optional[bool] = None
|
||||
reasoning_effort: Optional[Literal["low", "medium", "high"]] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
api_key: str,
|
||||
model: str = "sarvam-30b",
|
||||
base_url: str = "https://api.sarvam.ai/v1",
|
||||
default_headers: Optional[Mapping[str, str]] = None,
|
||||
params: Optional[InputParams] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""Initialize Sarvam LLM service.
|
||||
|
||||
Args:
|
||||
api_key: Sarvam API key used for both OpenAI auth and Sarvam subscription header.
|
||||
model: Sarvam model identifier. Supported values: ``sarvam-30b``, ``sarvam-105b``.
|
||||
base_url: Sarvam OpenAI-compatible base URL.
|
||||
default_headers: Additional HTTP headers to include in requests.
|
||||
params: Input parameters for model configuration.
|
||||
**kwargs: Additional keyword arguments passed to ``OpenAILLMService``.
|
||||
"""
|
||||
self._validate_model(model)
|
||||
|
||||
params = (params or SarvamLLMService.InputParams()).model_copy(deep=True)
|
||||
params.extra = self._build_extra_params(params)
|
||||
|
||||
super().__init__(
|
||||
api_key=api_key,
|
||||
base_url=base_url,
|
||||
model=model,
|
||||
default_headers=default_headers,
|
||||
params=params,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def create_client(
|
||||
self,
|
||||
api_key=None,
|
||||
base_url=None,
|
||||
organization=None,
|
||||
project=None,
|
||||
default_headers=None,
|
||||
**kwargs,
|
||||
):
|
||||
"""Create OpenAI-compatible client for Sarvam API endpoint.
|
||||
|
||||
Ensures Sarvam auth and SDK identification headers are always attached.
|
||||
"""
|
||||
merged_headers = dict(default_headers or {})
|
||||
merged_headers.update(sdk_headers())
|
||||
if api_key:
|
||||
merged_headers["api-subscription-key"] = api_key
|
||||
|
||||
# Keep SDK User-Agent stable even when caller-provided headers include User-Agent.
|
||||
merged_headers["User-Agent"] = sdk_headers()["User-Agent"]
|
||||
|
||||
logger.debug(f"Creating Sarvam client with API {base_url}")
|
||||
return super().create_client(
|
||||
api_key=api_key,
|
||||
base_url=base_url,
|
||||
organization=organization,
|
||||
project=project,
|
||||
default_headers=merged_headers,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def build_chat_completion_params(self, params_from_context: OpenAILLMInvocationParams) -> dict:
|
||||
"""Build parameters for Sarvam chat completion request.
|
||||
|
||||
Starts from OpenAI-compatible defaults, then removes unsupported
|
||||
request fields and applies Sarvam-specific options.
|
||||
"""
|
||||
self._validate_tool_parameters(params_from_context)
|
||||
|
||||
params = super().build_chat_completion_params(params_from_context)
|
||||
params.pop("stream_options", None)
|
||||
params.pop("max_completion_tokens", None)
|
||||
params.pop("service_tier", None)
|
||||
|
||||
extra = self._settings.extra if isinstance(self._settings.extra, dict) else {}
|
||||
if "wiki_grounding" in extra and extra["wiki_grounding"] is not None:
|
||||
params["wiki_grounding"] = extra["wiki_grounding"]
|
||||
if "reasoning_effort" in extra and extra["reasoning_effort"] is not None:
|
||||
params["reasoning_effort"] = extra["reasoning_effort"]
|
||||
|
||||
return params
|
||||
|
||||
async def get_chat_completions(self, params_from_context: OpenAILLMInvocationParams):
|
||||
"""Get streaming chat completions with Sarvam raw error passthrough."""
|
||||
try:
|
||||
return await super().get_chat_completions(params_from_context)
|
||||
except (APITimeoutError, asyncio.TimeoutError, httpx.TimeoutException):
|
||||
raise
|
||||
except Exception as e:
|
||||
raise RuntimeError(self._format_raw_server_error(e)) from e
|
||||
|
||||
async def run_inference(
|
||||
self, context: LLMContext | OpenAILLMContext, max_tokens: Optional[int] = None
|
||||
) -> Optional[str]:
|
||||
"""Run one-shot inference and preserve Sarvam raw server errors."""
|
||||
try:
|
||||
return await super().run_inference(context, max_tokens=max_tokens)
|
||||
except (APITimeoutError, asyncio.TimeoutError, httpx.TimeoutException):
|
||||
raise
|
||||
except Exception as e:
|
||||
raise RuntimeError(self._format_raw_server_error(e)) from e
|
||||
|
||||
def _validate_model(self, model: str):
|
||||
if model not in self.SUPPORTED_MODELS:
|
||||
allowed = ", ".join(sorted(self.SUPPORTED_MODELS))
|
||||
raise ValueError(f"Unsupported Sarvam LLM model '{model}'. Allowed values: {allowed}.")
|
||||
|
||||
def _build_extra_params(self, params: BaseModel) -> dict[str, Any]:
|
||||
extra = dict(getattr(params, "extra", {}) or {})
|
||||
if getattr(params, "wiki_grounding", None) is not None:
|
||||
extra["wiki_grounding"] = params.wiki_grounding
|
||||
if getattr(params, "reasoning_effort", None) is not None:
|
||||
extra["reasoning_effort"] = params.reasoning_effort
|
||||
return extra
|
||||
|
||||
def _validate_tool_parameters(self, params_from_context: OpenAILLMInvocationParams):
|
||||
tools = params_from_context.get("tools", NOT_GIVEN)
|
||||
tool_choice = params_from_context.get("tool_choice", NOT_GIVEN)
|
||||
|
||||
has_tools = (
|
||||
tools is not NOT_GIVEN
|
||||
and tools is not None
|
||||
and (not isinstance(tools, list) or len(tools) > 0)
|
||||
)
|
||||
has_tool_choice = tool_choice is not NOT_GIVEN and tool_choice is not None
|
||||
|
||||
if has_tool_choice and not has_tools:
|
||||
raise ValueError("Sarvam requires non-empty `tools` when `tool_choice` is provided.")
|
||||
|
||||
if has_tools and self._settings.model not in self.TOOL_CALLING_MODELS:
|
||||
allowed = ", ".join(sorted(self.TOOL_CALLING_MODELS))
|
||||
raise ValueError(
|
||||
f"Model '{self._settings.model}' does not support tools. "
|
||||
f"Supported models: {allowed}."
|
||||
)
|
||||
|
||||
def _format_raw_server_error(self, error: Exception) -> str:
|
||||
raw_message = self._extract_raw_server_message(error)
|
||||
return f"Sarvam server error: {raw_message}"
|
||||
|
||||
def _extract_raw_server_message(self, error: Exception) -> str:
|
||||
body = getattr(error, "body", None)
|
||||
if body is not None:
|
||||
return self._payload_to_message(body)
|
||||
|
||||
response = getattr(error, "response", None)
|
||||
if response is not None:
|
||||
try:
|
||||
return self._payload_to_message(response.json())
|
||||
except Exception:
|
||||
text = getattr(response, "text", None)
|
||||
if text:
|
||||
return str(text)
|
||||
|
||||
return str(error)
|
||||
|
||||
def _payload_to_message(self, payload: Any) -> str:
|
||||
if isinstance(payload, dict):
|
||||
error_obj = payload.get("error")
|
||||
if isinstance(error_obj, dict) and isinstance(error_obj.get("message"), str):
|
||||
return error_obj["message"]
|
||||
if isinstance(payload.get("message"), str):
|
||||
return payload["message"]
|
||||
return json.dumps(payload, ensure_ascii=False)
|
||||
if isinstance(payload, list):
|
||||
return json.dumps(payload, ensure_ascii=False)
|
||||
return str(payload)
|
||||
Reference in New Issue
Block a user