diff --git a/api/core/model_providers/model_provider_factory.py b/api/core/model_providers/model_provider_factory.py index c04f07f4d9..3a7e9b5c97 100644 --- a/api/core/model_providers/model_provider_factory.py +++ b/api/core/model_providers/model_provider_factory.py @@ -60,6 +60,9 @@ class ModelProviderFactory: elif provider_name == 'xinference': from core.model_providers.providers.xinference_provider import XinferenceProvider return XinferenceProvider + elif provider_name == 'openllm': + from core.model_providers.providers.openllm_provider import OpenLLMProvider + return OpenLLMProvider else: raise NotImplementedError diff --git a/api/core/model_providers/models/llm/openllm_model.py b/api/core/model_providers/models/llm/openllm_model.py new file mode 100644 index 0000000000..eba0d44eb5 --- /dev/null +++ b/api/core/model_providers/models/llm/openllm_model.py @@ -0,0 +1,60 @@ +from typing import List, Optional, Any + +from langchain.callbacks.manager import Callbacks +from langchain.llms import OpenLLM +from langchain.schema import LLMResult + +from core.model_providers.error import LLMBadRequestError +from core.model_providers.models.llm.base import BaseLLM +from core.model_providers.models.entity.message import PromptMessage +from core.model_providers.models.entity.model_params import ModelMode, ModelKwargs + + +class OpenLLMModel(BaseLLM): + model_mode: ModelMode = ModelMode.COMPLETION + + def _init_client(self) -> Any: + self.provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs) + + client = OpenLLM( + server_url=self.credentials.get('server_url'), + callbacks=self.callbacks, + **self.provider_model_kwargs + ) + + return client + + def _run(self, messages: List[PromptMessage], + stop: Optional[List[str]] = None, + callbacks: Callbacks = None, + **kwargs) -> LLMResult: + """ + run predict by prompt messages and stop words. + + :param messages: + :param stop: + :param callbacks: + :return: + """ + prompts = self._get_prompt_from_messages(messages) + return self._client.generate([prompts], stop, callbacks) + + def get_num_tokens(self, messages: List[PromptMessage]) -> int: + """ + get num tokens of prompt messages. + + :param messages: + :return: + """ + prompts = self._get_prompt_from_messages(messages) + return max(self._client.get_num_tokens(prompts), 0) + + def _set_model_kwargs(self, model_kwargs: ModelKwargs): + pass + + def handle_exceptions(self, ex: Exception) -> Exception: + return LLMBadRequestError(f"OpenLLM: {str(ex)}") + + @classmethod + def support_streaming(cls): + return False diff --git a/api/core/model_providers/providers/openllm_provider.py b/api/core/model_providers/providers/openllm_provider.py new file mode 100644 index 0000000000..efcc62c52e --- /dev/null +++ b/api/core/model_providers/providers/openllm_provider.py @@ -0,0 +1,137 @@ +import json +from typing import Type + +from langchain.llms import OpenLLM + +from core.helper import encrypter +from core.model_providers.models.entity.model_params import KwargRule, ModelKwargsRules, ModelType +from core.model_providers.models.llm.openllm_model import OpenLLMModel +from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError + +from core.model_providers.models.base import BaseProviderModel +from models.provider import ProviderType + + +class OpenLLMProvider(BaseModelProvider): + @property + def provider_name(self): + """ + Returns the name of a provider. + """ + return 'openllm' + + def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]: + return [] + + def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]: + """ + Returns the model class. + + :param model_type: + :return: + """ + if model_type == ModelType.TEXT_GENERATION: + model_class = OpenLLMModel + else: + raise NotImplementedError + + return model_class + + def get_model_parameter_rules(self, model_name: str, model_type: ModelType) -> ModelKwargsRules: + """ + get model parameter rules. + + :param model_name: + :param model_type: + :return: + """ + return ModelKwargsRules( + temperature=KwargRule[float](min=0, max=2, default=1), + top_p=KwargRule[float](min=0, max=1, default=0.7), + presence_penalty=KwargRule[float](min=-2, max=2, default=0), + frequency_penalty=KwargRule[float](min=-2, max=2, default=0), + max_tokens=KwargRule[int](min=10, max=4000, default=128), + ) + + @classmethod + def is_model_credentials_valid_or_raise(cls, model_name: str, model_type: ModelType, credentials: dict): + """ + check model credentials valid. + + :param model_name: + :param model_type: + :param credentials: + """ + if 'server_url' not in credentials: + raise CredentialsValidateFailedError('OpenLLM Server URL must be provided.') + + try: + credential_kwargs = { + 'server_url': credentials['server_url'] + } + + llm = OpenLLM( + max_tokens=10, + **credential_kwargs + ) + + llm("ping") + except Exception as ex: + raise CredentialsValidateFailedError(str(ex)) + + @classmethod + def encrypt_model_credentials(cls, tenant_id: str, model_name: str, model_type: ModelType, + credentials: dict) -> dict: + """ + encrypt model credentials for save. + + :param tenant_id: + :param model_name: + :param model_type: + :param credentials: + :return: + """ + credentials['server_url'] = encrypter.encrypt_token(tenant_id, credentials['server_url']) + return credentials + + def get_model_credentials(self, model_name: str, model_type: ModelType, obfuscated: bool = False) -> dict: + """ + get credentials for llm use. + + :param model_name: + :param model_type: + :param obfuscated: + :return: + """ + if self.provider.provider_type != ProviderType.CUSTOM.value: + raise NotImplementedError + + provider_model = self._get_provider_model(model_name, model_type) + + if not provider_model.encrypted_config: + return { + 'server_url': None + } + + credentials = json.loads(provider_model.encrypted_config) + if credentials['server_url']: + credentials['server_url'] = encrypter.decrypt_token( + self.provider.tenant_id, + credentials['server_url'] + ) + + if obfuscated: + credentials['server_url'] = encrypter.obfuscated_token(credentials['server_url']) + + return credentials + + @classmethod + def is_provider_credentials_valid_or_raise(cls, credentials: dict): + return + + @classmethod + def encrypt_provider_credentials(cls, tenant_id: str, credentials: dict) -> dict: + return {} + + def get_provider_credentials(self, obfuscated: bool = False) -> dict: + return {} diff --git a/api/core/model_providers/rules/_providers.json b/api/core/model_providers/rules/_providers.json index 7fb4cce70c..0c1463663a 100644 --- a/api/core/model_providers/rules/_providers.json +++ b/api/core/model_providers/rules/_providers.json @@ -9,5 +9,6 @@ "chatglm", "replicate", "huggingface_hub", - "xinference" + "xinference", + "openllm" ] \ No newline at end of file diff --git a/api/core/model_providers/rules/openllm.json b/api/core/model_providers/rules/openllm.json new file mode 100644 index 0000000000..5badb07178 --- /dev/null +++ b/api/core/model_providers/rules/openllm.json @@ -0,0 +1,7 @@ +{ + "support_provider_types": [ + "custom" + ], + "system_config": null, + "model_flexibility": "configurable" +} \ No newline at end of file diff --git a/api/requirements.txt b/api/requirements.txt index 43028d9225..3ae232cd6f 100644 --- a/api/requirements.txt +++ b/api/requirements.txt @@ -49,4 +49,5 @@ huggingface_hub~=0.16.4 transformers~=4.31.0 stripe~=5.5.0 pandas==1.5.3 -xinference==0.2.0 \ No newline at end of file +xinference==0.2.0 +openllm~=0.2.26 \ No newline at end of file diff --git a/api/tests/integration_tests/.env.example b/api/tests/integration_tests/.env.example index c24ec68c30..7c39d308bf 100644 --- a/api/tests/integration_tests/.env.example +++ b/api/tests/integration_tests/.env.example @@ -36,4 +36,7 @@ CHATGLM_API_BASE= # Xinference Credentials XINFERENCE_SERVER_URL= -XINFERENCE_MODEL_UID= \ No newline at end of file +XINFERENCE_MODEL_UID= + +# OpenLLM Credentials +OPENLLM_SERVER_URL= \ No newline at end of file diff --git a/api/tests/integration_tests/models/llm/test_openllm_model.py b/api/tests/integration_tests/models/llm/test_openllm_model.py new file mode 100644 index 0000000000..d515f35048 --- /dev/null +++ b/api/tests/integration_tests/models/llm/test_openllm_model.py @@ -0,0 +1,72 @@ +import json +import os +from unittest.mock import patch, MagicMock + +from core.model_providers.models.entity.message import PromptMessage, MessageType +from core.model_providers.models.entity.model_params import ModelKwargs, ModelType +from core.model_providers.models.llm.openllm_model import OpenLLMModel +from core.model_providers.providers.openllm_provider import OpenLLMProvider +from models.provider import Provider, ProviderType, ProviderModel + + +def get_mock_provider(): + return Provider( + id='provider_id', + tenant_id='tenant_id', + provider_name='openllm', + provider_type=ProviderType.CUSTOM.value, + encrypted_config='', + is_valid=True, + ) + + +def get_mock_model(model_name, mocker): + model_kwargs = ModelKwargs( + max_tokens=10, + temperature=0.01 + ) + server_url = os.environ['OPENLLM_SERVER_URL'] + model_provider = OpenLLMProvider(provider=get_mock_provider()) + + mock_query = MagicMock() + mock_query.filter.return_value.first.return_value = ProviderModel( + provider_name='openllm', + model_name=model_name, + model_type=ModelType.TEXT_GENERATION.value, + encrypted_config=json.dumps({ + 'server_url': server_url + }), + is_valid=True, + ) + mocker.patch('extensions.ext_database.db.session.query', return_value=mock_query) + + return OpenLLMModel( + model_provider=model_provider, + name=model_name, + model_kwargs=model_kwargs + ) + + +def decrypt_side_effect(tenant_id, encrypted_api_key): + return encrypted_api_key + + +@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect) +def test_get_num_tokens(mock_decrypt, mocker): + model = get_mock_model('facebook/opt-125m', mocker) + rst = model.get_num_tokens([ + PromptMessage(type=MessageType.HUMAN, content='Who is your manufacturer?') + ]) + assert rst == 5 + + +@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect) +def test_run(mock_decrypt, mocker): + mocker.patch('core.model_providers.providers.base.BaseModelProvider.update_last_used', return_value=None) + + model = get_mock_model('facebook/opt-125m', mocker) + messages = [PromptMessage(content='Human: who are you? \nAnswer: ')] + rst = model.run( + messages + ) + assert len(rst.content) > 0 diff --git a/api/tests/unit_tests/model_providers/test_openllm_provider.py b/api/tests/unit_tests/model_providers/test_openllm_provider.py new file mode 100644 index 0000000000..42609ed360 --- /dev/null +++ b/api/tests/unit_tests/model_providers/test_openllm_provider.py @@ -0,0 +1,125 @@ +import pytest +from unittest.mock import patch, MagicMock +import json + +from core.model_providers.models.entity.model_params import ModelType +from core.model_providers.providers.base import CredentialsValidateFailedError +from core.model_providers.providers.openllm_provider import OpenLLMProvider +from models.provider import ProviderType, Provider, ProviderModel + +PROVIDER_NAME = 'openllm' +MODEL_PROVIDER_CLASS = OpenLLMProvider +VALIDATE_CREDENTIAL = { + 'server_url': 'http://127.0.0.1:3333/' +} + + +def encrypt_side_effect(tenant_id, encrypt_key): + return f'encrypted_{encrypt_key}' + + +def decrypt_side_effect(tenant_id, encrypted_key): + return encrypted_key.replace('encrypted_', '') + + +def test_is_credentials_valid_or_raise_valid(mocker): + mocker.patch('langchain.llms.openllm.OpenLLM._identifying_params', return_value=None) + mocker.patch('langchain.llms.openllm.OpenLLM._call', + return_value="abc") + + MODEL_PROVIDER_CLASS.is_model_credentials_valid_or_raise( + model_name='username/test_model_name', + model_type=ModelType.TEXT_GENERATION, + credentials=VALIDATE_CREDENTIAL.copy() + ) + + +def test_is_credentials_valid_or_raise_invalid(mocker): + mocker.patch('langchain.llms.openllm.OpenLLM._identifying_params', return_value=None) + + # raise CredentialsValidateFailedError if credential is not in credentials + with pytest.raises(CredentialsValidateFailedError): + MODEL_PROVIDER_CLASS.is_model_credentials_valid_or_raise( + model_name='test_model_name', + model_type=ModelType.TEXT_GENERATION, + credentials={} + ) + + # raise CredentialsValidateFailedError if credential is invalid + with pytest.raises(CredentialsValidateFailedError): + MODEL_PROVIDER_CLASS.is_model_credentials_valid_or_raise( + model_name='test_model_name', + model_type=ModelType.TEXT_GENERATION, + credentials={'server_url': 'invalid'}) + + +@patch('core.helper.encrypter.encrypt_token', side_effect=encrypt_side_effect) +def test_encrypt_model_credentials(mock_encrypt): + api_key = 'http://127.0.0.1:3333/' + result = MODEL_PROVIDER_CLASS.encrypt_model_credentials( + tenant_id='tenant_id', + model_name='test_model_name', + model_type=ModelType.TEXT_GENERATION, + credentials=VALIDATE_CREDENTIAL.copy() + ) + mock_encrypt.assert_called_with('tenant_id', api_key) + assert result['server_url'] == f'encrypted_{api_key}' + + +@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect) +def test_get_model_credentials_custom(mock_decrypt, mocker): + provider = Provider( + id='provider_id', + tenant_id='tenant_id', + provider_name=PROVIDER_NAME, + provider_type=ProviderType.CUSTOM.value, + encrypted_config=None, + is_valid=True, + ) + + encrypted_credential = VALIDATE_CREDENTIAL.copy() + encrypted_credential['server_url'] = 'encrypted_' + encrypted_credential['server_url'] + + mock_query = MagicMock() + mock_query.filter.return_value.first.return_value = ProviderModel( + encrypted_config=json.dumps(encrypted_credential) + ) + mocker.patch('extensions.ext_database.db.session.query', return_value=mock_query) + + model_provider = MODEL_PROVIDER_CLASS(provider=provider) + result = model_provider.get_model_credentials( + model_name='test_model_name', + model_type=ModelType.TEXT_GENERATION + ) + assert result['server_url'] == 'http://127.0.0.1:3333/' + + +@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect) +def test_get_model_credentials_obfuscated(mock_decrypt, mocker): + provider = Provider( + id='provider_id', + tenant_id='tenant_id', + provider_name=PROVIDER_NAME, + provider_type=ProviderType.CUSTOM.value, + encrypted_config=None, + is_valid=True, + ) + + encrypted_credential = VALIDATE_CREDENTIAL.copy() + encrypted_credential['server_url'] = 'encrypted_' + encrypted_credential['server_url'] + + mock_query = MagicMock() + mock_query.filter.return_value.first.return_value = ProviderModel( + encrypted_config=json.dumps(encrypted_credential) + ) + mocker.patch('extensions.ext_database.db.session.query', return_value=mock_query) + + model_provider = MODEL_PROVIDER_CLASS(provider=provider) + result = model_provider.get_model_credentials( + model_name='test_model_name', + model_type=ModelType.TEXT_GENERATION, + obfuscated=True + ) + middle_token = result['server_url'][6:-2] + assert len(middle_token) == max(len(VALIDATE_CREDENTIAL['server_url']) - 8, 0) + assert all(char == '*' for char in middle_token)