mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-07-05 16:05:12 +08:00
fix: inference embedding validate (#1187)
This commit is contained in:
parent
ec5f585df4
commit
c8bd76cd66
@ -2,6 +2,7 @@ import json
|
||||
from typing import Type
|
||||
|
||||
import requests
|
||||
from langchain.embeddings import XinferenceEmbeddings
|
||||
|
||||
from core.helper import encrypter
|
||||
from core.model_providers.models.embedding.xinference_embedding import XinferenceEmbedding
|
||||
@ -97,11 +98,18 @@ class XinferenceProvider(BaseModelProvider):
|
||||
'model_uid': credentials['model_uid'],
|
||||
}
|
||||
|
||||
if model_type == ModelType.TEXT_GENERATION:
|
||||
llm = XinferenceLLM(
|
||||
**credential_kwargs
|
||||
)
|
||||
|
||||
llm("ping")
|
||||
elif model_type == ModelType.EMBEDDINGS:
|
||||
embedding = XinferenceEmbeddings(
|
||||
**credential_kwargs
|
||||
)
|
||||
|
||||
embedding.embed_query("ping")
|
||||
except Exception as ex:
|
||||
raise CredentialsValidateFailedError(str(ex))
|
||||
|
||||
@ -117,6 +125,7 @@ class XinferenceProvider(BaseModelProvider):
|
||||
:param credentials:
|
||||
:return:
|
||||
"""
|
||||
if model_type == ModelType.TEXT_GENERATION:
|
||||
extra_credentials = cls._get_extra_credentials(credentials)
|
||||
credentials.update(extra_credentials)
|
||||
|
||||
|
@ -19,7 +19,7 @@ pytest~=7.3.1
|
||||
pytest-mock~=3.11.1
|
||||
tiktoken==0.3.3
|
||||
Authlib==1.2.0
|
||||
boto3~=1.26.123
|
||||
boto3==1.28.17
|
||||
tenacity==8.2.2
|
||||
cachetools~=5.3.0
|
||||
weaviate-client~=3.21.0
|
||||
@ -49,5 +49,5 @@ huggingface_hub~=0.16.4
|
||||
transformers~=4.31.0
|
||||
stripe~=5.5.0
|
||||
pandas==1.5.3
|
||||
xinference==0.2.1
|
||||
xinference==0.4.2
|
||||
safetensors==0.3.2
|
Loading…
x
Reference in New Issue
Block a user