mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-11 08:18:58 +08:00
feat: optimize model when app create (#875)
This commit is contained in:
parent
cc2d71c253
commit
b7c29ea1b6
38
.github/workflows/api-unit-tests.yml
vendored
Normal file
38
.github/workflows/api-unit-tests.yml
vendored
Normal file
@ -0,0 +1,38 @@
|
||||
name: Run Pytest
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
branches:
|
||||
- main
|
||||
push:
|
||||
branches:
|
||||
- deploy/dev
|
||||
|
||||
jobs:
|
||||
test:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v2
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v2
|
||||
with:
|
||||
python-version: '3.10'
|
||||
|
||||
- name: Cache pip dependencies
|
||||
uses: actions/cache@v2
|
||||
with:
|
||||
path: ~/.cache/pip
|
||||
key: ${{ runner.os }}-pip-${{ hashFiles('api/requirements.txt') }}
|
||||
restore-keys: ${{ runner.os }}-pip-
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install pytest
|
||||
pip install -r api/requirements.txt
|
||||
|
||||
- name: Run pytest
|
||||
run: pytest api/tests/unit_tests
|
@ -1,5 +1,6 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
import json
|
||||
import logging
|
||||
from datetime import datetime
|
||||
|
||||
from flask_login import login_required, current_user
|
||||
@ -11,7 +12,9 @@ from controllers.console import api
|
||||
from controllers.console.app.error import AppNotFoundError, ProviderNotInitializeError
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from core.model_providers.error import ProviderTokenNotInitError, LLMBadRequestError
|
||||
from core.model_providers.model_factory import ModelFactory
|
||||
from core.model_providers.model_provider_factory import ModelProviderFactory
|
||||
from core.model_providers.models.entity.model_params import ModelType
|
||||
from events.app_event import app_was_created, app_was_deleted
|
||||
from libs.helper import TimestampField
|
||||
@ -124,24 +127,34 @@ class AppListApi(Resource):
|
||||
if current_user.current_tenant.current_role not in ['admin', 'owner']:
|
||||
raise Forbidden()
|
||||
|
||||
default_model = ModelFactory.get_default_model(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
model_type=ModelType.TEXT_GENERATION
|
||||
try:
|
||||
default_model = ModelFactory.get_text_generation_model(
|
||||
tenant_id=current_user.current_tenant_id
|
||||
)
|
||||
|
||||
if default_model:
|
||||
default_model_provider = default_model.provider_name
|
||||
default_model_name = default_model.model_name
|
||||
else:
|
||||
raise ProviderNotInitializeError(
|
||||
f"No Text Generation Model available. Please configure a valid provider "
|
||||
f"in the Settings -> Model Provider.")
|
||||
except (ProviderTokenNotInitError, LLMBadRequestError):
|
||||
default_model = None
|
||||
except Exception as e:
|
||||
logging.exception(e)
|
||||
default_model = None
|
||||
|
||||
if args['model_config'] is not None:
|
||||
# validate config
|
||||
model_config_dict = args['model_config']
|
||||
model_config_dict["model"]["provider"] = default_model_provider
|
||||
model_config_dict["model"]["name"] = default_model_name
|
||||
|
||||
# get model provider
|
||||
model_provider = ModelProviderFactory.get_preferred_model_provider(
|
||||
current_user.current_tenant_id,
|
||||
model_config_dict["model"]["provider"]
|
||||
)
|
||||
|
||||
if not model_provider:
|
||||
if not default_model:
|
||||
raise ProviderNotInitializeError(
|
||||
f"No Default System Reasoning Model available. Please configure "
|
||||
f"in the Settings -> Model Provider.")
|
||||
else:
|
||||
model_config_dict["model"]["provider"] = default_model.model_provider.provider_name
|
||||
model_config_dict["model"]["name"] = default_model.name
|
||||
|
||||
model_configuration = AppModelConfigService.validate_configuration(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
@ -169,9 +182,21 @@ class AppListApi(Resource):
|
||||
app = App(**model_config_template['app'])
|
||||
app_model_config = AppModelConfig(**model_config_template['model_config'])
|
||||
|
||||
# get model provider
|
||||
model_provider = ModelProviderFactory.get_preferred_model_provider(
|
||||
current_user.current_tenant_id,
|
||||
app_model_config.model_dict["provider"]
|
||||
)
|
||||
|
||||
if not model_provider:
|
||||
if not default_model:
|
||||
raise ProviderNotInitializeError(
|
||||
f"No Default System Reasoning Model available. Please configure "
|
||||
f"in the Settings -> Model Provider.")
|
||||
else:
|
||||
model_dict = app_model_config.model_dict
|
||||
model_dict['provider'] = default_model_provider
|
||||
model_dict['name'] = default_model_name
|
||||
model_dict['provider'] = default_model.model_provider.provider_name
|
||||
model_dict['name'] = default_model.name
|
||||
app_model_config.model = json.dumps(model_dict)
|
||||
|
||||
app.name = args['name']
|
||||
|
@ -30,8 +30,9 @@ def decrypt_side_effect(tenant_id, encrypted_key):
|
||||
|
||||
@patch('huggingface_hub.hf_api.ModelInfo')
|
||||
def test_hosted_inference_api_is_credentials_valid_or_raise_valid(mock_model_info, mocker):
|
||||
mock_model_info.return_value = MagicMock(pipeline_tag='text2text-generation')
|
||||
mocker.patch('langchain.llms.huggingface_hub.HuggingFaceHub._call', return_value="abc")
|
||||
mock_model_info.return_value = MagicMock(pipeline_tag='text2text-generation', cardData={'inference': True})
|
||||
mocker.patch('huggingface_hub.hf_api.HfApi.whoami', return_value="abc")
|
||||
mocker.patch('huggingface_hub.hf_api.HfApi.model_info', return_value=mock_model_info.return_value)
|
||||
|
||||
MODEL_PROVIDER_CLASS.is_model_credentials_valid_or_raise(
|
||||
model_name='test_model_name',
|
||||
|
@ -23,14 +23,31 @@ def decrypt_side_effect(tenant_id, encrypted_key):
|
||||
return encrypted_key.replace('encrypted_', '')
|
||||
|
||||
|
||||
def version_effect(id: str):
|
||||
mock_version = MagicMock()
|
||||
mock_version.openapi_schema = {
|
||||
'components': {
|
||||
'schemas': {
|
||||
'Output': {
|
||||
'items': {
|
||||
'type': 'string'
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return mock_version
|
||||
|
||||
@patch('replicate.version.VersionCollection.get', side_effect=version_effect)
|
||||
def test_is_credentials_valid_or_raise_valid(mocker):
|
||||
mock_query = MagicMock()
|
||||
mock_query.return_value = None
|
||||
|
||||
mocker.patch('replicate.model.ModelCollection.get', return_value=mock_query)
|
||||
mocker.patch('replicate.model.Model.versions', return_value=mock_query)
|
||||
|
||||
MODEL_PROVIDER_CLASS.is_model_credentials_valid_or_raise(
|
||||
model_name='test_model_name',
|
||||
model_name='username/test_model_name',
|
||||
model_type=ModelType.TEXT_GENERATION,
|
||||
credentials=VALIDATE_CREDENTIAL.copy()
|
||||
)
|
||||
|
@ -26,7 +26,7 @@ def decrypt_side_effect(tenant_id, encrypted_key):
|
||||
|
||||
|
||||
def test_is_provider_credentials_valid_or_raise_valid(mocker):
|
||||
mocker.patch('langchain.llms.tongyi.Tongyi._generate', return_value=LLMResult(generations=[[Generation(text="abc")]]))
|
||||
mocker.patch('core.third_party.langchain.llms.tongyi_llm.EnhanceTongyi._generate', return_value=LLMResult(generations=[[Generation(text="abc")]]))
|
||||
|
||||
MODEL_PROVIDER_CLASS.is_provider_credentials_valid_or_raise(VALIDATE_CREDENTIAL)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user