services: BaseOpenAILLMService.create_client() now returns the client
This commit is contained in:
@@ -84,7 +84,7 @@ class AzureLLMService(BaseOpenAILLMService):
|
||||
super().__init__(api_key=api_key, model=model)
|
||||
|
||||
def create_client(self, api_key=None, base_url=None):
|
||||
self._client = AsyncAzureOpenAI(
|
||||
return AsyncAzureOpenAI(
|
||||
api_key=api_key,
|
||||
azure_endpoint=self._endpoint,
|
||||
api_version=self._api_version,
|
||||
|
||||
@@ -42,12 +42,8 @@ class GoogleLLMService(LLMService):
|
||||
|
||||
def __init__(self, model="gemini-1.5-flash-latest", api_key=None, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.model = model
|
||||
gai.configure(api_key=api_key or os.environ["GOOGLE_API_KEY"])
|
||||
self.create_client()
|
||||
|
||||
def create_client(self):
|
||||
self._client = gai.GenerativeModel(self.model)
|
||||
self._client = gai.GenerativeModel(model)
|
||||
|
||||
def _get_messages_from_openai_context(
|
||||
self, context: OpenAILLMContext) -> List[glm.Content]:
|
||||
|
||||
@@ -58,10 +58,10 @@ class BaseOpenAILLMService(LLMService):
|
||||
def __init__(self, model: str, api_key=None, base_url=None):
|
||||
super().__init__()
|
||||
self._model: str = model
|
||||
self.create_client(api_key=api_key, base_url=base_url)
|
||||
self._client = self.create_client(api_key=api_key, base_url=base_url)
|
||||
|
||||
def create_client(self, api_key=None, base_url=None):
|
||||
self._client = AsyncOpenAI(api_key=api_key, base_url=base_url)
|
||||
return AsyncOpenAI(api_key=api_key, base_url=base_url)
|
||||
|
||||
async def _stream_chat_completions(
|
||||
self, context: OpenAILLMContext
|
||||
|
||||
Reference in New Issue
Block a user