Input params for Together AI LLM
This commit is contained in:
@@ -13,6 +13,7 @@ from dataclasses import dataclass
|
||||
from asyncio import CancelledError
|
||||
import re
|
||||
import uuid
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
Frame,
|
||||
@@ -58,18 +59,30 @@ class TogetherContextAggregatorPair:
|
||||
class TogetherLLMService(LLMService):
|
||||
"""This class implements inference with Together's Llama 3.1 models
|
||||
"""
|
||||
class InputParams(BaseModel):
|
||||
frequency_penalty: Optional[float] = Field(default=None, ge=-2.0, le=2.0)
|
||||
max_tokens: Optional[int] = Field(default=4096, ge=1)
|
||||
presence_penalty: Optional[float] = Field(default=None, ge=-2.0, le=2.0)
|
||||
temperature: Optional[float] = Field(default=None, ge=0.0, le=1.0)
|
||||
top_k: Optional[int] = Field(default=None, ge=0)
|
||||
top_p: Optional[float] = Field(default=None, ge=0.0, le=1.0)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
api_key: str,
|
||||
model: str = "meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo",
|
||||
max_tokens: int = 4096,
|
||||
params: InputParams = InputParams(),
|
||||
**kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self._client = AsyncTogether(api_key=api_key)
|
||||
self.set_model_name(model)
|
||||
self._max_tokens = max_tokens
|
||||
self._max_tokens = params.max_tokens
|
||||
self._frequency_penalty = params.frequency_penalty
|
||||
self._presence_penalty = params.presence_penalty
|
||||
self._temperature = params.temperature
|
||||
self._top_k = params.top_k
|
||||
self._top_p = params.top_p
|
||||
|
||||
def can_generate_metrics(self) -> bool:
|
||||
return True
|
||||
@@ -83,6 +96,30 @@ class TogetherLLMService(LLMService):
|
||||
_assistant=assistant
|
||||
)
|
||||
|
||||
async def set_frequency_penalty(self, frequency_penalty: float):
|
||||
logger.debug(f"Switching LLM frequency_penalty to: [{frequency_penalty}]")
|
||||
self._frequency_penalty = frequency_penalty
|
||||
|
||||
async def set_max_tokens(self, max_tokens: int):
|
||||
logger.debug(f"Switching LLM max_tokens to: [{max_tokens}]")
|
||||
self._max_tokens = max_tokens
|
||||
|
||||
async def set_presence_penalty(self, presence_penalty: float):
|
||||
logger.debug(f"Switching LLM presence_penalty to: [{presence_penalty}]")
|
||||
self._presence_penalty = presence_penalty
|
||||
|
||||
async def set_temperature(self, temperature: float):
|
||||
logger.debug(f"Switching LLM temperature to: [{temperature}]")
|
||||
self._temperature = temperature
|
||||
|
||||
async def set_top_k(self, top_k: float):
|
||||
logger.debug(f"Switching LLM top_k to: [{top_k}]")
|
||||
self._top_k = top_k
|
||||
|
||||
async def set_top_p(self, top_p: float):
|
||||
logger.debug(f"Switching LLM top_p to: [{top_p}]")
|
||||
self._top_p = top_p
|
||||
|
||||
async def _process_context(self, context: OpenAILLMContext):
|
||||
try:
|
||||
await self.push_frame(LLMFullResponseStartFrame())
|
||||
@@ -97,6 +134,11 @@ class TogetherLLMService(LLMService):
|
||||
model=self.model_name,
|
||||
max_tokens=self._max_tokens,
|
||||
stream=True,
|
||||
frequency_penalty=self._frequency_penalty,
|
||||
presence_penalty=self._presence_penalty,
|
||||
temperature=self._temperature,
|
||||
top_k=self._top_k,
|
||||
top_p=self._top_p
|
||||
)
|
||||
|
||||
# Function calling
|
||||
|
||||
Reference in New Issue
Block a user