mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-12 12:29:00 +08:00
feat: server xinference support (#927)
This commit is contained in:
parent
8c991b5b26
commit
da3f10a55e
@ -57,6 +57,9 @@ class ModelProviderFactory:
|
|||||||
elif provider_name == 'huggingface_hub':
|
elif provider_name == 'huggingface_hub':
|
||||||
from core.model_providers.providers.huggingface_hub_provider import HuggingfaceHubProvider
|
from core.model_providers.providers.huggingface_hub_provider import HuggingfaceHubProvider
|
||||||
return HuggingfaceHubProvider
|
return HuggingfaceHubProvider
|
||||||
|
elif provider_name == 'xinference':
|
||||||
|
from core.model_providers.providers.xinference_provider import XinferenceProvider
|
||||||
|
return XinferenceProvider
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
69
api/core/model_providers/models/llm/xinference_model.py
Normal file
69
api/core/model_providers/models/llm/xinference_model.py
Normal file
@ -0,0 +1,69 @@
|
|||||||
|
from typing import List, Optional, Any
|
||||||
|
|
||||||
|
from langchain.callbacks.manager import Callbacks
|
||||||
|
from langchain.llms import Xinference
|
||||||
|
from langchain.schema import LLMResult
|
||||||
|
|
||||||
|
from core.model_providers.error import LLMBadRequestError
|
||||||
|
from core.model_providers.models.llm.base import BaseLLM
|
||||||
|
from core.model_providers.models.entity.message import PromptMessage
|
||||||
|
from core.model_providers.models.entity.model_params import ModelMode, ModelKwargs
|
||||||
|
|
||||||
|
|
||||||
|
class XinferenceModel(BaseLLM):
|
||||||
|
model_mode: ModelMode = ModelMode.COMPLETION
|
||||||
|
|
||||||
|
def _init_client(self) -> Any:
|
||||||
|
self.provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs)
|
||||||
|
|
||||||
|
client = Xinference(
|
||||||
|
**self.credentials,
|
||||||
|
)
|
||||||
|
|
||||||
|
client.callbacks = self.callbacks
|
||||||
|
|
||||||
|
return client
|
||||||
|
|
||||||
|
def _run(self, messages: List[PromptMessage],
|
||||||
|
stop: Optional[List[str]] = None,
|
||||||
|
callbacks: Callbacks = None,
|
||||||
|
**kwargs) -> LLMResult:
|
||||||
|
"""
|
||||||
|
run predict by prompt messages and stop words.
|
||||||
|
|
||||||
|
:param messages:
|
||||||
|
:param stop:
|
||||||
|
:param callbacks:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
prompts = self._get_prompt_from_messages(messages)
|
||||||
|
return self._client.generate(
|
||||||
|
[prompts],
|
||||||
|
stop,
|
||||||
|
callbacks,
|
||||||
|
generate_config={
|
||||||
|
"stop": stop,
|
||||||
|
"stream": self.streaming,
|
||||||
|
**self.provider_model_kwargs,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_num_tokens(self, messages: List[PromptMessage]) -> int:
|
||||||
|
"""
|
||||||
|
get num tokens of prompt messages.
|
||||||
|
|
||||||
|
:param messages:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
prompts = self._get_prompt_from_messages(messages)
|
||||||
|
return max(self._client.get_num_tokens(prompts), 0)
|
||||||
|
|
||||||
|
def _set_model_kwargs(self, model_kwargs: ModelKwargs):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def handle_exceptions(self, ex: Exception) -> Exception:
|
||||||
|
return LLMBadRequestError(f"Xinference: {str(ex)}")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def support_streaming(cls):
|
||||||
|
return True
|
141
api/core/model_providers/providers/xinference_provider.py
Normal file
141
api/core/model_providers/providers/xinference_provider.py
Normal file
@ -0,0 +1,141 @@
|
|||||||
|
import json
|
||||||
|
from typing import Type
|
||||||
|
|
||||||
|
from langchain.llms import Xinference
|
||||||
|
|
||||||
|
from core.helper import encrypter
|
||||||
|
from core.model_providers.models.entity.model_params import KwargRule, ModelKwargsRules, ModelType
|
||||||
|
from core.model_providers.models.llm.xinference_model import XinferenceModel
|
||||||
|
from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
|
||||||
|
|
||||||
|
from core.model_providers.models.base import BaseProviderModel
|
||||||
|
from models.provider import ProviderType
|
||||||
|
|
||||||
|
|
||||||
|
class XinferenceProvider(BaseModelProvider):
|
||||||
|
@property
|
||||||
|
def provider_name(self):
|
||||||
|
"""
|
||||||
|
Returns the name of a provider.
|
||||||
|
"""
|
||||||
|
return 'xinference'
|
||||||
|
|
||||||
|
def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]:
|
||||||
|
return []
|
||||||
|
|
||||||
|
def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]:
|
||||||
|
"""
|
||||||
|
Returns the model class.
|
||||||
|
|
||||||
|
:param model_type:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
if model_type == ModelType.TEXT_GENERATION:
|
||||||
|
model_class = XinferenceModel
|
||||||
|
else:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
return model_class
|
||||||
|
|
||||||
|
def get_model_parameter_rules(self, model_name: str, model_type: ModelType) -> ModelKwargsRules:
|
||||||
|
"""
|
||||||
|
get model parameter rules.
|
||||||
|
|
||||||
|
:param model_name:
|
||||||
|
:param model_type:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
return ModelKwargsRules(
|
||||||
|
temperature=KwargRule[float](min=0, max=2, default=1),
|
||||||
|
top_p=KwargRule[float](min=0, max=1, default=0.7),
|
||||||
|
presence_penalty=KwargRule[float](min=-2, max=2, default=0),
|
||||||
|
frequency_penalty=KwargRule[float](min=-2, max=2, default=0),
|
||||||
|
max_tokens=KwargRule[int](alias='max_token', min=10, max=4000, default=256),
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def is_model_credentials_valid_or_raise(cls, model_name: str, model_type: ModelType, credentials: dict):
|
||||||
|
"""
|
||||||
|
check model credentials valid.
|
||||||
|
|
||||||
|
:param model_name:
|
||||||
|
:param model_type:
|
||||||
|
:param credentials:
|
||||||
|
"""
|
||||||
|
if 'server_url' not in credentials:
|
||||||
|
raise CredentialsValidateFailedError('Xinference Server URL must be provided.')
|
||||||
|
|
||||||
|
if 'model_uid' not in credentials:
|
||||||
|
raise CredentialsValidateFailedError('Xinference Model UID must be provided.')
|
||||||
|
|
||||||
|
try:
|
||||||
|
credential_kwargs = {
|
||||||
|
'server_url': credentials['server_url'],
|
||||||
|
'model_uid': credentials['model_uid'],
|
||||||
|
}
|
||||||
|
|
||||||
|
llm = Xinference(
|
||||||
|
**credential_kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
llm("ping", generate_config={'max_tokens': 10})
|
||||||
|
except Exception as ex:
|
||||||
|
raise CredentialsValidateFailedError(str(ex))
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def encrypt_model_credentials(cls, tenant_id: str, model_name: str, model_type: ModelType,
|
||||||
|
credentials: dict) -> dict:
|
||||||
|
"""
|
||||||
|
encrypt model credentials for save.
|
||||||
|
|
||||||
|
:param tenant_id:
|
||||||
|
:param model_name:
|
||||||
|
:param model_type:
|
||||||
|
:param credentials:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
credentials['server_url'] = encrypter.encrypt_token(tenant_id, credentials['server_url'])
|
||||||
|
return credentials
|
||||||
|
|
||||||
|
def get_model_credentials(self, model_name: str, model_type: ModelType, obfuscated: bool = False) -> dict:
|
||||||
|
"""
|
||||||
|
get credentials for llm use.
|
||||||
|
|
||||||
|
:param model_name:
|
||||||
|
:param model_type:
|
||||||
|
:param obfuscated:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
if self.provider.provider_type != ProviderType.CUSTOM.value:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
provider_model = self._get_provider_model(model_name, model_type)
|
||||||
|
|
||||||
|
if not provider_model.encrypted_config:
|
||||||
|
return {
|
||||||
|
'server_url': None,
|
||||||
|
'model_uid': None,
|
||||||
|
}
|
||||||
|
|
||||||
|
credentials = json.loads(provider_model.encrypted_config)
|
||||||
|
if credentials['server_url']:
|
||||||
|
credentials['server_url'] = encrypter.decrypt_token(
|
||||||
|
self.provider.tenant_id,
|
||||||
|
credentials['server_url']
|
||||||
|
)
|
||||||
|
|
||||||
|
if obfuscated:
|
||||||
|
credentials['server_url'] = encrypter.obfuscated_token(credentials['server_url'])
|
||||||
|
|
||||||
|
return credentials
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def is_provider_credentials_valid_or_raise(cls, credentials: dict):
|
||||||
|
return
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def encrypt_provider_credentials(cls, tenant_id: str, credentials: dict) -> dict:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
def get_provider_credentials(self, obfuscated: bool = False) -> dict:
|
||||||
|
return {}
|
@ -8,5 +8,6 @@
|
|||||||
"wenxin",
|
"wenxin",
|
||||||
"chatglm",
|
"chatglm",
|
||||||
"replicate",
|
"replicate",
|
||||||
"huggingface_hub"
|
"huggingface_hub",
|
||||||
|
"xinference"
|
||||||
]
|
]
|
7
api/core/model_providers/rules/xinference.json
Normal file
7
api/core/model_providers/rules/xinference.json
Normal file
@ -0,0 +1,7 @@
|
|||||||
|
{
|
||||||
|
"support_provider_types": [
|
||||||
|
"custom"
|
||||||
|
],
|
||||||
|
"system_config": null,
|
||||||
|
"model_flexibility": "configurable"
|
||||||
|
}
|
@ -48,4 +48,5 @@ dashscope~=1.5.0
|
|||||||
huggingface_hub~=0.16.4
|
huggingface_hub~=0.16.4
|
||||||
transformers~=4.31.0
|
transformers~=4.31.0
|
||||||
stripe~=5.5.0
|
stripe~=5.5.0
|
||||||
pandas==1.5.3
|
pandas==1.5.3
|
||||||
|
xinference==0.2.0
|
@ -32,4 +32,8 @@ WENXIN_API_KEY=
|
|||||||
WENXIN_SECRET_KEY=
|
WENXIN_SECRET_KEY=
|
||||||
|
|
||||||
# ChatGLM Credentials
|
# ChatGLM Credentials
|
||||||
CHATGLM_API_BASE=
|
CHATGLM_API_BASE=
|
||||||
|
|
||||||
|
# Xinference Credentials
|
||||||
|
XINFERENCE_SERVER_URL=
|
||||||
|
XINFERENCE_MODEL_UID=
|
@ -50,7 +50,9 @@ def test_get_num_tokens(mock_decrypt):
|
|||||||
|
|
||||||
|
|
||||||
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
|
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
|
||||||
def test_run(mock_decrypt):
|
def test_run(mock_decrypt, mocker):
|
||||||
|
mocker.patch('core.model_providers.providers.base.BaseModelProvider.update_last_used', return_value=None)
|
||||||
|
|
||||||
model = get_mock_model('claude-2')
|
model = get_mock_model('claude-2')
|
||||||
messages = [PromptMessage(content='Human: 1 + 1=? \nAssistant: ')]
|
messages = [PromptMessage(content='Human: 1 + 1=? \nAssistant: ')]
|
||||||
rst = model.run(
|
rst = model.run(
|
||||||
@ -58,4 +60,3 @@ def test_run(mock_decrypt):
|
|||||||
stop=['\nHuman:'],
|
stop=['\nHuman:'],
|
||||||
)
|
)
|
||||||
assert len(rst.content) > 0
|
assert len(rst.content) > 0
|
||||||
assert rst.content.strip() == '2'
|
|
||||||
|
@ -76,6 +76,8 @@ def test_chat_get_num_tokens(mock_decrypt, mocker):
|
|||||||
|
|
||||||
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
|
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
|
||||||
def test_run(mock_decrypt, mocker):
|
def test_run(mock_decrypt, mocker):
|
||||||
|
mocker.patch('core.model_providers.providers.base.BaseModelProvider.update_last_used', return_value=None)
|
||||||
|
|
||||||
openai_model = get_mock_azure_openai_model('gpt-35-turbo', mocker)
|
openai_model = get_mock_azure_openai_model('gpt-35-turbo', mocker)
|
||||||
messages = [PromptMessage(content='Human: Are you Human? you MUST only answer `y` or `n`? \nAssistant: ')]
|
messages = [PromptMessage(content='Human: Are you Human? you MUST only answer `y` or `n`? \nAssistant: ')]
|
||||||
rst = openai_model.run(
|
rst = openai_model.run(
|
||||||
@ -83,4 +85,3 @@ def test_run(mock_decrypt, mocker):
|
|||||||
stop=['\nHuman:'],
|
stop=['\nHuman:'],
|
||||||
)
|
)
|
||||||
assert len(rst.content) > 0
|
assert len(rst.content) > 0
|
||||||
assert rst.content.strip() == 'n'
|
|
||||||
|
@ -95,6 +95,8 @@ def test_inference_endpoints_get_num_tokens(mock_decrypt, mock_model_info, mocke
|
|||||||
|
|
||||||
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
|
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
|
||||||
def test_hosted_inference_api_run(mock_decrypt, mocker):
|
def test_hosted_inference_api_run(mock_decrypt, mocker):
|
||||||
|
mocker.patch('core.model_providers.providers.base.BaseModelProvider.update_last_used', return_value=None)
|
||||||
|
|
||||||
model = get_mock_model(
|
model = get_mock_model(
|
||||||
'google/flan-t5-base',
|
'google/flan-t5-base',
|
||||||
'hosted_inference_api',
|
'hosted_inference_api',
|
||||||
@ -111,6 +113,8 @@ def test_hosted_inference_api_run(mock_decrypt, mocker):
|
|||||||
|
|
||||||
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
|
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
|
||||||
def test_inference_endpoints_run(mock_decrypt, mocker):
|
def test_inference_endpoints_run(mock_decrypt, mocker):
|
||||||
|
mocker.patch('core.model_providers.providers.base.BaseModelProvider.update_last_used', return_value=None)
|
||||||
|
|
||||||
model = get_mock_model(
|
model = get_mock_model(
|
||||||
'',
|
'',
|
||||||
'inference_endpoints',
|
'inference_endpoints',
|
||||||
@ -121,4 +125,3 @@ def test_inference_endpoints_run(mock_decrypt, mocker):
|
|||||||
[PromptMessage(content='Answer the following yes/no question. Can you write a whole Haiku in a single tweet?')],
|
[PromptMessage(content='Answer the following yes/no question. Can you write a whole Haiku in a single tweet?')],
|
||||||
)
|
)
|
||||||
assert len(rst.content) > 0
|
assert len(rst.content) > 0
|
||||||
assert rst.content.strip() == 'no'
|
|
||||||
|
@ -54,11 +54,12 @@ def test_get_num_tokens(mock_decrypt):
|
|||||||
|
|
||||||
|
|
||||||
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
|
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
|
||||||
def test_run(mock_decrypt):
|
def test_run(mock_decrypt, mocker):
|
||||||
|
mocker.patch('core.model_providers.providers.base.BaseModelProvider.update_last_used', return_value=None)
|
||||||
|
|
||||||
model = get_mock_model('abab5.5-chat')
|
model = get_mock_model('abab5.5-chat')
|
||||||
rst = model.run(
|
rst = model.run(
|
||||||
[PromptMessage(content='Human: Are you a real Human? you MUST only answer `y` or `n`? \nAssistant: ')],
|
[PromptMessage(content='Human: Are you a real Human? you MUST only answer `y` or `n`? \nAssistant: ')],
|
||||||
stop=['\nHuman:'],
|
stop=['\nHuman:'],
|
||||||
)
|
)
|
||||||
assert len(rst.content) > 0
|
assert len(rst.content) > 0
|
||||||
assert rst.content.strip() == 'n'
|
|
||||||
|
@ -58,7 +58,9 @@ def test_chat_get_num_tokens(mock_decrypt):
|
|||||||
|
|
||||||
|
|
||||||
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
|
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
|
||||||
def test_run(mock_decrypt):
|
def test_run(mock_decrypt, mocker):
|
||||||
|
mocker.patch('core.model_providers.providers.base.BaseModelProvider.update_last_used', return_value=None)
|
||||||
|
|
||||||
openai_model = get_mock_openai_model('text-davinci-003')
|
openai_model = get_mock_openai_model('text-davinci-003')
|
||||||
rst = openai_model.run(
|
rst = openai_model.run(
|
||||||
[PromptMessage(content='Human: Are you Human? you MUST only answer `y` or `n`? \nAssistant: ')],
|
[PromptMessage(content='Human: Are you Human? you MUST only answer `y` or `n`? \nAssistant: ')],
|
||||||
@ -69,7 +71,9 @@ def test_run(mock_decrypt):
|
|||||||
|
|
||||||
|
|
||||||
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
|
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
|
||||||
def test_chat_run(mock_decrypt):
|
def test_chat_run(mock_decrypt, mocker):
|
||||||
|
mocker.patch('core.model_providers.providers.base.BaseModelProvider.update_last_used', return_value=None)
|
||||||
|
|
||||||
openai_model = get_mock_openai_model('gpt-3.5-turbo')
|
openai_model = get_mock_openai_model('gpt-3.5-turbo')
|
||||||
messages = [PromptMessage(content='Human: Are you Human? you MUST only answer `y` or `n`? \nAssistant: ')]
|
messages = [PromptMessage(content='Human: Are you Human? you MUST only answer `y` or `n`? \nAssistant: ')]
|
||||||
rst = openai_model.run(
|
rst = openai_model.run(
|
||||||
@ -77,4 +81,3 @@ def test_chat_run(mock_decrypt):
|
|||||||
stop=['\nHuman:'],
|
stop=['\nHuman:'],
|
||||||
)
|
)
|
||||||
assert len(rst.content) > 0
|
assert len(rst.content) > 0
|
||||||
assert rst.content.strip() == 'n'
|
|
||||||
|
@ -65,6 +65,8 @@ def test_get_num_tokens(mock_decrypt, mocker):
|
|||||||
|
|
||||||
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
|
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
|
||||||
def test_run(mock_decrypt, mocker):
|
def test_run(mock_decrypt, mocker):
|
||||||
|
mocker.patch('core.model_providers.providers.base.BaseModelProvider.update_last_used', return_value=None)
|
||||||
|
|
||||||
model = get_mock_model('a16z-infra/llama-2-13b-chat', '2a7f981751ec7fdf87b5b91ad4db53683a98082e9ff7bfd12c8cd5ea85980a52', mocker)
|
model = get_mock_model('a16z-infra/llama-2-13b-chat', '2a7f981751ec7fdf87b5b91ad4db53683a98082e9ff7bfd12c8cd5ea85980a52', mocker)
|
||||||
messages = [PromptMessage(content='Human: 1+1=? \nAnswer: ')]
|
messages = [PromptMessage(content='Human: 1+1=? \nAnswer: ')]
|
||||||
rst = model.run(
|
rst = model.run(
|
||||||
|
@ -58,7 +58,9 @@ def test_get_num_tokens(mock_decrypt):
|
|||||||
|
|
||||||
|
|
||||||
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
|
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
|
||||||
def test_run(mock_decrypt):
|
def test_run(mock_decrypt, mocker):
|
||||||
|
mocker.patch('core.model_providers.providers.base.BaseModelProvider.update_last_used', return_value=None)
|
||||||
|
|
||||||
model = get_mock_model('spark')
|
model = get_mock_model('spark')
|
||||||
messages = [PromptMessage(content='Human: 1 + 1=? \nAssistant: Integer answer is:')]
|
messages = [PromptMessage(content='Human: 1 + 1=? \nAssistant: Integer answer is:')]
|
||||||
rst = model.run(
|
rst = model.run(
|
||||||
@ -66,4 +68,3 @@ def test_run(mock_decrypt):
|
|||||||
stop=['\nHuman:'],
|
stop=['\nHuman:'],
|
||||||
)
|
)
|
||||||
assert len(rst.content) > 0
|
assert len(rst.content) > 0
|
||||||
assert rst.content.strip() == '2'
|
|
||||||
|
@ -52,7 +52,9 @@ def test_get_num_tokens(mock_decrypt):
|
|||||||
|
|
||||||
|
|
||||||
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
|
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
|
||||||
def test_run(mock_decrypt):
|
def test_run(mock_decrypt, mocker):
|
||||||
|
mocker.patch('core.model_providers.providers.base.BaseModelProvider.update_last_used', return_value=None)
|
||||||
|
|
||||||
model = get_mock_model('qwen-v1')
|
model = get_mock_model('qwen-v1')
|
||||||
rst = model.run(
|
rst = model.run(
|
||||||
[PromptMessage(content='Human: Are you Human? you MUST only answer `y` or `n`? \nAssistant: ')],
|
[PromptMessage(content='Human: Are you Human? you MUST only answer `y` or `n`? \nAssistant: ')],
|
||||||
|
@ -52,7 +52,9 @@ def test_get_num_tokens(mock_decrypt):
|
|||||||
|
|
||||||
|
|
||||||
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
|
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
|
||||||
def test_run(mock_decrypt):
|
def test_run(mock_decrypt, mocker):
|
||||||
|
mocker.patch('core.model_providers.providers.base.BaseModelProvider.update_last_used', return_value=None)
|
||||||
|
|
||||||
model = get_mock_model('ernie-bot')
|
model = get_mock_model('ernie-bot')
|
||||||
messages = [PromptMessage(content='Human: 1 + 1=? \nAssistant: Integer answer is:')]
|
messages = [PromptMessage(content='Human: 1 + 1=? \nAssistant: Integer answer is:')]
|
||||||
rst = model.run(
|
rst = model.run(
|
||||||
@ -60,4 +62,3 @@ def test_run(mock_decrypt):
|
|||||||
stop=['\nHuman:'],
|
stop=['\nHuman:'],
|
||||||
)
|
)
|
||||||
assert len(rst.content) > 0
|
assert len(rst.content) > 0
|
||||||
assert rst.content.strip() == '2'
|
|
||||||
|
@ -0,0 +1,74 @@
|
|||||||
|
import json
|
||||||
|
import os
|
||||||
|
from unittest.mock import patch, MagicMock
|
||||||
|
|
||||||
|
from core.model_providers.models.entity.message import PromptMessage, MessageType
|
||||||
|
from core.model_providers.models.entity.model_params import ModelKwargs, ModelType
|
||||||
|
from core.model_providers.models.llm.xinference_model import XinferenceModel
|
||||||
|
from core.model_providers.providers.xinference_provider import XinferenceProvider
|
||||||
|
from models.provider import Provider, ProviderType, ProviderModel
|
||||||
|
|
||||||
|
|
||||||
|
def get_mock_provider():
|
||||||
|
return Provider(
|
||||||
|
id='provider_id',
|
||||||
|
tenant_id='tenant_id',
|
||||||
|
provider_name='xinference',
|
||||||
|
provider_type=ProviderType.CUSTOM.value,
|
||||||
|
encrypted_config='',
|
||||||
|
is_valid=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_mock_model(model_name, mocker):
|
||||||
|
model_kwargs = ModelKwargs(
|
||||||
|
max_tokens=10,
|
||||||
|
temperature=0.01
|
||||||
|
)
|
||||||
|
server_url = os.environ['XINFERENCE_SERVER_URL']
|
||||||
|
model_uid = os.environ['XINFERENCE_MODEL_UID']
|
||||||
|
model_provider = XinferenceProvider(provider=get_mock_provider())
|
||||||
|
|
||||||
|
mock_query = MagicMock()
|
||||||
|
mock_query.filter.return_value.first.return_value = ProviderModel(
|
||||||
|
provider_name='xinference',
|
||||||
|
model_name=model_name,
|
||||||
|
model_type=ModelType.TEXT_GENERATION.value,
|
||||||
|
encrypted_config=json.dumps({
|
||||||
|
'server_url': server_url,
|
||||||
|
'model_uid': model_uid
|
||||||
|
}),
|
||||||
|
is_valid=True,
|
||||||
|
)
|
||||||
|
mocker.patch('extensions.ext_database.db.session.query', return_value=mock_query)
|
||||||
|
|
||||||
|
return XinferenceModel(
|
||||||
|
model_provider=model_provider,
|
||||||
|
name=model_name,
|
||||||
|
model_kwargs=model_kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def decrypt_side_effect(tenant_id, encrypted_api_key):
|
||||||
|
return encrypted_api_key
|
||||||
|
|
||||||
|
|
||||||
|
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
|
||||||
|
def test_get_num_tokens(mock_decrypt, mocker):
|
||||||
|
model = get_mock_model('llama-2-chat', mocker)
|
||||||
|
rst = model.get_num_tokens([
|
||||||
|
PromptMessage(type=MessageType.HUMAN, content='Who is your manufacturer?')
|
||||||
|
])
|
||||||
|
assert rst == 5
|
||||||
|
|
||||||
|
|
||||||
|
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
|
||||||
|
def test_run(mock_decrypt, mocker):
|
||||||
|
mocker.patch('core.model_providers.providers.base.BaseModelProvider.update_last_used', return_value=None)
|
||||||
|
|
||||||
|
model = get_mock_model('llama-2-chat', mocker)
|
||||||
|
messages = [PromptMessage(content='Human: 1+1=? \nAnswer: ')]
|
||||||
|
rst = model.run(
|
||||||
|
messages
|
||||||
|
)
|
||||||
|
assert len(rst.content) > 0
|
124
api/tests/unit_tests/model_providers/test_xinference_provider.py
Normal file
124
api/tests/unit_tests/model_providers/test_xinference_provider.py
Normal file
@ -0,0 +1,124 @@
|
|||||||
|
import pytest
|
||||||
|
from unittest.mock import patch, MagicMock
|
||||||
|
import json
|
||||||
|
|
||||||
|
from core.model_providers.models.entity.model_params import ModelType
|
||||||
|
from core.model_providers.providers.base import CredentialsValidateFailedError
|
||||||
|
from core.model_providers.providers.replicate_provider import ReplicateProvider
|
||||||
|
from core.model_providers.providers.xinference_provider import XinferenceProvider
|
||||||
|
from models.provider import ProviderType, Provider, ProviderModel
|
||||||
|
|
||||||
|
PROVIDER_NAME = 'xinference'
|
||||||
|
MODEL_PROVIDER_CLASS = XinferenceProvider
|
||||||
|
VALIDATE_CREDENTIAL = {
|
||||||
|
'model_uid': 'fake-model-uid',
|
||||||
|
'server_url': 'http://127.0.0.1:9997/'
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def encrypt_side_effect(tenant_id, encrypt_key):
|
||||||
|
return f'encrypted_{encrypt_key}'
|
||||||
|
|
||||||
|
|
||||||
|
def decrypt_side_effect(tenant_id, encrypted_key):
|
||||||
|
return encrypted_key.replace('encrypted_', '')
|
||||||
|
|
||||||
|
|
||||||
|
def test_is_credentials_valid_or_raise_valid(mocker):
|
||||||
|
mocker.patch('langchain.llms.xinference.Xinference._call',
|
||||||
|
return_value="abc")
|
||||||
|
|
||||||
|
MODEL_PROVIDER_CLASS.is_model_credentials_valid_or_raise(
|
||||||
|
model_name='username/test_model_name',
|
||||||
|
model_type=ModelType.TEXT_GENERATION,
|
||||||
|
credentials=VALIDATE_CREDENTIAL.copy()
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_is_credentials_valid_or_raise_invalid():
|
||||||
|
# raise CredentialsValidateFailedError if replicate_api_token is not in credentials
|
||||||
|
with pytest.raises(CredentialsValidateFailedError):
|
||||||
|
MODEL_PROVIDER_CLASS.is_model_credentials_valid_or_raise(
|
||||||
|
model_name='test_model_name',
|
||||||
|
model_type=ModelType.TEXT_GENERATION,
|
||||||
|
credentials={}
|
||||||
|
)
|
||||||
|
|
||||||
|
# raise CredentialsValidateFailedError if replicate_api_token is invalid
|
||||||
|
with pytest.raises(CredentialsValidateFailedError):
|
||||||
|
MODEL_PROVIDER_CLASS.is_model_credentials_valid_or_raise(
|
||||||
|
model_name='test_model_name',
|
||||||
|
model_type=ModelType.TEXT_GENERATION,
|
||||||
|
credentials={'server_url': 'invalid'})
|
||||||
|
|
||||||
|
|
||||||
|
@patch('core.helper.encrypter.encrypt_token', side_effect=encrypt_side_effect)
|
||||||
|
def test_encrypt_model_credentials(mock_encrypt):
|
||||||
|
api_key = 'http://127.0.0.1:9997/'
|
||||||
|
result = MODEL_PROVIDER_CLASS.encrypt_model_credentials(
|
||||||
|
tenant_id='tenant_id',
|
||||||
|
model_name='test_model_name',
|
||||||
|
model_type=ModelType.TEXT_GENERATION,
|
||||||
|
credentials=VALIDATE_CREDENTIAL.copy()
|
||||||
|
)
|
||||||
|
mock_encrypt.assert_called_with('tenant_id', api_key)
|
||||||
|
assert result['server_url'] == f'encrypted_{api_key}'
|
||||||
|
|
||||||
|
|
||||||
|
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
|
||||||
|
def test_get_model_credentials_custom(mock_decrypt, mocker):
|
||||||
|
provider = Provider(
|
||||||
|
id='provider_id',
|
||||||
|
tenant_id='tenant_id',
|
||||||
|
provider_name=PROVIDER_NAME,
|
||||||
|
provider_type=ProviderType.CUSTOM.value,
|
||||||
|
encrypted_config=None,
|
||||||
|
is_valid=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
encrypted_credential = VALIDATE_CREDENTIAL.copy()
|
||||||
|
encrypted_credential['server_url'] = 'encrypted_' + encrypted_credential['server_url']
|
||||||
|
|
||||||
|
mock_query = MagicMock()
|
||||||
|
mock_query.filter.return_value.first.return_value = ProviderModel(
|
||||||
|
encrypted_config=json.dumps(encrypted_credential)
|
||||||
|
)
|
||||||
|
mocker.patch('extensions.ext_database.db.session.query', return_value=mock_query)
|
||||||
|
|
||||||
|
model_provider = MODEL_PROVIDER_CLASS(provider=provider)
|
||||||
|
result = model_provider.get_model_credentials(
|
||||||
|
model_name='test_model_name',
|
||||||
|
model_type=ModelType.TEXT_GENERATION
|
||||||
|
)
|
||||||
|
assert result['server_url'] == 'http://127.0.0.1:9997/'
|
||||||
|
|
||||||
|
|
||||||
|
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
|
||||||
|
def test_get_model_credentials_obfuscated(mock_decrypt, mocker):
|
||||||
|
provider = Provider(
|
||||||
|
id='provider_id',
|
||||||
|
tenant_id='tenant_id',
|
||||||
|
provider_name=PROVIDER_NAME,
|
||||||
|
provider_type=ProviderType.CUSTOM.value,
|
||||||
|
encrypted_config=None,
|
||||||
|
is_valid=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
encrypted_credential = VALIDATE_CREDENTIAL.copy()
|
||||||
|
encrypted_credential['server_url'] = 'encrypted_' + encrypted_credential['server_url']
|
||||||
|
|
||||||
|
mock_query = MagicMock()
|
||||||
|
mock_query.filter.return_value.first.return_value = ProviderModel(
|
||||||
|
encrypted_config=json.dumps(encrypted_credential)
|
||||||
|
)
|
||||||
|
mocker.patch('extensions.ext_database.db.session.query', return_value=mock_query)
|
||||||
|
|
||||||
|
model_provider = MODEL_PROVIDER_CLASS(provider=provider)
|
||||||
|
result = model_provider.get_model_credentials(
|
||||||
|
model_name='test_model_name',
|
||||||
|
model_type=ModelType.TEXT_GENERATION,
|
||||||
|
obfuscated=True
|
||||||
|
)
|
||||||
|
middle_token = result['server_url'][6:-2]
|
||||||
|
assert len(middle_token) == max(len(VALIDATE_CREDENTIAL['server_url']) - 8, 0)
|
||||||
|
assert all(char == '*' for char in middle_token)
|
Loading…
x
Reference in New Issue
Block a user