diff --git a/api/controllers/console/datasets/datasets_segments.py b/api/controllers/console/datasets/datasets_segments.py index 1395963f1d..7b58120a58 100644 --- a/api/controllers/console/datasets/datasets_segments.py +++ b/api/controllers/console/datasets/datasets_segments.py @@ -12,7 +12,11 @@ from controllers.console import api from controllers.console.app.error import ProviderNotInitializeError from controllers.console.datasets.error import InvalidActionError, NoFileUploadedError, TooManyFilesError from controllers.console.setup import setup_required -from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check +from controllers.console.wraps import ( + account_initialization_required, + cloud_edition_billing_knowledge_limit_check, + cloud_edition_billing_resource_check, +) from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError from core.model_manager import ModelManager from core.model_runtime.entities.model_entities import ModelType @@ -207,6 +211,7 @@ class DatasetDocumentSegmentAddApi(Resource): @login_required @account_initialization_required @cloud_edition_billing_resource_check('vector_space') + @cloud_edition_billing_knowledge_limit_check('add_segment') def post(self, dataset_id, document_id): # check dataset dataset_id = str(dataset_id) @@ -357,6 +362,7 @@ class DatasetDocumentSegmentBatchImportApi(Resource): @login_required @account_initialization_required @cloud_edition_billing_resource_check('vector_space') + @cloud_edition_billing_knowledge_limit_check('add_segment') def post(self, dataset_id, document_id): # check dataset dataset_id = str(dataset_id) diff --git a/api/controllers/console/wraps.py b/api/controllers/console/wraps.py index 84f9918470..7c8ad11078 100644 --- a/api/controllers/console/wraps.py +++ b/api/controllers/console/wraps.py @@ -51,14 +51,12 @@ def cloud_edition_billing_resource_check(resource: str, @wraps(view) def decorated(*args, **kwargs): features = FeatureService.get_features(current_user.current_tenant_id) - if features.billing.enabled: members = features.members apps = features.apps vector_space = features.vector_space documents_upload_quota = features.documents_upload_quota annotation_quota_limit = features.annotation_quota_limit - if resource == 'members' and 0 < members.limit <= members.size: abort(403, error_msg) elif resource == 'apps' and 0 < apps.limit <= apps.size: @@ -80,7 +78,29 @@ def cloud_edition_billing_resource_check(resource: str, return view(*args, **kwargs) return view(*args, **kwargs) + return decorated + + return interceptor + + +def cloud_edition_billing_knowledge_limit_check(resource: str, + error_msg: str = "To unlock this feature and elevate your Dify experience, please upgrade to a paid plan."): + def interceptor(view): + @wraps(view) + def decorated(*args, **kwargs): + features = FeatureService.get_features(current_user.current_tenant_id) + if features.billing.enabled: + if resource == 'add_segment': + if features.billing.subscription.plan == 'sandbox': + abort(403, error_msg) + else: + return view(*args, **kwargs) + + return view(*args, **kwargs) + + return decorated + return interceptor @@ -99,4 +119,5 @@ def cloud_utm_record(view): except Exception as e: pass return view(*args, **kwargs) + return decorated diff --git a/api/controllers/service_api/dataset/segment.py b/api/controllers/service_api/dataset/segment.py index 5d3a081357..0849eb72ba 100644 --- a/api/controllers/service_api/dataset/segment.py +++ b/api/controllers/service_api/dataset/segment.py @@ -4,7 +4,11 @@ from werkzeug.exceptions import NotFound from controllers.service_api import api from controllers.service_api.app.error import ProviderNotInitializeError -from controllers.service_api.wraps import DatasetApiResource, cloud_edition_billing_resource_check +from controllers.service_api.wraps import ( + DatasetApiResource, + cloud_edition_billing_knowledge_limit_check, + cloud_edition_billing_resource_check, +) from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError from core.model_manager import ModelManager from core.model_runtime.entities.model_entities import ModelType @@ -18,6 +22,7 @@ class SegmentApi(DatasetApiResource): """Resource for segments.""" @cloud_edition_billing_resource_check('vector_space', 'dataset') + @cloud_edition_billing_knowledge_limit_check('add_segment', 'dataset') def post(self, tenant_id, dataset_id, document_id): """Create single segment.""" # check dataset diff --git a/api/controllers/service_api/wraps.py b/api/controllers/service_api/wraps.py index bdcbaecbea..a75583469e 100644 --- a/api/controllers/service_api/wraps.py +++ b/api/controllers/service_api/wraps.py @@ -8,7 +8,7 @@ from flask import current_app, request from flask_login import user_logged_in from flask_restful import Resource from pydantic import BaseModel -from werkzeug.exceptions import NotFound, Unauthorized +from werkzeug.exceptions import Forbidden, NotFound, Unauthorized from extensions.ext_database import db from libs.login import _get_user @@ -92,13 +92,13 @@ def cloud_edition_billing_resource_check(resource: str, documents_upload_quota = features.documents_upload_quota if resource == 'members' and 0 < members.limit <= members.size: - raise Unauthorized(error_msg) + raise Forbidden(error_msg) elif resource == 'apps' and 0 < apps.limit <= apps.size: - raise Unauthorized(error_msg) + raise Forbidden(error_msg) elif resource == 'vector_space' and 0 < vector_space.limit <= vector_space.size: - raise Unauthorized(error_msg) + raise Forbidden(error_msg) elif resource == 'documents' and 0 < documents_upload_quota.limit <= documents_upload_quota.size: - raise Unauthorized(error_msg) + raise Forbidden(error_msg) else: return view(*args, **kwargs) @@ -107,6 +107,27 @@ def cloud_edition_billing_resource_check(resource: str, return interceptor +def cloud_edition_billing_knowledge_limit_check(resource: str, + api_token_type: str, + error_msg: str = "To unlock this feature and elevate your Dify experience, please upgrade to a paid plan."): + def interceptor(view): + @wraps(view) + def decorated(*args, **kwargs): + api_token = validate_and_get_api_token(api_token_type) + features = FeatureService.get_features(api_token.tenant_id) + if features.billing.enabled: + if resource == 'add_segment': + if features.billing.subscription.plan == 'sandbox': + raise Forbidden(error_msg) + else: + return view(*args, **kwargs) + + return view(*args, **kwargs) + + return decorated + + return interceptor + def validate_dataset_token(view=None): def decorator(view): @wraps(view)