diff --git a/.github/workflows/api-unit-tests.yml b/.github/workflows/api-unit-tests.yml new file mode 100644 index 0000000000..6e795c953f --- /dev/null +++ b/.github/workflows/api-unit-tests.yml @@ -0,0 +1,38 @@ +name: Run Pytest + +on: + pull_request: + branches: + - main + push: + branches: + - deploy/dev + +jobs: + test: + runs-on: ubuntu-latest + + steps: + - name: Checkout code + uses: actions/checkout@v2 + + - name: Set up Python + uses: actions/setup-python@v2 + with: + python-version: '3.10' + + - name: Cache pip dependencies + uses: actions/cache@v2 + with: + path: ~/.cache/pip + key: ${{ runner.os }}-pip-${{ hashFiles('api/requirements.txt') }} + restore-keys: ${{ runner.os }}-pip- + + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install pytest + pip install -r api/requirements.txt + + - name: Run pytest + run: pytest api/tests/unit_tests diff --git a/api/controllers/console/app/app.py b/api/controllers/console/app/app.py index c44f4edc62..9e949865f2 100644 --- a/api/controllers/console/app/app.py +++ b/api/controllers/console/app/app.py @@ -1,5 +1,6 @@ # -*- coding:utf-8 -*- import json +import logging from datetime import datetime from flask_login import login_required, current_user @@ -11,7 +12,9 @@ from controllers.console import api from controllers.console.app.error import AppNotFoundError, ProviderNotInitializeError from controllers.console.setup import setup_required from controllers.console.wraps import account_initialization_required +from core.model_providers.error import ProviderTokenNotInitError, LLMBadRequestError from core.model_providers.model_factory import ModelFactory +from core.model_providers.model_provider_factory import ModelProviderFactory from core.model_providers.models.entity.model_params import ModelType from events.app_event import app_was_created, app_was_deleted from libs.helper import TimestampField @@ -124,24 +127,34 @@ class AppListApi(Resource): if current_user.current_tenant.current_role not in ['admin', 'owner']: raise Forbidden() - default_model = ModelFactory.get_default_model( - tenant_id=current_user.current_tenant_id, - model_type=ModelType.TEXT_GENERATION - ) - - if default_model: - default_model_provider = default_model.provider_name - default_model_name = default_model.model_name - else: - raise ProviderNotInitializeError( - f"No Text Generation Model available. Please configure a valid provider " - f"in the Settings -> Model Provider.") + try: + default_model = ModelFactory.get_text_generation_model( + tenant_id=current_user.current_tenant_id + ) + except (ProviderTokenNotInitError, LLMBadRequestError): + default_model = None + except Exception as e: + logging.exception(e) + default_model = None if args['model_config'] is not None: # validate config model_config_dict = args['model_config'] - model_config_dict["model"]["provider"] = default_model_provider - model_config_dict["model"]["name"] = default_model_name + + # get model provider + model_provider = ModelProviderFactory.get_preferred_model_provider( + current_user.current_tenant_id, + model_config_dict["model"]["provider"] + ) + + if not model_provider: + if not default_model: + raise ProviderNotInitializeError( + f"No Default System Reasoning Model available. Please configure " + f"in the Settings -> Model Provider.") + else: + model_config_dict["model"]["provider"] = default_model.model_provider.provider_name + model_config_dict["model"]["name"] = default_model.name model_configuration = AppModelConfigService.validate_configuration( tenant_id=current_user.current_tenant_id, @@ -169,10 +182,22 @@ class AppListApi(Resource): app = App(**model_config_template['app']) app_model_config = AppModelConfig(**model_config_template['model_config']) - model_dict = app_model_config.model_dict - model_dict['provider'] = default_model_provider - model_dict['name'] = default_model_name - app_model_config.model = json.dumps(model_dict) + # get model provider + model_provider = ModelProviderFactory.get_preferred_model_provider( + current_user.current_tenant_id, + app_model_config.model_dict["provider"] + ) + + if not model_provider: + if not default_model: + raise ProviderNotInitializeError( + f"No Default System Reasoning Model available. Please configure " + f"in the Settings -> Model Provider.") + else: + model_dict = app_model_config.model_dict + model_dict['provider'] = default_model.model_provider.provider_name + model_dict['name'] = default_model.name + app_model_config.model = json.dumps(model_dict) app.name = args['name'] app.mode = args['mode'] diff --git a/api/tests/unit_tests/model_providers/test_huggingface_hub_provider.py b/api/tests/unit_tests/model_providers/test_huggingface_hub_provider.py index 3f3384834c..61456f64f4 100644 --- a/api/tests/unit_tests/model_providers/test_huggingface_hub_provider.py +++ b/api/tests/unit_tests/model_providers/test_huggingface_hub_provider.py @@ -30,8 +30,9 @@ def decrypt_side_effect(tenant_id, encrypted_key): @patch('huggingface_hub.hf_api.ModelInfo') def test_hosted_inference_api_is_credentials_valid_or_raise_valid(mock_model_info, mocker): - mock_model_info.return_value = MagicMock(pipeline_tag='text2text-generation') - mocker.patch('langchain.llms.huggingface_hub.HuggingFaceHub._call', return_value="abc") + mock_model_info.return_value = MagicMock(pipeline_tag='text2text-generation', cardData={'inference': True}) + mocker.patch('huggingface_hub.hf_api.HfApi.whoami', return_value="abc") + mocker.patch('huggingface_hub.hf_api.HfApi.model_info', return_value=mock_model_info.return_value) MODEL_PROVIDER_CLASS.is_model_credentials_valid_or_raise( model_name='test_model_name', diff --git a/api/tests/unit_tests/model_providers/test_replicate_provider.py b/api/tests/unit_tests/model_providers/test_replicate_provider.py index e555636f0f..69f96368ec 100644 --- a/api/tests/unit_tests/model_providers/test_replicate_provider.py +++ b/api/tests/unit_tests/model_providers/test_replicate_provider.py @@ -23,14 +23,31 @@ def decrypt_side_effect(tenant_id, encrypted_key): return encrypted_key.replace('encrypted_', '') +def version_effect(id: str): + mock_version = MagicMock() + mock_version.openapi_schema = { + 'components': { + 'schemas': { + 'Output': { + 'items': { + 'type': 'string' + } + } + } + } + } + + return mock_version + +@patch('replicate.version.VersionCollection.get', side_effect=version_effect) def test_is_credentials_valid_or_raise_valid(mocker): mock_query = MagicMock() mock_query.return_value = None + mocker.patch('replicate.model.ModelCollection.get', return_value=mock_query) - mocker.patch('replicate.model.Model.versions', return_value=mock_query) MODEL_PROVIDER_CLASS.is_model_credentials_valid_or_raise( - model_name='test_model_name', + model_name='username/test_model_name', model_type=ModelType.TEXT_GENERATION, credentials=VALIDATE_CREDENTIAL.copy() ) diff --git a/api/tests/unit_tests/model_providers/test_tongyi_provider.py b/api/tests/unit_tests/model_providers/test_tongyi_provider.py index 275a1908fe..763f570cfa 100644 --- a/api/tests/unit_tests/model_providers/test_tongyi_provider.py +++ b/api/tests/unit_tests/model_providers/test_tongyi_provider.py @@ -26,7 +26,7 @@ def decrypt_side_effect(tenant_id, encrypted_key): def test_is_provider_credentials_valid_or_raise_valid(mocker): - mocker.patch('langchain.llms.tongyi.Tongyi._generate', return_value=LLMResult(generations=[[Generation(text="abc")]])) + mocker.patch('core.third_party.langchain.llms.tongyi_llm.EnhanceTongyi._generate', return_value=LLMResult(generations=[[Generation(text="abc")]])) MODEL_PROVIDER_CLASS.is_provider_credentials_valid_or_raise(VALIDATE_CREDENTIAL)