fix llm wrapper redundancy and restore run_inference parity
This commit is contained in:
@@ -8,20 +8,40 @@
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from typing import Any, Literal, Mapping, Optional
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Awaitable, Literal, Mapping, Optional, TypeVar
|
||||
|
||||
import httpx
|
||||
from loguru import logger
|
||||
from openai import NOT_GIVEN, APITimeoutError
|
||||
from pydantic import BaseModel
|
||||
from openai import NOT_GIVEN, APITimeoutError, AsyncStream
|
||||
from openai.types.chat import ChatCompletionChunk
|
||||
|
||||
from pipecat.adapters.services.open_ai_adapter import OpenAILLMInvocationParams
|
||||
from pipecat.processors.aggregators.llm_context import LLMContext
|
||||
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
|
||||
from pipecat.services.openai.base_llm import OpenAILLMSettings
|
||||
from pipecat.services.openai.llm import OpenAILLMService
|
||||
from pipecat.services.sarvam._sdk import sdk_headers
|
||||
from pipecat.services.settings import NOT_GIVEN as _NOT_GIVEN
|
||||
from pipecat.services.settings import _NotGiven, _warn_deprecated_param, is_given
|
||||
|
||||
__all__ = ["SarvamLLMService"]
|
||||
__all__ = ["SarvamLLMService", "SarvamLLMSettings"]
|
||||
_T = TypeVar("_T")
|
||||
|
||||
|
||||
@dataclass
|
||||
class SarvamLLMSettings(OpenAILLMSettings):
|
||||
"""Settings for SarvamLLMService.
|
||||
|
||||
Parameters:
|
||||
wiki_grounding: Sarvam wiki grounding toggle.
|
||||
reasoning_effort: Reasoning effort level (low, medium, high).
|
||||
"""
|
||||
|
||||
wiki_grounding: bool | None | _NotGiven = field(default_factory=lambda: _NOT_GIVEN)
|
||||
reasoning_effort: Literal["low", "medium", "high"] | None | _NotGiven = field(
|
||||
default_factory=lambda: _NOT_GIVEN
|
||||
)
|
||||
|
||||
|
||||
class SarvamLLMService(OpenAILLMService):
|
||||
@@ -40,59 +60,59 @@ class SarvamLLMService(OpenAILLMService):
|
||||
TOOL_CALLING_MODELS = frozenset(
|
||||
{"sarvam-30b", "sarvam-30b-16k", "sarvam-105b", "sarvam-105b-32k"}
|
||||
)
|
||||
|
||||
class InputParams(OpenAILLMService.InputParams):
|
||||
"""Configuration parameters for Sarvam LLM service.
|
||||
|
||||
Parameters:
|
||||
frequency_penalty: Penalty for frequent tokens (-2.0 to 2.0).
|
||||
presence_penalty: Penalty for new tokens (-2.0 to 2.0).
|
||||
seed: Random seed for deterministic outputs.
|
||||
temperature: Sampling temperature (0.0 to 2.0).
|
||||
top_k: Top-k sampling parameter (currently ignored by OpenAI client).
|
||||
top_p: Top-p (nucleus) sampling parameter (0.0 to 1.0).
|
||||
max_tokens: Maximum tokens in response.
|
||||
max_completion_tokens: Maximum completion tokens (not sent to Sarvam API).
|
||||
service_tier: Service tier (not sent to Sarvam API).
|
||||
extra: Additional model-specific parameters.
|
||||
wiki_grounding: Sarvam wiki grounding toggle.
|
||||
reasoning_effort: Reasoning effort level (low, medium, high).
|
||||
"""
|
||||
|
||||
wiki_grounding: Optional[bool] = None
|
||||
reasoning_effort: Optional[Literal["low", "medium", "high"]] = None
|
||||
Settings = SarvamLLMSettings
|
||||
_settings: SarvamLLMSettings
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
api_key: str,
|
||||
model: str = "sarvam-30b",
|
||||
base_url: str = "https://api.sarvam.ai/v1",
|
||||
model: Optional[str] = None,
|
||||
settings: Optional[SarvamLLMSettings] = None,
|
||||
default_headers: Optional[Mapping[str, str]] = None,
|
||||
params: Optional[InputParams] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""Initialize Sarvam LLM service.
|
||||
|
||||
Args:
|
||||
api_key: Sarvam API key used for both OpenAI auth and Sarvam subscription header.
|
||||
model: Sarvam model identifier. Supported values: ``sarvam-30b``, ``sarvam-105b``.
|
||||
base_url: Sarvam OpenAI-compatible base URL.
|
||||
model: Sarvam model identifier. Supported values: ``sarvam-30b``,
|
||||
``sarvam-30b-16k``, ``sarvam-105b``, ``sarvam-105b-32k``.
|
||||
|
||||
.. deprecated:: 0.0.105
|
||||
Use ``settings=SarvamLLMSettings(model=...)`` instead.
|
||||
|
||||
settings: Runtime-updatable settings. When provided alongside deprecated
|
||||
parameters, ``settings`` values take precedence.
|
||||
default_headers: Additional HTTP headers to include in requests.
|
||||
params: Input parameters for model configuration.
|
||||
**kwargs: Additional keyword arguments passed to ``OpenAILLMService``.
|
||||
"""
|
||||
self._validate_model(model)
|
||||
# 1. Initialize default_settings with hardcoded defaults
|
||||
default_settings = SarvamLLMSettings(model="sarvam-30b")
|
||||
|
||||
params = (params or SarvamLLMService.InputParams()).model_copy(deep=True)
|
||||
params.extra = self._build_extra_params(params)
|
||||
# 2. Apply direct init arg overrides (deprecated)
|
||||
if model is not None:
|
||||
_warn_deprecated_param("model", SarvamLLMSettings, "model")
|
||||
default_settings.model = model
|
||||
|
||||
# 3. Apply settings delta (canonical API, always wins)
|
||||
if settings is not None:
|
||||
default_settings.apply_update(settings)
|
||||
|
||||
# BaseOpenAILLMService stores settings as OpenAILLMSettings, so keep
|
||||
# Sarvam-specific runtime knobs in ``extra``.
|
||||
default_settings.extra = dict(default_settings.extra)
|
||||
default_settings.extra.update(self._extract_sarvam_extra_from_settings(default_settings))
|
||||
|
||||
self._validate_model(default_settings.model)
|
||||
|
||||
super().__init__(
|
||||
api_key=api_key,
|
||||
base_url=base_url,
|
||||
model=model,
|
||||
settings=default_settings,
|
||||
default_headers=default_headers,
|
||||
params=params,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@@ -110,13 +130,11 @@ class SarvamLLMService(OpenAILLMService):
|
||||
Ensures Sarvam auth and SDK identification headers are always attached.
|
||||
"""
|
||||
merged_headers = dict(default_headers or {})
|
||||
# sdk_headers() carries Pipecat User-Agent and should override caller-provided value.
|
||||
merged_headers.update(sdk_headers())
|
||||
if api_key:
|
||||
merged_headers["api-subscription-key"] = api_key
|
||||
|
||||
# Keep SDK User-Agent stable even when caller-provided headers include User-Agent.
|
||||
merged_headers["User-Agent"] = sdk_headers()["User-Agent"]
|
||||
|
||||
logger.debug(f"Creating Sarvam client with API {base_url}")
|
||||
return super().create_client(
|
||||
api_key=api_key,
|
||||
@@ -148,38 +166,63 @@ class SarvamLLMService(OpenAILLMService):
|
||||
|
||||
return params
|
||||
|
||||
async def get_chat_completions(self, params_from_context: OpenAILLMInvocationParams):
|
||||
"""Get streaming chat completions with Sarvam raw error passthrough."""
|
||||
async def _update_settings(self, delta: OpenAILLMSettings) -> dict[str, Any]:
|
||||
"""Apply settings updates, preserving Sarvam-specific runtime knobs."""
|
||||
sarvam_extra = self._extract_sarvam_extra_from_settings(delta)
|
||||
if sarvam_extra:
|
||||
delta.extra = dict(delta.extra)
|
||||
delta.extra.update(sarvam_extra)
|
||||
|
||||
return await super()._update_settings(delta)
|
||||
|
||||
async def _call_with_raw_sarvam_errors(self, awaitable: Awaitable[_T]) -> _T:
|
||||
"""Await an OpenAI call while preserving Sarvam raw error payloads."""
|
||||
try:
|
||||
return await super().get_chat_completions(params_from_context)
|
||||
return await awaitable
|
||||
except (APITimeoutError, asyncio.TimeoutError, httpx.TimeoutException):
|
||||
raise
|
||||
except Exception as e:
|
||||
raise RuntimeError(self._format_raw_server_error(e)) from e
|
||||
|
||||
async def get_chat_completions(
|
||||
self, params_from_context: OpenAILLMInvocationParams
|
||||
) -> AsyncStream[ChatCompletionChunk]:
|
||||
"""Get streaming chat completions with Sarvam raw error passthrough."""
|
||||
return await self._call_with_raw_sarvam_errors(
|
||||
super().get_chat_completions(params_from_context)
|
||||
)
|
||||
|
||||
async def run_inference(
|
||||
self, context: LLMContext | OpenAILLMContext, max_tokens: Optional[int] = None
|
||||
self,
|
||||
context: LLMContext | OpenAILLMContext,
|
||||
max_tokens: Optional[int] = None,
|
||||
system_instruction: Optional[str] = None,
|
||||
) -> Optional[str]:
|
||||
"""Run one-shot inference and preserve Sarvam raw server errors."""
|
||||
try:
|
||||
return await super().run_inference(context, max_tokens=max_tokens)
|
||||
except (APITimeoutError, asyncio.TimeoutError, httpx.TimeoutException):
|
||||
raise
|
||||
except Exception as e:
|
||||
raise RuntimeError(self._format_raw_server_error(e)) from e
|
||||
return await self._call_with_raw_sarvam_errors(
|
||||
super().run_inference(
|
||||
context,
|
||||
max_tokens=max_tokens,
|
||||
system_instruction=system_instruction,
|
||||
)
|
||||
)
|
||||
|
||||
def _validate_model(self, model: str):
|
||||
if model not in self.SUPPORTED_MODELS:
|
||||
allowed = ", ".join(sorted(self.SUPPORTED_MODELS))
|
||||
raise ValueError(f"Unsupported Sarvam LLM model '{model}'. Allowed values: {allowed}.")
|
||||
|
||||
def _build_extra_params(self, params: BaseModel) -> dict[str, Any]:
|
||||
extra = dict(getattr(params, "extra", {}) or {})
|
||||
if getattr(params, "wiki_grounding", None) is not None:
|
||||
extra["wiki_grounding"] = params.wiki_grounding
|
||||
if getattr(params, "reasoning_effort", None) is not None:
|
||||
extra["reasoning_effort"] = params.reasoning_effort
|
||||
return extra
|
||||
def _extract_sarvam_extra_from_settings(self, settings_obj: Any) -> dict[str, Any]:
|
||||
updates: dict[str, Any] = {}
|
||||
wiki_grounding = getattr(settings_obj, "wiki_grounding", _NOT_GIVEN)
|
||||
if is_given(wiki_grounding):
|
||||
updates["wiki_grounding"] = wiki_grounding
|
||||
|
||||
reasoning_effort = getattr(settings_obj, "reasoning_effort", _NOT_GIVEN)
|
||||
if is_given(reasoning_effort):
|
||||
updates["reasoning_effort"] = reasoning_effort
|
||||
|
||||
return updates
|
||||
|
||||
def _validate_tool_parameters(self, params_from_context: OpenAILLMInvocationParams):
|
||||
tools = params_from_context.get("tools", NOT_GIVEN)
|
||||
|
||||
Reference in New Issue
Block a user