fix: inference embedding validate (#1187)

This commit is contained in:
takatost 2023-09-16 03:09:36 +08:00 committed by GitHub
parent ec5f585df4
commit c8bd76cd66
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 17 additions and 8 deletions

View File

@ -2,6 +2,7 @@ import json
from typing import Type
import requests
from langchain.embeddings import XinferenceEmbeddings
from core.helper import encrypter
from core.model_providers.models.embedding.xinference_embedding import XinferenceEmbedding
@ -97,11 +98,18 @@ class XinferenceProvider(BaseModelProvider):
'model_uid': credentials['model_uid'],
}
if model_type == ModelType.TEXT_GENERATION:
llm = XinferenceLLM(
**credential_kwargs
)
llm("ping")
elif model_type == ModelType.EMBEDDINGS:
embedding = XinferenceEmbeddings(
**credential_kwargs
)
embedding.embed_query("ping")
except Exception as ex:
raise CredentialsValidateFailedError(str(ex))
@ -117,6 +125,7 @@ class XinferenceProvider(BaseModelProvider):
:param credentials:
:return:
"""
if model_type == ModelType.TEXT_GENERATION:
extra_credentials = cls._get_extra_credentials(credentials)
credentials.update(extra_credentials)

View File

@ -19,7 +19,7 @@ pytest~=7.3.1
pytest-mock~=3.11.1
tiktoken==0.3.3
Authlib==1.2.0
boto3~=1.26.123
boto3==1.28.17
tenacity==8.2.2
cachetools~=5.3.0
weaviate-client~=3.21.0
@ -49,5 +49,5 @@ huggingface_hub~=0.16.4
transformers~=4.31.0
stripe~=5.5.0
pandas==1.5.3
xinference==0.2.1
xinference==0.4.2
safetensors==0.3.2