services: BaseOpenAILLMService.create_client() now returns the client

This commit is contained in:
Aleix Conchillo Flaqué
2024-05-24 09:04:15 -07:00
parent 32f91c5f31
commit 4e594aa9b0
3 changed files with 4 additions and 8 deletions

View File

@@ -84,7 +84,7 @@ class AzureLLMService(BaseOpenAILLMService):
super().__init__(api_key=api_key, model=model)
def create_client(self, api_key=None, base_url=None):
self._client = AsyncAzureOpenAI(
return AsyncAzureOpenAI(
api_key=api_key,
azure_endpoint=self._endpoint,
api_version=self._api_version,

View File

@@ -42,12 +42,8 @@ class GoogleLLMService(LLMService):
def __init__(self, model="gemini-1.5-flash-latest", api_key=None, **kwargs):
super().__init__(**kwargs)
self.model = model
gai.configure(api_key=api_key or os.environ["GOOGLE_API_KEY"])
self.create_client()
def create_client(self):
self._client = gai.GenerativeModel(self.model)
self._client = gai.GenerativeModel(model)
def _get_messages_from_openai_context(
self, context: OpenAILLMContext) -> List[glm.Content]:

View File

@@ -58,10 +58,10 @@ class BaseOpenAILLMService(LLMService):
def __init__(self, model: str, api_key=None, base_url=None):
super().__init__()
self._model: str = model
self.create_client(api_key=api_key, base_url=base_url)
self._client = self.create_client(api_key=api_key, base_url=base_url)
def create_client(self, api_key=None, base_url=None):
self._client = AsyncOpenAI(api_key=api_key, base_url=base_url)
return AsyncOpenAI(api_key=api_key, base_url=base_url)
async def _stream_chat_completions(
self, context: OpenAILLMContext