diff --git a/src/pipecat/services/sarvam/llm.py b/src/pipecat/services/sarvam/llm.py index 6b440d7f9..c99e3dfb9 100644 --- a/src/pipecat/services/sarvam/llm.py +++ b/src/pipecat/services/sarvam/llm.py @@ -8,20 +8,40 @@ import asyncio import json -from typing import Any, Literal, Mapping, Optional +from dataclasses import dataclass, field +from typing import Any, Awaitable, Literal, Mapping, Optional, TypeVar import httpx from loguru import logger -from openai import NOT_GIVEN, APITimeoutError -from pydantic import BaseModel +from openai import NOT_GIVEN, APITimeoutError, AsyncStream +from openai.types.chat import ChatCompletionChunk from pipecat.adapters.services.open_ai_adapter import OpenAILLMInvocationParams from pipecat.processors.aggregators.llm_context import LLMContext from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext +from pipecat.services.openai.base_llm import OpenAILLMSettings from pipecat.services.openai.llm import OpenAILLMService from pipecat.services.sarvam._sdk import sdk_headers +from pipecat.services.settings import NOT_GIVEN as _NOT_GIVEN +from pipecat.services.settings import _NotGiven, _warn_deprecated_param, is_given -__all__ = ["SarvamLLMService"] +__all__ = ["SarvamLLMService", "SarvamLLMSettings"] +_T = TypeVar("_T") + + +@dataclass +class SarvamLLMSettings(OpenAILLMSettings): + """Settings for SarvamLLMService. + + Parameters: + wiki_grounding: Sarvam wiki grounding toggle. + reasoning_effort: Reasoning effort level (low, medium, high). + """ + + wiki_grounding: bool | None | _NotGiven = field(default_factory=lambda: _NOT_GIVEN) + reasoning_effort: Literal["low", "medium", "high"] | None | _NotGiven = field( + default_factory=lambda: _NOT_GIVEN + ) class SarvamLLMService(OpenAILLMService): @@ -40,59 +60,59 @@ class SarvamLLMService(OpenAILLMService): TOOL_CALLING_MODELS = frozenset( {"sarvam-30b", "sarvam-30b-16k", "sarvam-105b", "sarvam-105b-32k"} ) - - class InputParams(OpenAILLMService.InputParams): - """Configuration parameters for Sarvam LLM service. - - Parameters: - frequency_penalty: Penalty for frequent tokens (-2.0 to 2.0). - presence_penalty: Penalty for new tokens (-2.0 to 2.0). - seed: Random seed for deterministic outputs. - temperature: Sampling temperature (0.0 to 2.0). - top_k: Top-k sampling parameter (currently ignored by OpenAI client). - top_p: Top-p (nucleus) sampling parameter (0.0 to 1.0). - max_tokens: Maximum tokens in response. - max_completion_tokens: Maximum completion tokens (not sent to Sarvam API). - service_tier: Service tier (not sent to Sarvam API). - extra: Additional model-specific parameters. - wiki_grounding: Sarvam wiki grounding toggle. - reasoning_effort: Reasoning effort level (low, medium, high). - """ - - wiki_grounding: Optional[bool] = None - reasoning_effort: Optional[Literal["low", "medium", "high"]] = None + Settings = SarvamLLMSettings + _settings: SarvamLLMSettings def __init__( self, *, api_key: str, - model: str = "sarvam-30b", base_url: str = "https://api.sarvam.ai/v1", + model: Optional[str] = None, + settings: Optional[SarvamLLMSettings] = None, default_headers: Optional[Mapping[str, str]] = None, - params: Optional[InputParams] = None, **kwargs, ): """Initialize Sarvam LLM service. Args: api_key: Sarvam API key used for both OpenAI auth and Sarvam subscription header. - model: Sarvam model identifier. Supported values: ``sarvam-30b``, ``sarvam-105b``. base_url: Sarvam OpenAI-compatible base URL. + model: Sarvam model identifier. Supported values: ``sarvam-30b``, + ``sarvam-30b-16k``, ``sarvam-105b``, ``sarvam-105b-32k``. + + .. deprecated:: 0.0.105 + Use ``settings=SarvamLLMSettings(model=...)`` instead. + + settings: Runtime-updatable settings. When provided alongside deprecated + parameters, ``settings`` values take precedence. default_headers: Additional HTTP headers to include in requests. - params: Input parameters for model configuration. **kwargs: Additional keyword arguments passed to ``OpenAILLMService``. """ - self._validate_model(model) + # 1. Initialize default_settings with hardcoded defaults + default_settings = SarvamLLMSettings(model="sarvam-30b") - params = (params or SarvamLLMService.InputParams()).model_copy(deep=True) - params.extra = self._build_extra_params(params) + # 2. Apply direct init arg overrides (deprecated) + if model is not None: + _warn_deprecated_param("model", SarvamLLMSettings, "model") + default_settings.model = model + + # 3. Apply settings delta (canonical API, always wins) + if settings is not None: + default_settings.apply_update(settings) + + # BaseOpenAILLMService stores settings as OpenAILLMSettings, so keep + # Sarvam-specific runtime knobs in ``extra``. + default_settings.extra = dict(default_settings.extra) + default_settings.extra.update(self._extract_sarvam_extra_from_settings(default_settings)) + + self._validate_model(default_settings.model) super().__init__( api_key=api_key, base_url=base_url, - model=model, + settings=default_settings, default_headers=default_headers, - params=params, **kwargs, ) @@ -110,13 +130,11 @@ class SarvamLLMService(OpenAILLMService): Ensures Sarvam auth and SDK identification headers are always attached. """ merged_headers = dict(default_headers or {}) + # sdk_headers() carries Pipecat User-Agent and should override caller-provided value. merged_headers.update(sdk_headers()) if api_key: merged_headers["api-subscription-key"] = api_key - # Keep SDK User-Agent stable even when caller-provided headers include User-Agent. - merged_headers["User-Agent"] = sdk_headers()["User-Agent"] - logger.debug(f"Creating Sarvam client with API {base_url}") return super().create_client( api_key=api_key, @@ -148,38 +166,63 @@ class SarvamLLMService(OpenAILLMService): return params - async def get_chat_completions(self, params_from_context: OpenAILLMInvocationParams): - """Get streaming chat completions with Sarvam raw error passthrough.""" + async def _update_settings(self, delta: OpenAILLMSettings) -> dict[str, Any]: + """Apply settings updates, preserving Sarvam-specific runtime knobs.""" + sarvam_extra = self._extract_sarvam_extra_from_settings(delta) + if sarvam_extra: + delta.extra = dict(delta.extra) + delta.extra.update(sarvam_extra) + + return await super()._update_settings(delta) + + async def _call_with_raw_sarvam_errors(self, awaitable: Awaitable[_T]) -> _T: + """Await an OpenAI call while preserving Sarvam raw error payloads.""" try: - return await super().get_chat_completions(params_from_context) + return await awaitable except (APITimeoutError, asyncio.TimeoutError, httpx.TimeoutException): raise except Exception as e: raise RuntimeError(self._format_raw_server_error(e)) from e + async def get_chat_completions( + self, params_from_context: OpenAILLMInvocationParams + ) -> AsyncStream[ChatCompletionChunk]: + """Get streaming chat completions with Sarvam raw error passthrough.""" + return await self._call_with_raw_sarvam_errors( + super().get_chat_completions(params_from_context) + ) + async def run_inference( - self, context: LLMContext | OpenAILLMContext, max_tokens: Optional[int] = None + self, + context: LLMContext | OpenAILLMContext, + max_tokens: Optional[int] = None, + system_instruction: Optional[str] = None, ) -> Optional[str]: """Run one-shot inference and preserve Sarvam raw server errors.""" - try: - return await super().run_inference(context, max_tokens=max_tokens) - except (APITimeoutError, asyncio.TimeoutError, httpx.TimeoutException): - raise - except Exception as e: - raise RuntimeError(self._format_raw_server_error(e)) from e + return await self._call_with_raw_sarvam_errors( + super().run_inference( + context, + max_tokens=max_tokens, + system_instruction=system_instruction, + ) + ) def _validate_model(self, model: str): if model not in self.SUPPORTED_MODELS: allowed = ", ".join(sorted(self.SUPPORTED_MODELS)) raise ValueError(f"Unsupported Sarvam LLM model '{model}'. Allowed values: {allowed}.") - def _build_extra_params(self, params: BaseModel) -> dict[str, Any]: - extra = dict(getattr(params, "extra", {}) or {}) - if getattr(params, "wiki_grounding", None) is not None: - extra["wiki_grounding"] = params.wiki_grounding - if getattr(params, "reasoning_effort", None) is not None: - extra["reasoning_effort"] = params.reasoning_effort - return extra + def _extract_sarvam_extra_from_settings(self, settings_obj: Any) -> dict[str, Any]: + updates: dict[str, Any] = {} + wiki_grounding = getattr(settings_obj, "wiki_grounding", _NOT_GIVEN) + if is_given(wiki_grounding): + updates["wiki_grounding"] = wiki_grounding + + reasoning_effort = getattr(settings_obj, "reasoning_effort", _NOT_GIVEN) + if is_given(reasoning_effort): + updates["reasoning_effort"] = reasoning_effort + + return updates def _validate_tool_parameters(self, params_from_context: OpenAILLMInvocationParams): tools = params_from_context.get("tools", NOT_GIVEN)