mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-11 08:18:58 +08:00
feat: optimize model when app create (#875)
This commit is contained in:
parent
cc2d71c253
commit
b7c29ea1b6
38
.github/workflows/api-unit-tests.yml
vendored
Normal file
38
.github/workflows/api-unit-tests.yml
vendored
Normal file
@ -0,0 +1,38 @@
|
|||||||
|
name: Run Pytest
|
||||||
|
|
||||||
|
on:
|
||||||
|
pull_request:
|
||||||
|
branches:
|
||||||
|
- main
|
||||||
|
push:
|
||||||
|
branches:
|
||||||
|
- deploy/dev
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
test:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- name: Checkout code
|
||||||
|
uses: actions/checkout@v2
|
||||||
|
|
||||||
|
- name: Set up Python
|
||||||
|
uses: actions/setup-python@v2
|
||||||
|
with:
|
||||||
|
python-version: '3.10'
|
||||||
|
|
||||||
|
- name: Cache pip dependencies
|
||||||
|
uses: actions/cache@v2
|
||||||
|
with:
|
||||||
|
path: ~/.cache/pip
|
||||||
|
key: ${{ runner.os }}-pip-${{ hashFiles('api/requirements.txt') }}
|
||||||
|
restore-keys: ${{ runner.os }}-pip-
|
||||||
|
|
||||||
|
- name: Install dependencies
|
||||||
|
run: |
|
||||||
|
python -m pip install --upgrade pip
|
||||||
|
pip install pytest
|
||||||
|
pip install -r api/requirements.txt
|
||||||
|
|
||||||
|
- name: Run pytest
|
||||||
|
run: pytest api/tests/unit_tests
|
@ -1,5 +1,6 @@
|
|||||||
# -*- coding:utf-8 -*-
|
# -*- coding:utf-8 -*-
|
||||||
import json
|
import json
|
||||||
|
import logging
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
from flask_login import login_required, current_user
|
from flask_login import login_required, current_user
|
||||||
@ -11,7 +12,9 @@ from controllers.console import api
|
|||||||
from controllers.console.app.error import AppNotFoundError, ProviderNotInitializeError
|
from controllers.console.app.error import AppNotFoundError, ProviderNotInitializeError
|
||||||
from controllers.console.setup import setup_required
|
from controllers.console.setup import setup_required
|
||||||
from controllers.console.wraps import account_initialization_required
|
from controllers.console.wraps import account_initialization_required
|
||||||
|
from core.model_providers.error import ProviderTokenNotInitError, LLMBadRequestError
|
||||||
from core.model_providers.model_factory import ModelFactory
|
from core.model_providers.model_factory import ModelFactory
|
||||||
|
from core.model_providers.model_provider_factory import ModelProviderFactory
|
||||||
from core.model_providers.models.entity.model_params import ModelType
|
from core.model_providers.models.entity.model_params import ModelType
|
||||||
from events.app_event import app_was_created, app_was_deleted
|
from events.app_event import app_was_created, app_was_deleted
|
||||||
from libs.helper import TimestampField
|
from libs.helper import TimestampField
|
||||||
@ -124,24 +127,34 @@ class AppListApi(Resource):
|
|||||||
if current_user.current_tenant.current_role not in ['admin', 'owner']:
|
if current_user.current_tenant.current_role not in ['admin', 'owner']:
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
|
|
||||||
default_model = ModelFactory.get_default_model(
|
try:
|
||||||
tenant_id=current_user.current_tenant_id,
|
default_model = ModelFactory.get_text_generation_model(
|
||||||
model_type=ModelType.TEXT_GENERATION
|
tenant_id=current_user.current_tenant_id
|
||||||
)
|
)
|
||||||
|
except (ProviderTokenNotInitError, LLMBadRequestError):
|
||||||
if default_model:
|
default_model = None
|
||||||
default_model_provider = default_model.provider_name
|
except Exception as e:
|
||||||
default_model_name = default_model.model_name
|
logging.exception(e)
|
||||||
else:
|
default_model = None
|
||||||
raise ProviderNotInitializeError(
|
|
||||||
f"No Text Generation Model available. Please configure a valid provider "
|
|
||||||
f"in the Settings -> Model Provider.")
|
|
||||||
|
|
||||||
if args['model_config'] is not None:
|
if args['model_config'] is not None:
|
||||||
# validate config
|
# validate config
|
||||||
model_config_dict = args['model_config']
|
model_config_dict = args['model_config']
|
||||||
model_config_dict["model"]["provider"] = default_model_provider
|
|
||||||
model_config_dict["model"]["name"] = default_model_name
|
# get model provider
|
||||||
|
model_provider = ModelProviderFactory.get_preferred_model_provider(
|
||||||
|
current_user.current_tenant_id,
|
||||||
|
model_config_dict["model"]["provider"]
|
||||||
|
)
|
||||||
|
|
||||||
|
if not model_provider:
|
||||||
|
if not default_model:
|
||||||
|
raise ProviderNotInitializeError(
|
||||||
|
f"No Default System Reasoning Model available. Please configure "
|
||||||
|
f"in the Settings -> Model Provider.")
|
||||||
|
else:
|
||||||
|
model_config_dict["model"]["provider"] = default_model.model_provider.provider_name
|
||||||
|
model_config_dict["model"]["name"] = default_model.name
|
||||||
|
|
||||||
model_configuration = AppModelConfigService.validate_configuration(
|
model_configuration = AppModelConfigService.validate_configuration(
|
||||||
tenant_id=current_user.current_tenant_id,
|
tenant_id=current_user.current_tenant_id,
|
||||||
@ -169,10 +182,22 @@ class AppListApi(Resource):
|
|||||||
app = App(**model_config_template['app'])
|
app = App(**model_config_template['app'])
|
||||||
app_model_config = AppModelConfig(**model_config_template['model_config'])
|
app_model_config = AppModelConfig(**model_config_template['model_config'])
|
||||||
|
|
||||||
model_dict = app_model_config.model_dict
|
# get model provider
|
||||||
model_dict['provider'] = default_model_provider
|
model_provider = ModelProviderFactory.get_preferred_model_provider(
|
||||||
model_dict['name'] = default_model_name
|
current_user.current_tenant_id,
|
||||||
app_model_config.model = json.dumps(model_dict)
|
app_model_config.model_dict["provider"]
|
||||||
|
)
|
||||||
|
|
||||||
|
if not model_provider:
|
||||||
|
if not default_model:
|
||||||
|
raise ProviderNotInitializeError(
|
||||||
|
f"No Default System Reasoning Model available. Please configure "
|
||||||
|
f"in the Settings -> Model Provider.")
|
||||||
|
else:
|
||||||
|
model_dict = app_model_config.model_dict
|
||||||
|
model_dict['provider'] = default_model.model_provider.provider_name
|
||||||
|
model_dict['name'] = default_model.name
|
||||||
|
app_model_config.model = json.dumps(model_dict)
|
||||||
|
|
||||||
app.name = args['name']
|
app.name = args['name']
|
||||||
app.mode = args['mode']
|
app.mode = args['mode']
|
||||||
|
@ -30,8 +30,9 @@ def decrypt_side_effect(tenant_id, encrypted_key):
|
|||||||
|
|
||||||
@patch('huggingface_hub.hf_api.ModelInfo')
|
@patch('huggingface_hub.hf_api.ModelInfo')
|
||||||
def test_hosted_inference_api_is_credentials_valid_or_raise_valid(mock_model_info, mocker):
|
def test_hosted_inference_api_is_credentials_valid_or_raise_valid(mock_model_info, mocker):
|
||||||
mock_model_info.return_value = MagicMock(pipeline_tag='text2text-generation')
|
mock_model_info.return_value = MagicMock(pipeline_tag='text2text-generation', cardData={'inference': True})
|
||||||
mocker.patch('langchain.llms.huggingface_hub.HuggingFaceHub._call', return_value="abc")
|
mocker.patch('huggingface_hub.hf_api.HfApi.whoami', return_value="abc")
|
||||||
|
mocker.patch('huggingface_hub.hf_api.HfApi.model_info', return_value=mock_model_info.return_value)
|
||||||
|
|
||||||
MODEL_PROVIDER_CLASS.is_model_credentials_valid_or_raise(
|
MODEL_PROVIDER_CLASS.is_model_credentials_valid_or_raise(
|
||||||
model_name='test_model_name',
|
model_name='test_model_name',
|
||||||
|
@ -23,14 +23,31 @@ def decrypt_side_effect(tenant_id, encrypted_key):
|
|||||||
return encrypted_key.replace('encrypted_', '')
|
return encrypted_key.replace('encrypted_', '')
|
||||||
|
|
||||||
|
|
||||||
|
def version_effect(id: str):
|
||||||
|
mock_version = MagicMock()
|
||||||
|
mock_version.openapi_schema = {
|
||||||
|
'components': {
|
||||||
|
'schemas': {
|
||||||
|
'Output': {
|
||||||
|
'items': {
|
||||||
|
'type': 'string'
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return mock_version
|
||||||
|
|
||||||
|
@patch('replicate.version.VersionCollection.get', side_effect=version_effect)
|
||||||
def test_is_credentials_valid_or_raise_valid(mocker):
|
def test_is_credentials_valid_or_raise_valid(mocker):
|
||||||
mock_query = MagicMock()
|
mock_query = MagicMock()
|
||||||
mock_query.return_value = None
|
mock_query.return_value = None
|
||||||
|
|
||||||
mocker.patch('replicate.model.ModelCollection.get', return_value=mock_query)
|
mocker.patch('replicate.model.ModelCollection.get', return_value=mock_query)
|
||||||
mocker.patch('replicate.model.Model.versions', return_value=mock_query)
|
|
||||||
|
|
||||||
MODEL_PROVIDER_CLASS.is_model_credentials_valid_or_raise(
|
MODEL_PROVIDER_CLASS.is_model_credentials_valid_or_raise(
|
||||||
model_name='test_model_name',
|
model_name='username/test_model_name',
|
||||||
model_type=ModelType.TEXT_GENERATION,
|
model_type=ModelType.TEXT_GENERATION,
|
||||||
credentials=VALIDATE_CREDENTIAL.copy()
|
credentials=VALIDATE_CREDENTIAL.copy()
|
||||||
)
|
)
|
||||||
|
@ -26,7 +26,7 @@ def decrypt_side_effect(tenant_id, encrypted_key):
|
|||||||
|
|
||||||
|
|
||||||
def test_is_provider_credentials_valid_or_raise_valid(mocker):
|
def test_is_provider_credentials_valid_or_raise_valid(mocker):
|
||||||
mocker.patch('langchain.llms.tongyi.Tongyi._generate', return_value=LLMResult(generations=[[Generation(text="abc")]]))
|
mocker.patch('core.third_party.langchain.llms.tongyi_llm.EnhanceTongyi._generate', return_value=LLMResult(generations=[[Generation(text="abc")]]))
|
||||||
|
|
||||||
MODEL_PROVIDER_CLASS.is_provider_credentials_valid_or_raise(VALIDATE_CREDENTIAL)
|
MODEL_PROVIDER_CLASS.is_provider_credentials_valid_or_raise(VALIDATE_CREDENTIAL)
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user