fix llm wrapper redundancy and restore run_inference parity

This commit is contained in:
dhruvladia-sarvam
2026-03-15 22:24:06 +05:30
parent dc0386937a
commit 8745f20330

View File

@@ -8,20 +8,40 @@
import asyncio
import json
from typing import Any, Literal, Mapping, Optional
from dataclasses import dataclass, field
from typing import Any, Awaitable, Literal, Mapping, Optional, TypeVar
import httpx
from loguru import logger
from openai import NOT_GIVEN, APITimeoutError
from pydantic import BaseModel
from openai import NOT_GIVEN, APITimeoutError, AsyncStream
from openai.types.chat import ChatCompletionChunk
from pipecat.adapters.services.open_ai_adapter import OpenAILLMInvocationParams
from pipecat.processors.aggregators.llm_context import LLMContext
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
from pipecat.services.openai.base_llm import OpenAILLMSettings
from pipecat.services.openai.llm import OpenAILLMService
from pipecat.services.sarvam._sdk import sdk_headers
from pipecat.services.settings import NOT_GIVEN as _NOT_GIVEN
from pipecat.services.settings import _NotGiven, _warn_deprecated_param, is_given
__all__ = ["SarvamLLMService"]
__all__ = ["SarvamLLMService", "SarvamLLMSettings"]
_T = TypeVar("_T")
@dataclass
class SarvamLLMSettings(OpenAILLMSettings):
"""Settings for SarvamLLMService.
Parameters:
wiki_grounding: Sarvam wiki grounding toggle.
reasoning_effort: Reasoning effort level (low, medium, high).
"""
wiki_grounding: bool | None | _NotGiven = field(default_factory=lambda: _NOT_GIVEN)
reasoning_effort: Literal["low", "medium", "high"] | None | _NotGiven = field(
default_factory=lambda: _NOT_GIVEN
)
class SarvamLLMService(OpenAILLMService):
@@ -40,59 +60,59 @@ class SarvamLLMService(OpenAILLMService):
TOOL_CALLING_MODELS = frozenset(
{"sarvam-30b", "sarvam-30b-16k", "sarvam-105b", "sarvam-105b-32k"}
)
class InputParams(OpenAILLMService.InputParams):
"""Configuration parameters for Sarvam LLM service.
Parameters:
frequency_penalty: Penalty for frequent tokens (-2.0 to 2.0).
presence_penalty: Penalty for new tokens (-2.0 to 2.0).
seed: Random seed for deterministic outputs.
temperature: Sampling temperature (0.0 to 2.0).
top_k: Top-k sampling parameter (currently ignored by OpenAI client).
top_p: Top-p (nucleus) sampling parameter (0.0 to 1.0).
max_tokens: Maximum tokens in response.
max_completion_tokens: Maximum completion tokens (not sent to Sarvam API).
service_tier: Service tier (not sent to Sarvam API).
extra: Additional model-specific parameters.
wiki_grounding: Sarvam wiki grounding toggle.
reasoning_effort: Reasoning effort level (low, medium, high).
"""
wiki_grounding: Optional[bool] = None
reasoning_effort: Optional[Literal["low", "medium", "high"]] = None
Settings = SarvamLLMSettings
_settings: SarvamLLMSettings
def __init__(
self,
*,
api_key: str,
model: str = "sarvam-30b",
base_url: str = "https://api.sarvam.ai/v1",
model: Optional[str] = None,
settings: Optional[SarvamLLMSettings] = None,
default_headers: Optional[Mapping[str, str]] = None,
params: Optional[InputParams] = None,
**kwargs,
):
"""Initialize Sarvam LLM service.
Args:
api_key: Sarvam API key used for both OpenAI auth and Sarvam subscription header.
model: Sarvam model identifier. Supported values: ``sarvam-30b``, ``sarvam-105b``.
base_url: Sarvam OpenAI-compatible base URL.
model: Sarvam model identifier. Supported values: ``sarvam-30b``,
``sarvam-30b-16k``, ``sarvam-105b``, ``sarvam-105b-32k``.
.. deprecated:: 0.0.105
Use ``settings=SarvamLLMSettings(model=...)`` instead.
settings: Runtime-updatable settings. When provided alongside deprecated
parameters, ``settings`` values take precedence.
default_headers: Additional HTTP headers to include in requests.
params: Input parameters for model configuration.
**kwargs: Additional keyword arguments passed to ``OpenAILLMService``.
"""
self._validate_model(model)
# 1. Initialize default_settings with hardcoded defaults
default_settings = SarvamLLMSettings(model="sarvam-30b")
params = (params or SarvamLLMService.InputParams()).model_copy(deep=True)
params.extra = self._build_extra_params(params)
# 2. Apply direct init arg overrides (deprecated)
if model is not None:
_warn_deprecated_param("model", SarvamLLMSettings, "model")
default_settings.model = model
# 3. Apply settings delta (canonical API, always wins)
if settings is not None:
default_settings.apply_update(settings)
# BaseOpenAILLMService stores settings as OpenAILLMSettings, so keep
# Sarvam-specific runtime knobs in ``extra``.
default_settings.extra = dict(default_settings.extra)
default_settings.extra.update(self._extract_sarvam_extra_from_settings(default_settings))
self._validate_model(default_settings.model)
super().__init__(
api_key=api_key,
base_url=base_url,
model=model,
settings=default_settings,
default_headers=default_headers,
params=params,
**kwargs,
)
@@ -110,13 +130,11 @@ class SarvamLLMService(OpenAILLMService):
Ensures Sarvam auth and SDK identification headers are always attached.
"""
merged_headers = dict(default_headers or {})
# sdk_headers() carries Pipecat User-Agent and should override caller-provided value.
merged_headers.update(sdk_headers())
if api_key:
merged_headers["api-subscription-key"] = api_key
# Keep SDK User-Agent stable even when caller-provided headers include User-Agent.
merged_headers["User-Agent"] = sdk_headers()["User-Agent"]
logger.debug(f"Creating Sarvam client with API {base_url}")
return super().create_client(
api_key=api_key,
@@ -148,38 +166,63 @@ class SarvamLLMService(OpenAILLMService):
return params
async def get_chat_completions(self, params_from_context: OpenAILLMInvocationParams):
"""Get streaming chat completions with Sarvam raw error passthrough."""
async def _update_settings(self, delta: OpenAILLMSettings) -> dict[str, Any]:
"""Apply settings updates, preserving Sarvam-specific runtime knobs."""
sarvam_extra = self._extract_sarvam_extra_from_settings(delta)
if sarvam_extra:
delta.extra = dict(delta.extra)
delta.extra.update(sarvam_extra)
return await super()._update_settings(delta)
async def _call_with_raw_sarvam_errors(self, awaitable: Awaitable[_T]) -> _T:
"""Await an OpenAI call while preserving Sarvam raw error payloads."""
try:
return await super().get_chat_completions(params_from_context)
return await awaitable
except (APITimeoutError, asyncio.TimeoutError, httpx.TimeoutException):
raise
except Exception as e:
raise RuntimeError(self._format_raw_server_error(e)) from e
async def get_chat_completions(
self, params_from_context: OpenAILLMInvocationParams
) -> AsyncStream[ChatCompletionChunk]:
"""Get streaming chat completions with Sarvam raw error passthrough."""
return await self._call_with_raw_sarvam_errors(
super().get_chat_completions(params_from_context)
)
async def run_inference(
self, context: LLMContext | OpenAILLMContext, max_tokens: Optional[int] = None
self,
context: LLMContext | OpenAILLMContext,
max_tokens: Optional[int] = None,
system_instruction: Optional[str] = None,
) -> Optional[str]:
"""Run one-shot inference and preserve Sarvam raw server errors."""
try:
return await super().run_inference(context, max_tokens=max_tokens)
except (APITimeoutError, asyncio.TimeoutError, httpx.TimeoutException):
raise
except Exception as e:
raise RuntimeError(self._format_raw_server_error(e)) from e
return await self._call_with_raw_sarvam_errors(
super().run_inference(
context,
max_tokens=max_tokens,
system_instruction=system_instruction,
)
)
def _validate_model(self, model: str):
if model not in self.SUPPORTED_MODELS:
allowed = ", ".join(sorted(self.SUPPORTED_MODELS))
raise ValueError(f"Unsupported Sarvam LLM model '{model}'. Allowed values: {allowed}.")
def _build_extra_params(self, params: BaseModel) -> dict[str, Any]:
extra = dict(getattr(params, "extra", {}) or {})
if getattr(params, "wiki_grounding", None) is not None:
extra["wiki_grounding"] = params.wiki_grounding
if getattr(params, "reasoning_effort", None) is not None:
extra["reasoning_effort"] = params.reasoning_effort
return extra
def _extract_sarvam_extra_from_settings(self, settings_obj: Any) -> dict[str, Any]:
updates: dict[str, Any] = {}
wiki_grounding = getattr(settings_obj, "wiki_grounding", _NOT_GIVEN)
if is_given(wiki_grounding):
updates["wiki_grounding"] = wiki_grounding
reasoning_effort = getattr(settings_obj, "reasoning_effort", _NOT_GIVEN)
if is_given(reasoning_effort):
updates["reasoning_effort"] = reasoning_effort
return updates
def _validate_tool_parameters(self, params_from_context: OpenAILLMInvocationParams):
tools = params_from_context.get("tools", NOT_GIVEN)