From da3f10a55e121f0249fbeab6788c9d975f72ae3e Mon Sep 17 00:00:00 2001 From: takatost Date: Sun, 20 Aug 2023 17:46:41 +0800 Subject: [PATCH] feat: server xinference support (#927) --- .../model_providers/model_provider_factory.py | 3 + .../models/llm/xinference_model.py | 69 +++++++++ .../providers/xinference_provider.py | 141 ++++++++++++++++++ .../model_providers/rules/_providers.json | 3 +- .../model_providers/rules/xinference.json | 7 + api/requirements.txt | 3 +- api/tests/integration_tests/.env.example | 6 +- .../models/llm/test_anthropic_model.py | 5 +- .../models/llm/test_azure_openai_model.py | 3 +- .../models/llm/test_huggingface_hub_model.py | 5 +- .../models/llm/test_minimax_model.py | 5 +- .../models/llm/test_openai_model.py | 9 +- .../models/llm/test_replicate_model.py | 2 + .../models/llm/test_spark_model.py | 5 +- .../models/llm/test_tongyi_model.py | 4 +- .../models/llm/test_wenxin_model.py | 5 +- .../models/llm/test_xinference_model.py | 74 +++++++++ .../test_xinference_provider.py | 124 +++++++++++++++ 18 files changed, 456 insertions(+), 17 deletions(-) create mode 100644 api/core/model_providers/models/llm/xinference_model.py create mode 100644 api/core/model_providers/providers/xinference_provider.py create mode 100644 api/core/model_providers/rules/xinference.json create mode 100644 api/tests/integration_tests/models/llm/test_xinference_model.py create mode 100644 api/tests/unit_tests/model_providers/test_xinference_provider.py diff --git a/api/core/model_providers/model_provider_factory.py b/api/core/model_providers/model_provider_factory.py index 6cb2f8fa46..c04f07f4d9 100644 --- a/api/core/model_providers/model_provider_factory.py +++ b/api/core/model_providers/model_provider_factory.py @@ -57,6 +57,9 @@ class ModelProviderFactory: elif provider_name == 'huggingface_hub': from core.model_providers.providers.huggingface_hub_provider import HuggingfaceHubProvider return HuggingfaceHubProvider + elif provider_name == 'xinference': + from core.model_providers.providers.xinference_provider import XinferenceProvider + return XinferenceProvider else: raise NotImplementedError diff --git a/api/core/model_providers/models/llm/xinference_model.py b/api/core/model_providers/models/llm/xinference_model.py new file mode 100644 index 0000000000..ef3a83c352 --- /dev/null +++ b/api/core/model_providers/models/llm/xinference_model.py @@ -0,0 +1,69 @@ +from typing import List, Optional, Any + +from langchain.callbacks.manager import Callbacks +from langchain.llms import Xinference +from langchain.schema import LLMResult + +from core.model_providers.error import LLMBadRequestError +from core.model_providers.models.llm.base import BaseLLM +from core.model_providers.models.entity.message import PromptMessage +from core.model_providers.models.entity.model_params import ModelMode, ModelKwargs + + +class XinferenceModel(BaseLLM): + model_mode: ModelMode = ModelMode.COMPLETION + + def _init_client(self) -> Any: + self.provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs) + + client = Xinference( + **self.credentials, + ) + + client.callbacks = self.callbacks + + return client + + def _run(self, messages: List[PromptMessage], + stop: Optional[List[str]] = None, + callbacks: Callbacks = None, + **kwargs) -> LLMResult: + """ + run predict by prompt messages and stop words. + + :param messages: + :param stop: + :param callbacks: + :return: + """ + prompts = self._get_prompt_from_messages(messages) + return self._client.generate( + [prompts], + stop, + callbacks, + generate_config={ + "stop": stop, + "stream": self.streaming, + **self.provider_model_kwargs, + } + ) + + def get_num_tokens(self, messages: List[PromptMessage]) -> int: + """ + get num tokens of prompt messages. + + :param messages: + :return: + """ + prompts = self._get_prompt_from_messages(messages) + return max(self._client.get_num_tokens(prompts), 0) + + def _set_model_kwargs(self, model_kwargs: ModelKwargs): + pass + + def handle_exceptions(self, ex: Exception) -> Exception: + return LLMBadRequestError(f"Xinference: {str(ex)}") + + @classmethod + def support_streaming(cls): + return True diff --git a/api/core/model_providers/providers/xinference_provider.py b/api/core/model_providers/providers/xinference_provider.py new file mode 100644 index 0000000000..4589b3f853 --- /dev/null +++ b/api/core/model_providers/providers/xinference_provider.py @@ -0,0 +1,141 @@ +import json +from typing import Type + +from langchain.llms import Xinference + +from core.helper import encrypter +from core.model_providers.models.entity.model_params import KwargRule, ModelKwargsRules, ModelType +from core.model_providers.models.llm.xinference_model import XinferenceModel +from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError + +from core.model_providers.models.base import BaseProviderModel +from models.provider import ProviderType + + +class XinferenceProvider(BaseModelProvider): + @property + def provider_name(self): + """ + Returns the name of a provider. + """ + return 'xinference' + + def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]: + return [] + + def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]: + """ + Returns the model class. + + :param model_type: + :return: + """ + if model_type == ModelType.TEXT_GENERATION: + model_class = XinferenceModel + else: + raise NotImplementedError + + return model_class + + def get_model_parameter_rules(self, model_name: str, model_type: ModelType) -> ModelKwargsRules: + """ + get model parameter rules. + + :param model_name: + :param model_type: + :return: + """ + return ModelKwargsRules( + temperature=KwargRule[float](min=0, max=2, default=1), + top_p=KwargRule[float](min=0, max=1, default=0.7), + presence_penalty=KwargRule[float](min=-2, max=2, default=0), + frequency_penalty=KwargRule[float](min=-2, max=2, default=0), + max_tokens=KwargRule[int](alias='max_token', min=10, max=4000, default=256), + ) + + @classmethod + def is_model_credentials_valid_or_raise(cls, model_name: str, model_type: ModelType, credentials: dict): + """ + check model credentials valid. + + :param model_name: + :param model_type: + :param credentials: + """ + if 'server_url' not in credentials: + raise CredentialsValidateFailedError('Xinference Server URL must be provided.') + + if 'model_uid' not in credentials: + raise CredentialsValidateFailedError('Xinference Model UID must be provided.') + + try: + credential_kwargs = { + 'server_url': credentials['server_url'], + 'model_uid': credentials['model_uid'], + } + + llm = Xinference( + **credential_kwargs + ) + + llm("ping", generate_config={'max_tokens': 10}) + except Exception as ex: + raise CredentialsValidateFailedError(str(ex)) + + @classmethod + def encrypt_model_credentials(cls, tenant_id: str, model_name: str, model_type: ModelType, + credentials: dict) -> dict: + """ + encrypt model credentials for save. + + :param tenant_id: + :param model_name: + :param model_type: + :param credentials: + :return: + """ + credentials['server_url'] = encrypter.encrypt_token(tenant_id, credentials['server_url']) + return credentials + + def get_model_credentials(self, model_name: str, model_type: ModelType, obfuscated: bool = False) -> dict: + """ + get credentials for llm use. + + :param model_name: + :param model_type: + :param obfuscated: + :return: + """ + if self.provider.provider_type != ProviderType.CUSTOM.value: + raise NotImplementedError + + provider_model = self._get_provider_model(model_name, model_type) + + if not provider_model.encrypted_config: + return { + 'server_url': None, + 'model_uid': None, + } + + credentials = json.loads(provider_model.encrypted_config) + if credentials['server_url']: + credentials['server_url'] = encrypter.decrypt_token( + self.provider.tenant_id, + credentials['server_url'] + ) + + if obfuscated: + credentials['server_url'] = encrypter.obfuscated_token(credentials['server_url']) + + return credentials + + @classmethod + def is_provider_credentials_valid_or_raise(cls, credentials: dict): + return + + @classmethod + def encrypt_provider_credentials(cls, tenant_id: str, credentials: dict) -> dict: + return {} + + def get_provider_credentials(self, obfuscated: bool = False) -> dict: + return {} diff --git a/api/core/model_providers/rules/_providers.json b/api/core/model_providers/rules/_providers.json index ad53f425ce..7fb4cce70c 100644 --- a/api/core/model_providers/rules/_providers.json +++ b/api/core/model_providers/rules/_providers.json @@ -8,5 +8,6 @@ "wenxin", "chatglm", "replicate", - "huggingface_hub" + "huggingface_hub", + "xinference" ] \ No newline at end of file diff --git a/api/core/model_providers/rules/xinference.json b/api/core/model_providers/rules/xinference.json new file mode 100644 index 0000000000..5badb07178 --- /dev/null +++ b/api/core/model_providers/rules/xinference.json @@ -0,0 +1,7 @@ +{ + "support_provider_types": [ + "custom" + ], + "system_config": null, + "model_flexibility": "configurable" +} \ No newline at end of file diff --git a/api/requirements.txt b/api/requirements.txt index d7d546c856..43028d9225 100644 --- a/api/requirements.txt +++ b/api/requirements.txt @@ -48,4 +48,5 @@ dashscope~=1.5.0 huggingface_hub~=0.16.4 transformers~=4.31.0 stripe~=5.5.0 -pandas==1.5.3 \ No newline at end of file +pandas==1.5.3 +xinference==0.2.0 \ No newline at end of file diff --git a/api/tests/integration_tests/.env.example b/api/tests/integration_tests/.env.example index f1ee239415..c24ec68c30 100644 --- a/api/tests/integration_tests/.env.example +++ b/api/tests/integration_tests/.env.example @@ -32,4 +32,8 @@ WENXIN_API_KEY= WENXIN_SECRET_KEY= # ChatGLM Credentials -CHATGLM_API_BASE= \ No newline at end of file +CHATGLM_API_BASE= + +# Xinference Credentials +XINFERENCE_SERVER_URL= +XINFERENCE_MODEL_UID= \ No newline at end of file diff --git a/api/tests/integration_tests/models/llm/test_anthropic_model.py b/api/tests/integration_tests/models/llm/test_anthropic_model.py index 86cfe9922d..32013b27aa 100644 --- a/api/tests/integration_tests/models/llm/test_anthropic_model.py +++ b/api/tests/integration_tests/models/llm/test_anthropic_model.py @@ -50,7 +50,9 @@ def test_get_num_tokens(mock_decrypt): @patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect) -def test_run(mock_decrypt): +def test_run(mock_decrypt, mocker): + mocker.patch('core.model_providers.providers.base.BaseModelProvider.update_last_used', return_value=None) + model = get_mock_model('claude-2') messages = [PromptMessage(content='Human: 1 + 1=? \nAssistant: ')] rst = model.run( @@ -58,4 +60,3 @@ def test_run(mock_decrypt): stop=['\nHuman:'], ) assert len(rst.content) > 0 - assert rst.content.strip() == '2' diff --git a/api/tests/integration_tests/models/llm/test_azure_openai_model.py b/api/tests/integration_tests/models/llm/test_azure_openai_model.py index 112dcd6841..1df272d1cc 100644 --- a/api/tests/integration_tests/models/llm/test_azure_openai_model.py +++ b/api/tests/integration_tests/models/llm/test_azure_openai_model.py @@ -76,6 +76,8 @@ def test_chat_get_num_tokens(mock_decrypt, mocker): @patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect) def test_run(mock_decrypt, mocker): + mocker.patch('core.model_providers.providers.base.BaseModelProvider.update_last_used', return_value=None) + openai_model = get_mock_azure_openai_model('gpt-35-turbo', mocker) messages = [PromptMessage(content='Human: Are you Human? you MUST only answer `y` or `n`? \nAssistant: ')] rst = openai_model.run( @@ -83,4 +85,3 @@ def test_run(mock_decrypt, mocker): stop=['\nHuman:'], ) assert len(rst.content) > 0 - assert rst.content.strip() == 'n' diff --git a/api/tests/integration_tests/models/llm/test_huggingface_hub_model.py b/api/tests/integration_tests/models/llm/test_huggingface_hub_model.py index d55d6e93fe..eda95102c9 100644 --- a/api/tests/integration_tests/models/llm/test_huggingface_hub_model.py +++ b/api/tests/integration_tests/models/llm/test_huggingface_hub_model.py @@ -95,6 +95,8 @@ def test_inference_endpoints_get_num_tokens(mock_decrypt, mock_model_info, mocke @patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect) def test_hosted_inference_api_run(mock_decrypt, mocker): + mocker.patch('core.model_providers.providers.base.BaseModelProvider.update_last_used', return_value=None) + model = get_mock_model( 'google/flan-t5-base', 'hosted_inference_api', @@ -111,6 +113,8 @@ def test_hosted_inference_api_run(mock_decrypt, mocker): @patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect) def test_inference_endpoints_run(mock_decrypt, mocker): + mocker.patch('core.model_providers.providers.base.BaseModelProvider.update_last_used', return_value=None) + model = get_mock_model( '', 'inference_endpoints', @@ -121,4 +125,3 @@ def test_inference_endpoints_run(mock_decrypt, mocker): [PromptMessage(content='Answer the following yes/no question. Can you write a whole Haiku in a single tweet?')], ) assert len(rst.content) > 0 - assert rst.content.strip() == 'no' diff --git a/api/tests/integration_tests/models/llm/test_minimax_model.py b/api/tests/integration_tests/models/llm/test_minimax_model.py index 79a05bc279..d93f8ad735 100644 --- a/api/tests/integration_tests/models/llm/test_minimax_model.py +++ b/api/tests/integration_tests/models/llm/test_minimax_model.py @@ -54,11 +54,12 @@ def test_get_num_tokens(mock_decrypt): @patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect) -def test_run(mock_decrypt): +def test_run(mock_decrypt, mocker): + mocker.patch('core.model_providers.providers.base.BaseModelProvider.update_last_used', return_value=None) + model = get_mock_model('abab5.5-chat') rst = model.run( [PromptMessage(content='Human: Are you a real Human? you MUST only answer `y` or `n`? \nAssistant: ')], stop=['\nHuman:'], ) assert len(rst.content) > 0 - assert rst.content.strip() == 'n' diff --git a/api/tests/integration_tests/models/llm/test_openai_model.py b/api/tests/integration_tests/models/llm/test_openai_model.py index ebc40fd529..8c3dd70dd8 100644 --- a/api/tests/integration_tests/models/llm/test_openai_model.py +++ b/api/tests/integration_tests/models/llm/test_openai_model.py @@ -58,7 +58,9 @@ def test_chat_get_num_tokens(mock_decrypt): @patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect) -def test_run(mock_decrypt): +def test_run(mock_decrypt, mocker): + mocker.patch('core.model_providers.providers.base.BaseModelProvider.update_last_used', return_value=None) + openai_model = get_mock_openai_model('text-davinci-003') rst = openai_model.run( [PromptMessage(content='Human: Are you Human? you MUST only answer `y` or `n`? \nAssistant: ')], @@ -69,7 +71,9 @@ def test_run(mock_decrypt): @patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect) -def test_chat_run(mock_decrypt): +def test_chat_run(mock_decrypt, mocker): + mocker.patch('core.model_providers.providers.base.BaseModelProvider.update_last_used', return_value=None) + openai_model = get_mock_openai_model('gpt-3.5-turbo') messages = [PromptMessage(content='Human: Are you Human? you MUST only answer `y` or `n`? \nAssistant: ')] rst = openai_model.run( @@ -77,4 +81,3 @@ def test_chat_run(mock_decrypt): stop=['\nHuman:'], ) assert len(rst.content) > 0 - assert rst.content.strip() == 'n' diff --git a/api/tests/integration_tests/models/llm/test_replicate_model.py b/api/tests/integration_tests/models/llm/test_replicate_model.py index 7689a3c0fc..13efc19881 100644 --- a/api/tests/integration_tests/models/llm/test_replicate_model.py +++ b/api/tests/integration_tests/models/llm/test_replicate_model.py @@ -65,6 +65,8 @@ def test_get_num_tokens(mock_decrypt, mocker): @patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect) def test_run(mock_decrypt, mocker): + mocker.patch('core.model_providers.providers.base.BaseModelProvider.update_last_used', return_value=None) + model = get_mock_model('a16z-infra/llama-2-13b-chat', '2a7f981751ec7fdf87b5b91ad4db53683a98082e9ff7bfd12c8cd5ea85980a52', mocker) messages = [PromptMessage(content='Human: 1+1=? \nAnswer: ')] rst = model.run( diff --git a/api/tests/integration_tests/models/llm/test_spark_model.py b/api/tests/integration_tests/models/llm/test_spark_model.py index 4e62aeb2cd..d07bfb279a 100644 --- a/api/tests/integration_tests/models/llm/test_spark_model.py +++ b/api/tests/integration_tests/models/llm/test_spark_model.py @@ -58,7 +58,9 @@ def test_get_num_tokens(mock_decrypt): @patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect) -def test_run(mock_decrypt): +def test_run(mock_decrypt, mocker): + mocker.patch('core.model_providers.providers.base.BaseModelProvider.update_last_used', return_value=None) + model = get_mock_model('spark') messages = [PromptMessage(content='Human: 1 + 1=? \nAssistant: Integer answer is:')] rst = model.run( @@ -66,4 +68,3 @@ def test_run(mock_decrypt): stop=['\nHuman:'], ) assert len(rst.content) > 0 - assert rst.content.strip() == '2' diff --git a/api/tests/integration_tests/models/llm/test_tongyi_model.py b/api/tests/integration_tests/models/llm/test_tongyi_model.py index 2f9e33992f..c2254dec0c 100644 --- a/api/tests/integration_tests/models/llm/test_tongyi_model.py +++ b/api/tests/integration_tests/models/llm/test_tongyi_model.py @@ -52,7 +52,9 @@ def test_get_num_tokens(mock_decrypt): @patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect) -def test_run(mock_decrypt): +def test_run(mock_decrypt, mocker): + mocker.patch('core.model_providers.providers.base.BaseModelProvider.update_last_used', return_value=None) + model = get_mock_model('qwen-v1') rst = model.run( [PromptMessage(content='Human: Are you Human? you MUST only answer `y` or `n`? \nAssistant: ')], diff --git a/api/tests/integration_tests/models/llm/test_wenxin_model.py b/api/tests/integration_tests/models/llm/test_wenxin_model.py index f517d05c25..29a0de3262 100644 --- a/api/tests/integration_tests/models/llm/test_wenxin_model.py +++ b/api/tests/integration_tests/models/llm/test_wenxin_model.py @@ -52,7 +52,9 @@ def test_get_num_tokens(mock_decrypt): @patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect) -def test_run(mock_decrypt): +def test_run(mock_decrypt, mocker): + mocker.patch('core.model_providers.providers.base.BaseModelProvider.update_last_used', return_value=None) + model = get_mock_model('ernie-bot') messages = [PromptMessage(content='Human: 1 + 1=? \nAssistant: Integer answer is:')] rst = model.run( @@ -60,4 +62,3 @@ def test_run(mock_decrypt): stop=['\nHuman:'], ) assert len(rst.content) > 0 - assert rst.content.strip() == '2' diff --git a/api/tests/integration_tests/models/llm/test_xinference_model.py b/api/tests/integration_tests/models/llm/test_xinference_model.py new file mode 100644 index 0000000000..aab075fae2 --- /dev/null +++ b/api/tests/integration_tests/models/llm/test_xinference_model.py @@ -0,0 +1,74 @@ +import json +import os +from unittest.mock import patch, MagicMock + +from core.model_providers.models.entity.message import PromptMessage, MessageType +from core.model_providers.models.entity.model_params import ModelKwargs, ModelType +from core.model_providers.models.llm.xinference_model import XinferenceModel +from core.model_providers.providers.xinference_provider import XinferenceProvider +from models.provider import Provider, ProviderType, ProviderModel + + +def get_mock_provider(): + return Provider( + id='provider_id', + tenant_id='tenant_id', + provider_name='xinference', + provider_type=ProviderType.CUSTOM.value, + encrypted_config='', + is_valid=True, + ) + + +def get_mock_model(model_name, mocker): + model_kwargs = ModelKwargs( + max_tokens=10, + temperature=0.01 + ) + server_url = os.environ['XINFERENCE_SERVER_URL'] + model_uid = os.environ['XINFERENCE_MODEL_UID'] + model_provider = XinferenceProvider(provider=get_mock_provider()) + + mock_query = MagicMock() + mock_query.filter.return_value.first.return_value = ProviderModel( + provider_name='xinference', + model_name=model_name, + model_type=ModelType.TEXT_GENERATION.value, + encrypted_config=json.dumps({ + 'server_url': server_url, + 'model_uid': model_uid + }), + is_valid=True, + ) + mocker.patch('extensions.ext_database.db.session.query', return_value=mock_query) + + return XinferenceModel( + model_provider=model_provider, + name=model_name, + model_kwargs=model_kwargs + ) + + +def decrypt_side_effect(tenant_id, encrypted_api_key): + return encrypted_api_key + + +@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect) +def test_get_num_tokens(mock_decrypt, mocker): + model = get_mock_model('llama-2-chat', mocker) + rst = model.get_num_tokens([ + PromptMessage(type=MessageType.HUMAN, content='Who is your manufacturer?') + ]) + assert rst == 5 + + +@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect) +def test_run(mock_decrypt, mocker): + mocker.patch('core.model_providers.providers.base.BaseModelProvider.update_last_used', return_value=None) + + model = get_mock_model('llama-2-chat', mocker) + messages = [PromptMessage(content='Human: 1+1=? \nAnswer: ')] + rst = model.run( + messages + ) + assert len(rst.content) > 0 diff --git a/api/tests/unit_tests/model_providers/test_xinference_provider.py b/api/tests/unit_tests/model_providers/test_xinference_provider.py new file mode 100644 index 0000000000..4cf85dcce3 --- /dev/null +++ b/api/tests/unit_tests/model_providers/test_xinference_provider.py @@ -0,0 +1,124 @@ +import pytest +from unittest.mock import patch, MagicMock +import json + +from core.model_providers.models.entity.model_params import ModelType +from core.model_providers.providers.base import CredentialsValidateFailedError +from core.model_providers.providers.replicate_provider import ReplicateProvider +from core.model_providers.providers.xinference_provider import XinferenceProvider +from models.provider import ProviderType, Provider, ProviderModel + +PROVIDER_NAME = 'xinference' +MODEL_PROVIDER_CLASS = XinferenceProvider +VALIDATE_CREDENTIAL = { + 'model_uid': 'fake-model-uid', + 'server_url': 'http://127.0.0.1:9997/' +} + + +def encrypt_side_effect(tenant_id, encrypt_key): + return f'encrypted_{encrypt_key}' + + +def decrypt_side_effect(tenant_id, encrypted_key): + return encrypted_key.replace('encrypted_', '') + + +def test_is_credentials_valid_or_raise_valid(mocker): + mocker.patch('langchain.llms.xinference.Xinference._call', + return_value="abc") + + MODEL_PROVIDER_CLASS.is_model_credentials_valid_or_raise( + model_name='username/test_model_name', + model_type=ModelType.TEXT_GENERATION, + credentials=VALIDATE_CREDENTIAL.copy() + ) + + +def test_is_credentials_valid_or_raise_invalid(): + # raise CredentialsValidateFailedError if replicate_api_token is not in credentials + with pytest.raises(CredentialsValidateFailedError): + MODEL_PROVIDER_CLASS.is_model_credentials_valid_or_raise( + model_name='test_model_name', + model_type=ModelType.TEXT_GENERATION, + credentials={} + ) + + # raise CredentialsValidateFailedError if replicate_api_token is invalid + with pytest.raises(CredentialsValidateFailedError): + MODEL_PROVIDER_CLASS.is_model_credentials_valid_or_raise( + model_name='test_model_name', + model_type=ModelType.TEXT_GENERATION, + credentials={'server_url': 'invalid'}) + + +@patch('core.helper.encrypter.encrypt_token', side_effect=encrypt_side_effect) +def test_encrypt_model_credentials(mock_encrypt): + api_key = 'http://127.0.0.1:9997/' + result = MODEL_PROVIDER_CLASS.encrypt_model_credentials( + tenant_id='tenant_id', + model_name='test_model_name', + model_type=ModelType.TEXT_GENERATION, + credentials=VALIDATE_CREDENTIAL.copy() + ) + mock_encrypt.assert_called_with('tenant_id', api_key) + assert result['server_url'] == f'encrypted_{api_key}' + + +@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect) +def test_get_model_credentials_custom(mock_decrypt, mocker): + provider = Provider( + id='provider_id', + tenant_id='tenant_id', + provider_name=PROVIDER_NAME, + provider_type=ProviderType.CUSTOM.value, + encrypted_config=None, + is_valid=True, + ) + + encrypted_credential = VALIDATE_CREDENTIAL.copy() + encrypted_credential['server_url'] = 'encrypted_' + encrypted_credential['server_url'] + + mock_query = MagicMock() + mock_query.filter.return_value.first.return_value = ProviderModel( + encrypted_config=json.dumps(encrypted_credential) + ) + mocker.patch('extensions.ext_database.db.session.query', return_value=mock_query) + + model_provider = MODEL_PROVIDER_CLASS(provider=provider) + result = model_provider.get_model_credentials( + model_name='test_model_name', + model_type=ModelType.TEXT_GENERATION + ) + assert result['server_url'] == 'http://127.0.0.1:9997/' + + +@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect) +def test_get_model_credentials_obfuscated(mock_decrypt, mocker): + provider = Provider( + id='provider_id', + tenant_id='tenant_id', + provider_name=PROVIDER_NAME, + provider_type=ProviderType.CUSTOM.value, + encrypted_config=None, + is_valid=True, + ) + + encrypted_credential = VALIDATE_CREDENTIAL.copy() + encrypted_credential['server_url'] = 'encrypted_' + encrypted_credential['server_url'] + + mock_query = MagicMock() + mock_query.filter.return_value.first.return_value = ProviderModel( + encrypted_config=json.dumps(encrypted_credential) + ) + mocker.patch('extensions.ext_database.db.session.query', return_value=mock_query) + + model_provider = MODEL_PROVIDER_CLASS(provider=provider) + result = model_provider.get_model_credentials( + model_name='test_model_name', + model_type=ModelType.TEXT_GENERATION, + obfuscated=True + ) + middle_token = result['server_url'][6:-2] + assert len(middle_token) == max(len(VALIDATE_CREDENTIAL['server_url']) - 8, 0) + assert all(char == '*' for char in middle_token)