From c8bd76cd66a94deca2d2b19fd53e342e627b3d07 Mon Sep 17 00:00:00 2001 From: takatost Date: Sat, 16 Sep 2023 03:09:36 +0800 Subject: [PATCH] fix: inference embedding validate (#1187) --- .../providers/xinference_provider.py | 21 +++++++++++++------ api/requirements.txt | 4 ++-- 2 files changed, 17 insertions(+), 8 deletions(-) diff --git a/api/core/model_providers/providers/xinference_provider.py b/api/core/model_providers/providers/xinference_provider.py index 0f3243d806..7c43804c7f 100644 --- a/api/core/model_providers/providers/xinference_provider.py +++ b/api/core/model_providers/providers/xinference_provider.py @@ -2,6 +2,7 @@ import json from typing import Type import requests +from langchain.embeddings import XinferenceEmbeddings from core.helper import encrypter from core.model_providers.models.embedding.xinference_embedding import XinferenceEmbedding @@ -97,11 +98,18 @@ class XinferenceProvider(BaseModelProvider): 'model_uid': credentials['model_uid'], } - llm = XinferenceLLM( - **credential_kwargs - ) + if model_type == ModelType.TEXT_GENERATION: + llm = XinferenceLLM( + **credential_kwargs + ) - llm("ping") + llm("ping") + elif model_type == ModelType.EMBEDDINGS: + embedding = XinferenceEmbeddings( + **credential_kwargs + ) + + embedding.embed_query("ping") except Exception as ex: raise CredentialsValidateFailedError(str(ex)) @@ -117,8 +125,9 @@ class XinferenceProvider(BaseModelProvider): :param credentials: :return: """ - extra_credentials = cls._get_extra_credentials(credentials) - credentials.update(extra_credentials) + if model_type == ModelType.TEXT_GENERATION: + extra_credentials = cls._get_extra_credentials(credentials) + credentials.update(extra_credentials) credentials['server_url'] = encrypter.encrypt_token(tenant_id, credentials['server_url']) diff --git a/api/requirements.txt b/api/requirements.txt index 0d904ced46..26bdce61da 100644 --- a/api/requirements.txt +++ b/api/requirements.txt @@ -19,7 +19,7 @@ pytest~=7.3.1 pytest-mock~=3.11.1 tiktoken==0.3.3 Authlib==1.2.0 -boto3~=1.26.123 +boto3==1.28.17 tenacity==8.2.2 cachetools~=5.3.0 weaviate-client~=3.21.0 @@ -49,5 +49,5 @@ huggingface_hub~=0.16.4 transformers~=4.31.0 stripe~=5.5.0 pandas==1.5.3 -xinference==0.2.1 +xinference==0.4.2 safetensors==0.3.2 \ No newline at end of file