From dc0386937a703131ffc9027902fa808d4c20ed3f Mon Sep 17 00:00:00 2001 From: dhruvladia-sarvam Date: Wed, 11 Mar 2026 02:27:57 +0530 Subject: [PATCH] Initial --- changelog/3978.added.md | 1 + .../55zzq-update-settings-sarvam-llm.py | 146 +++++++++++ src/pipecat/services/sarvam/__init__.py | 2 +- src/pipecat/services/sarvam/llm.py | 235 ++++++++++++++++++ 4 files changed, 383 insertions(+), 1 deletion(-) create mode 100644 changelog/3978.added.md create mode 100644 examples/foundational/55zzq-update-settings-sarvam-llm.py create mode 100644 src/pipecat/services/sarvam/llm.py diff --git a/changelog/3978.added.md b/changelog/3978.added.md new file mode 100644 index 000000000..bb80b10c0 --- /dev/null +++ b/changelog/3978.added.md @@ -0,0 +1 @@ +- Added `SarvamLLMService` with support for `sarvam-30b`, `sarvam-30b-16k`, `sarvam-105b` and `sarvam-105b-32k` diff --git a/examples/foundational/55zzq-update-settings-sarvam-llm.py b/examples/foundational/55zzq-update-settings-sarvam-llm.py new file mode 100644 index 000000000..1d3f9d754 --- /dev/null +++ b/examples/foundational/55zzq-update-settings-sarvam-llm.py @@ -0,0 +1,146 @@ +# +# Copyright (c) 2024-2026, Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + +import asyncio +import os +from typing import Any + +from dotenv import load_dotenv +from loguru import logger + +from pipecat.audio.vad.silero import SileroVADAnalyzer +from pipecat.frames.frames import LLMRunFrame, LLMUpdateSettingsFrame +from pipecat.pipeline.pipeline import Pipeline +from pipecat.pipeline.runner import PipelineRunner +from pipecat.pipeline.task import PipelineParams, PipelineTask +from pipecat.processors.aggregators.llm_context import LLMContext +from pipecat.processors.aggregators.llm_response_universal import ( + LLMContextAggregatorPair, + LLMUserAggregatorParams, +) +from pipecat.runner.types import RunnerArguments +from pipecat.runner.utils import create_transport +from pipecat.services.openai.base_llm import OpenAILLMSettings +from pipecat.services.sarvam.llm import SarvamLLMService +from pipecat.services.sarvam.stt import SarvamSTTService +from pipecat.services.sarvam.tts import SarvamTTSService +from pipecat.transports.base_transport import BaseTransport, TransportParams +from pipecat.transports.daily.transport import DailyParams +from pipecat.transports.websocket.fastapi import FastAPIWebsocketParams + +load_dotenv(override=True) + +transport_params = { + "daily": lambda: DailyParams( + audio_in_enabled=True, + audio_out_enabled=True, + ), + "twilio": lambda: FastAPIWebsocketParams( + audio_in_enabled=True, + audio_out_enabled=True, + ), + "webrtc": lambda: TransportParams( + audio_in_enabled=True, + audio_out_enabled=True, + ), +} + + +def _require_env(name: str) -> str: + value = os.getenv(name) + if not value: + raise ValueError(f"Environment variable `{name}` is required.") + return value + + +async def run_bot(transport: BaseTransport, runner_args: RunnerArguments): + logger.info("Starting bot") + + stt = SarvamSTTService( + model="saaras:v3", + api_key=_require_env("SARVAM_API_KEY"), + ) + + tts = SarvamTTSService( + model="bulbul:v3", + api_key=_require_env("SARVAM_API_KEY"), + ) + + llm = SarvamLLMService( + api_key=_require_env("SARVAM_API_KEY"), + model="sarvam-30b", + ) + + messages: list[Any] = [ + { + "role": "system", + "content": ( + "You are a helpful LLM in a WebRTC call. Your goal is to " + "demonstrate your capabilities in a succinct way. Your output " + "will be spoken aloud, so avoid special characters that can't " + "easily be spoken, such as emojis or bullet points. Respond to " + "what the user said in a creative and helpful way." + ), + }, + ] + + context = LLMContext(messages) + user_aggregator, assistant_aggregator = LLMContextAggregatorPair( + context, + user_params=LLMUserAggregatorParams(vad_analyzer=SileroVADAnalyzer()), + ) + + pipeline = Pipeline( + [ + transport.input(), + stt, + user_aggregator, + llm, + tts, + transport.output(), + assistant_aggregator, + ] + ) + + task = PipelineTask( + pipeline, + params=PipelineParams( + enable_metrics=True, + enable_usage_metrics=True, + ), + idle_timeout_secs=runner_args.pipeline_idle_timeout_secs, + ) + + @transport.event_handler("on_client_connected") + async def on_client_connected(transport, client): + logger.info("Client connected") + messages.append({"role": "system", "content": "Please introduce yourself to the user."}) + await task.queue_frames([LLMRunFrame()]) + + await asyncio.sleep(10) + logger.info("Updating Sarvam LLM settings: temperature=0.1") + await task.queue_frame(LLMUpdateSettingsFrame(delta=OpenAILLMSettings(temperature=0.1))) + + @transport.event_handler("on_client_disconnected") + async def on_client_disconnected(transport, client): + logger.info("Client disconnected") + await task.cancel() + + runner = PipelineRunner(handle_sigint=runner_args.handle_sigint) + + await runner.run(task) + + +async def bot(runner_args: RunnerArguments): + """Main bot entry point compatible with Pipecat Cloud.""" + transport = await create_transport(runner_args, transport_params) + await run_bot(transport, runner_args) + + +if __name__ == "__main__": + from pipecat.runner.run import main + + main() diff --git a/src/pipecat/services/sarvam/__init__.py b/src/pipecat/services/sarvam/__init__.py index e8af1401e..1357f737a 100644 --- a/src/pipecat/services/sarvam/__init__.py +++ b/src/pipecat/services/sarvam/__init__.py @@ -4,5 +4,5 @@ # SPDX-License-Identifier: BSD 2-Clause License # - +from .llm import SarvamLLMService from .tts import * diff --git a/src/pipecat/services/sarvam/llm.py b/src/pipecat/services/sarvam/llm.py new file mode 100644 index 000000000..6b440d7f9 --- /dev/null +++ b/src/pipecat/services/sarvam/llm.py @@ -0,0 +1,235 @@ +# +# Copyright (c) 2024-2026, Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + +"""Sarvam LLM service implementation using OpenAI-compatible interface.""" + +import asyncio +import json +from typing import Any, Literal, Mapping, Optional + +import httpx +from loguru import logger +from openai import NOT_GIVEN, APITimeoutError +from pydantic import BaseModel + +from pipecat.adapters.services.open_ai_adapter import OpenAILLMInvocationParams +from pipecat.processors.aggregators.llm_context import LLMContext +from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext +from pipecat.services.openai.llm import OpenAILLMService +from pipecat.services.sarvam._sdk import sdk_headers + +__all__ = ["SarvamLLMService"] + + +class SarvamLLMService(OpenAILLMService): + """Sarvam LLM service using Sarvam's OpenAI-compatible chat completions API. + + This service extends ``OpenAILLMService`` while adding Sarvam-specific behavior: + + - model allow-list validation + - request shaping for Sarvam-compatible parameters + - Sarvam auth header wiring (``api-subscription-key``) + - SDK User-Agent propagation on every API call + - raw Sarvam server error passthrough + """ + + SUPPORTED_MODELS = frozenset({"sarvam-30b", "sarvam-30b-16k", "sarvam-105b", "sarvam-105b-32k"}) + TOOL_CALLING_MODELS = frozenset( + {"sarvam-30b", "sarvam-30b-16k", "sarvam-105b", "sarvam-105b-32k"} + ) + + class InputParams(OpenAILLMService.InputParams): + """Configuration parameters for Sarvam LLM service. + + Parameters: + frequency_penalty: Penalty for frequent tokens (-2.0 to 2.0). + presence_penalty: Penalty for new tokens (-2.0 to 2.0). + seed: Random seed for deterministic outputs. + temperature: Sampling temperature (0.0 to 2.0). + top_k: Top-k sampling parameter (currently ignored by OpenAI client). + top_p: Top-p (nucleus) sampling parameter (0.0 to 1.0). + max_tokens: Maximum tokens in response. + max_completion_tokens: Maximum completion tokens (not sent to Sarvam API). + service_tier: Service tier (not sent to Sarvam API). + extra: Additional model-specific parameters. + wiki_grounding: Sarvam wiki grounding toggle. + reasoning_effort: Reasoning effort level (low, medium, high). + """ + + wiki_grounding: Optional[bool] = None + reasoning_effort: Optional[Literal["low", "medium", "high"]] = None + + def __init__( + self, + *, + api_key: str, + model: str = "sarvam-30b", + base_url: str = "https://api.sarvam.ai/v1", + default_headers: Optional[Mapping[str, str]] = None, + params: Optional[InputParams] = None, + **kwargs, + ): + """Initialize Sarvam LLM service. + + Args: + api_key: Sarvam API key used for both OpenAI auth and Sarvam subscription header. + model: Sarvam model identifier. Supported values: ``sarvam-30b``, ``sarvam-105b``. + base_url: Sarvam OpenAI-compatible base URL. + default_headers: Additional HTTP headers to include in requests. + params: Input parameters for model configuration. + **kwargs: Additional keyword arguments passed to ``OpenAILLMService``. + """ + self._validate_model(model) + + params = (params or SarvamLLMService.InputParams()).model_copy(deep=True) + params.extra = self._build_extra_params(params) + + super().__init__( + api_key=api_key, + base_url=base_url, + model=model, + default_headers=default_headers, + params=params, + **kwargs, + ) + + def create_client( + self, + api_key=None, + base_url=None, + organization=None, + project=None, + default_headers=None, + **kwargs, + ): + """Create OpenAI-compatible client for Sarvam API endpoint. + + Ensures Sarvam auth and SDK identification headers are always attached. + """ + merged_headers = dict(default_headers or {}) + merged_headers.update(sdk_headers()) + if api_key: + merged_headers["api-subscription-key"] = api_key + + # Keep SDK User-Agent stable even when caller-provided headers include User-Agent. + merged_headers["User-Agent"] = sdk_headers()["User-Agent"] + + logger.debug(f"Creating Sarvam client with API {base_url}") + return super().create_client( + api_key=api_key, + base_url=base_url, + organization=organization, + project=project, + default_headers=merged_headers, + **kwargs, + ) + + def build_chat_completion_params(self, params_from_context: OpenAILLMInvocationParams) -> dict: + """Build parameters for Sarvam chat completion request. + + Starts from OpenAI-compatible defaults, then removes unsupported + request fields and applies Sarvam-specific options. + """ + self._validate_tool_parameters(params_from_context) + + params = super().build_chat_completion_params(params_from_context) + params.pop("stream_options", None) + params.pop("max_completion_tokens", None) + params.pop("service_tier", None) + + extra = self._settings.extra if isinstance(self._settings.extra, dict) else {} + if "wiki_grounding" in extra and extra["wiki_grounding"] is not None: + params["wiki_grounding"] = extra["wiki_grounding"] + if "reasoning_effort" in extra and extra["reasoning_effort"] is not None: + params["reasoning_effort"] = extra["reasoning_effort"] + + return params + + async def get_chat_completions(self, params_from_context: OpenAILLMInvocationParams): + """Get streaming chat completions with Sarvam raw error passthrough.""" + try: + return await super().get_chat_completions(params_from_context) + except (APITimeoutError, asyncio.TimeoutError, httpx.TimeoutException): + raise + except Exception as e: + raise RuntimeError(self._format_raw_server_error(e)) from e + + async def run_inference( + self, context: LLMContext | OpenAILLMContext, max_tokens: Optional[int] = None + ) -> Optional[str]: + """Run one-shot inference and preserve Sarvam raw server errors.""" + try: + return await super().run_inference(context, max_tokens=max_tokens) + except (APITimeoutError, asyncio.TimeoutError, httpx.TimeoutException): + raise + except Exception as e: + raise RuntimeError(self._format_raw_server_error(e)) from e + + def _validate_model(self, model: str): + if model not in self.SUPPORTED_MODELS: + allowed = ", ".join(sorted(self.SUPPORTED_MODELS)) + raise ValueError(f"Unsupported Sarvam LLM model '{model}'. Allowed values: {allowed}.") + + def _build_extra_params(self, params: BaseModel) -> dict[str, Any]: + extra = dict(getattr(params, "extra", {}) or {}) + if getattr(params, "wiki_grounding", None) is not None: + extra["wiki_grounding"] = params.wiki_grounding + if getattr(params, "reasoning_effort", None) is not None: + extra["reasoning_effort"] = params.reasoning_effort + return extra + + def _validate_tool_parameters(self, params_from_context: OpenAILLMInvocationParams): + tools = params_from_context.get("tools", NOT_GIVEN) + tool_choice = params_from_context.get("tool_choice", NOT_GIVEN) + + has_tools = ( + tools is not NOT_GIVEN + and tools is not None + and (not isinstance(tools, list) or len(tools) > 0) + ) + has_tool_choice = tool_choice is not NOT_GIVEN and tool_choice is not None + + if has_tool_choice and not has_tools: + raise ValueError("Sarvam requires non-empty `tools` when `tool_choice` is provided.") + + if has_tools and self._settings.model not in self.TOOL_CALLING_MODELS: + allowed = ", ".join(sorted(self.TOOL_CALLING_MODELS)) + raise ValueError( + f"Model '{self._settings.model}' does not support tools. " + f"Supported models: {allowed}." + ) + + def _format_raw_server_error(self, error: Exception) -> str: + raw_message = self._extract_raw_server_message(error) + return f"Sarvam server error: {raw_message}" + + def _extract_raw_server_message(self, error: Exception) -> str: + body = getattr(error, "body", None) + if body is not None: + return self._payload_to_message(body) + + response = getattr(error, "response", None) + if response is not None: + try: + return self._payload_to_message(response.json()) + except Exception: + text = getattr(response, "text", None) + if text: + return str(text) + + return str(error) + + def _payload_to_message(self, payload: Any) -> str: + if isinstance(payload, dict): + error_obj = payload.get("error") + if isinstance(error_obj, dict) and isinstance(error_obj.get("message"), str): + return error_obj["message"] + if isinstance(payload.get("message"), str): + return payload["message"] + return json.dumps(payload, ensure_ascii=False) + if isinstance(payload, list): + return json.dumps(payload, ensure_ascii=False) + return str(payload)