From 7b8a10f3ea7efca8efc9e0cf00026ca76e2b1893 Mon Sep 17 00:00:00 2001 From: Garfield Dai Date: Tue, 5 Dec 2023 16:53:55 +0800 Subject: [PATCH] feat: billing enhancement 20231204 (#1691) Co-authored-by: jyong --- api/config.py | 2 - api/controllers/console/app/app.py | 3 +- .../console/datasets/datasets_document.py | 29 ++----------- .../console/datasets/datasets_segments.py | 6 ++- .../console/explore/installed_app.py | 2 + api/controllers/console/workspace/members.py | 3 +- api/controllers/console/wraps.py | 28 ++++++++++++ .../service_api/dataset/document.py | 6 ++- .../service_api/dataset/segment.py | 5 ++- api/controllers/service_api/wraps.py | 28 ++++++++++++ api/services/billing_service.py | 12 +++--- api/services/dataset_service.py | 43 +------------------ 12 files changed, 86 insertions(+), 81 deletions(-) diff --git a/api/config.py b/api/config.py index 9223b2e2c4..048af509ce 100644 --- a/api/config.py +++ b/api/config.py @@ -54,7 +54,6 @@ DEFAULTS = { 'HOSTED_ANTHROPIC_PAID_MAX_QUANTITY': 100, 'HOSTED_MODERATION_ENABLED': 'False', 'HOSTED_MODERATION_PROVIDERS': '', - 'TENANT_DOCUMENT_COUNT': 100, 'CLEAN_DAY_SETTING': 30, 'UPLOAD_FILE_SIZE_LIMIT': 15, 'UPLOAD_FILE_BATCH_LIMIT': 5, @@ -240,7 +239,6 @@ class Config: self.MULTIMODAL_SEND_IMAGE_FORMAT = get_env('MULTIMODAL_SEND_IMAGE_FORMAT') # Dataset Configurations. - self.TENANT_DOCUMENT_COUNT = get_env('TENANT_DOCUMENT_COUNT') self.CLEAN_DAY_SETTING = get_env('CLEAN_DAY_SETTING') # File upload Configurations. diff --git a/api/controllers/console/app/app.py b/api/controllers/console/app/app.py index 3f811d7055..ff41c929ac 100644 --- a/api/controllers/console/app/app.py +++ b/api/controllers/console/app/app.py @@ -12,7 +12,7 @@ from constants.model_template import model_templates, demo_model_templates from controllers.console import api from controllers.console.app.error import AppNotFoundError, ProviderNotInitializeError from controllers.console.setup import setup_required -from controllers.console.wraps import account_initialization_required +from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check from core.model_providers.error import ProviderTokenNotInitError, LLMBadRequestError from core.model_providers.model_factory import ModelFactory from core.model_providers.model_provider_factory import ModelProviderFactory @@ -57,6 +57,7 @@ class AppListApi(Resource): @login_required @account_initialization_required @marshal_with(app_detail_fields) + @cloud_edition_billing_resource_check('apps') def post(self): """Create app""" parser = reqparse.RequestParser() diff --git a/api/controllers/console/datasets/datasets_document.py b/api/controllers/console/datasets/datasets_document.py index 0f5634de4d..58c39d0b46 100644 --- a/api/controllers/console/datasets/datasets_document.py +++ b/api/controllers/console/datasets/datasets_document.py @@ -16,7 +16,7 @@ from controllers.console.app.error import ProviderNotInitializeError, ProviderQu from controllers.console.datasets.error import DocumentAlreadyFinishedError, InvalidActionError, DocumentIndexingError, \ InvalidMetadataError, ArchivedDocumentImmutableError from controllers.console.setup import setup_required -from controllers.console.wraps import account_initialization_required +from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check from core.indexing_runner import IndexingRunner from core.model_providers.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError, \ LLMBadRequestError @@ -194,6 +194,7 @@ class DatasetDocumentListApi(Resource): @login_required @account_initialization_required @marshal_with(documents_and_batch_fields) + @cloud_edition_billing_resource_check('vector_space') def post(self, dataset_id): dataset_id = str(dataset_id) @@ -252,6 +253,7 @@ class DatasetInitApi(Resource): @login_required @account_initialization_required @marshal_with(dataset_and_document_fields) + @cloud_edition_billing_resource_check('vector_space') def post(self): # The role of the current user in the ta table must be admin or owner if current_user.current_tenant.current_role not in ['admin', 'owner']: @@ -693,6 +695,7 @@ class DocumentStatusApi(DocumentResource): @setup_required @login_required @account_initialization_required + @cloud_edition_billing_resource_check('vector_space') def patch(self, dataset_id, document_id, action): dataset_id = str(dataset_id) document_id = str(document_id) @@ -770,14 +773,6 @@ class DocumentStatusApi(DocumentResource): if not document.archived: raise InvalidActionError('Document is not archived.') - # check document limit - if current_app.config['EDITION'] == 'CLOUD': - documents_count = DocumentService.get_tenant_documents_count() - total_count = documents_count + 1 - tenant_document_count = int(current_app.config['TENANT_DOCUMENT_COUNT']) - if total_count > tenant_document_count: - raise ValueError(f"All your documents have overed limit {tenant_document_count}.") - document.archived = False document.archived_at = None document.archived_by = None @@ -856,21 +851,6 @@ class DocumentRecoverApi(DocumentResource): return {'result': 'success'}, 204 -class DocumentLimitApi(DocumentResource): - @setup_required - @login_required - @account_initialization_required - def get(self): - """get document limit""" - documents_count = DocumentService.get_tenant_documents_count() - tenant_document_count = int(current_app.config['TENANT_DOCUMENT_COUNT']) - - return { - 'documents_count': documents_count, - 'documents_limit': tenant_document_count - }, 200 - - api.add_resource(GetProcessRuleApi, '/datasets/process-rule') api.add_resource(DatasetDocumentListApi, '/datasets//documents') @@ -896,4 +876,3 @@ api.add_resource(DocumentStatusApi, '/datasets//documents//status/') api.add_resource(DocumentPauseApi, '/datasets//documents//processing/pause') api.add_resource(DocumentRecoverApi, '/datasets//documents//processing/resume') -api.add_resource(DocumentLimitApi, '/datasets/limit') diff --git a/api/controllers/console/datasets/datasets_segments.py b/api/controllers/console/datasets/datasets_segments.py index a0715233a4..6051d12999 100644 --- a/api/controllers/console/datasets/datasets_segments.py +++ b/api/controllers/console/datasets/datasets_segments.py @@ -11,7 +11,7 @@ from controllers.console import api from controllers.console.app.error import ProviderNotInitializeError from controllers.console.datasets.error import InvalidActionError, NoFileUploadedError, TooManyFilesError from controllers.console.setup import setup_required -from controllers.console.wraps import account_initialization_required +from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check from core.model_providers.error import LLMBadRequestError, ProviderTokenNotInitError from core.model_providers.model_factory import ModelFactory from libs.login import login_required @@ -114,6 +114,7 @@ class DatasetDocumentSegmentApi(Resource): @setup_required @login_required @account_initialization_required + @cloud_edition_billing_resource_check('vector_space') def patch(self, dataset_id, segment_id, action): dataset_id = str(dataset_id) dataset = DatasetService.get_dataset(dataset_id) @@ -200,6 +201,7 @@ class DatasetDocumentSegmentAddApi(Resource): @setup_required @login_required @account_initialization_required + @cloud_edition_billing_resource_check('vector_space') def post(self, dataset_id, document_id): # check dataset dataset_id = str(dataset_id) @@ -250,6 +252,7 @@ class DatasetDocumentSegmentUpdateApi(Resource): @setup_required @login_required @account_initialization_required + @cloud_edition_billing_resource_check('vector_space') def patch(self, dataset_id, document_id, segment_id): # check dataset dataset_id = str(dataset_id) @@ -344,6 +347,7 @@ class DatasetDocumentSegmentBatchImportApi(Resource): @setup_required @login_required @account_initialization_required + @cloud_edition_billing_resource_check('vector_space') def post(self, dataset_id, document_id): # check dataset dataset_id = str(dataset_id) diff --git a/api/controllers/console/explore/installed_app.py b/api/controllers/console/explore/installed_app.py index d7ee991663..b1e30f4455 100644 --- a/api/controllers/console/explore/installed_app.py +++ b/api/controllers/console/explore/installed_app.py @@ -14,6 +14,7 @@ from extensions.ext_database import db from fields.installed_app_fields import installed_app_list_fields from models.model import App, InstalledApp, RecommendedApp from services.account_service import TenantService +from controllers.console.wraps import cloud_edition_billing_resource_check class InstalledAppsListApi(Resource): @@ -47,6 +48,7 @@ class InstalledAppsListApi(Resource): @login_required @account_initialization_required + @cloud_edition_billing_resource_check('apps') def post(self): parser = reqparse.RequestParser() parser.add_argument('app_id', type=str, required=True, help='Invalid app_id') diff --git a/api/controllers/console/workspace/members.py b/api/controllers/console/workspace/members.py index 00c3e173af..104180e4a6 100644 --- a/api/controllers/console/workspace/members.py +++ b/api/controllers/console/workspace/members.py @@ -7,7 +7,7 @@ from flask_restful import Resource, reqparse, marshal_with, abort, fields, marsh import services from controllers.console import api from controllers.console.setup import setup_required -from controllers.console.wraps import account_initialization_required +from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check from libs.helper import TimestampField from extensions.ext_database import db from models.account import Account, TenantAccountJoin @@ -47,6 +47,7 @@ class MemberInviteEmailApi(Resource): @setup_required @login_required @account_initialization_required + @cloud_edition_billing_resource_check('members') def post(self): parser = reqparse.RequestParser() parser.add_argument('emails', type=str, required=True, location='json', action='append') diff --git a/api/controllers/console/wraps.py b/api/controllers/console/wraps.py index 41ce4f200b..efe86ea8c3 100644 --- a/api/controllers/console/wraps.py +++ b/api/controllers/console/wraps.py @@ -5,6 +5,7 @@ from flask import current_app, abort from flask_login import current_user from controllers.console.workspace.error import AccountNotInitializedError +from services.billing_service import BillingService def account_initialization_required(view): @@ -41,3 +42,30 @@ def only_edition_self_hosted(view): return view(*args, **kwargs) return decorated + + +def cloud_edition_billing_resource_check(resource: str, + error_msg: str = "You have reached the limit of your subscription."): + def interceptor(view): + @wraps(view) + def decorated(*args, **kwargs): + if current_app.config['EDITION'] == 'CLOUD': + tenant_id = current_user.current_tenant_id + billing_info = BillingService.get_info(tenant_id) + members = billing_info['members'] + apps = billing_info['apps'] + vector_space = billing_info['vector_space'] + + if resource == 'members' and 0 < members['limit'] <= members['size']: + abort(403, error_msg) + elif resource == 'apps' and 0 < apps['limit'] <= apps['size']: + abort(403, error_msg) + elif resource == 'vector_space' and 0 < vector_space['limit'] <= vector_space['size']: + abort(403, error_msg) + else: + return view(*args, **kwargs) + + return view(*args, **kwargs) + return decorated + return interceptor + diff --git a/api/controllers/service_api/dataset/document.py b/api/controllers/service_api/dataset/document.py index e900e84a01..6a4057c1f6 100644 --- a/api/controllers/service_api/dataset/document.py +++ b/api/controllers/service_api/dataset/document.py @@ -11,7 +11,7 @@ from controllers.service_api import api from controllers.service_api.app.error import ProviderNotInitializeError from controllers.service_api.dataset.error import ArchivedDocumentImmutableError, DocumentIndexingError, \ NoFileUploadedError, TooManyFilesError -from controllers.service_api.wraps import DatasetApiResource +from controllers.service_api.wraps import DatasetApiResource, cloud_edition_billing_resource_check from libs.login import current_user from core.model_providers.error import ProviderTokenNotInitError from extensions.ext_database import db @@ -24,6 +24,7 @@ from services.file_service import FileService class DocumentAddByTextApi(DatasetApiResource): """Resource for documents.""" + @cloud_edition_billing_resource_check('vector_space', 'dataset') def post(self, tenant_id, dataset_id): """Create document by text.""" parser = reqparse.RequestParser() @@ -88,6 +89,7 @@ class DocumentAddByTextApi(DatasetApiResource): class DocumentUpdateByTextApi(DatasetApiResource): """Resource for update documents.""" + @cloud_edition_billing_resource_check('vector_space', 'dataset') def post(self, tenant_id, dataset_id, document_id): """Update document by text.""" parser = reqparse.RequestParser() @@ -147,6 +149,7 @@ class DocumentUpdateByTextApi(DatasetApiResource): class DocumentAddByFileApi(DatasetApiResource): """Resource for documents.""" + @cloud_edition_billing_resource_check('vector_space', 'dataset') def post(self, tenant_id, dataset_id): """Create document by upload file.""" args = {} @@ -212,6 +215,7 @@ class DocumentAddByFileApi(DatasetApiResource): class DocumentUpdateByFileApi(DatasetApiResource): """Resource for update documents.""" + @cloud_edition_billing_resource_check('vector_space', 'dataset') def post(self, tenant_id, dataset_id, document_id): """Update document by upload file.""" args = {} diff --git a/api/controllers/service_api/dataset/segment.py b/api/controllers/service_api/dataset/segment.py index ddb4487e6e..2cd6da3d13 100644 --- a/api/controllers/service_api/dataset/segment.py +++ b/api/controllers/service_api/dataset/segment.py @@ -3,7 +3,7 @@ from flask_restful import reqparse, marshal from werkzeug.exceptions import NotFound from controllers.service_api import api from controllers.service_api.app.error import ProviderNotInitializeError -from controllers.service_api.wraps import DatasetApiResource +from controllers.service_api.wraps import DatasetApiResource, cloud_edition_billing_resource_check from core.model_providers.error import ProviderTokenNotInitError, LLMBadRequestError from core.model_providers.model_factory import ModelFactory from extensions.ext_database import db @@ -14,6 +14,8 @@ from services.dataset_service import DatasetService, DocumentService, SegmentSer class SegmentApi(DatasetApiResource): """Resource for segments.""" + + @cloud_edition_billing_resource_check('vector_space', 'dataset') def post(self, tenant_id, dataset_id, document_id): """Create single segment.""" # check dataset @@ -144,6 +146,7 @@ class DatasetSegmentApi(DatasetApiResource): SegmentService.delete_segment(segment, document, dataset) return {'result': 'success'}, 200 + @cloud_edition_billing_resource_check('vector_space', 'dataset') def post(self, tenant_id, dataset_id, document_id, segment_id): # check dataset dataset_id = str(dataset_id) diff --git a/api/controllers/service_api/wraps.py b/api/controllers/service_api/wraps.py index 4a45af35d6..0b73297dd5 100644 --- a/api/controllers/service_api/wraps.py +++ b/api/controllers/service_api/wraps.py @@ -11,6 +11,7 @@ from libs.login import _get_user from extensions.ext_database import db from models.account import Tenant, TenantAccountJoin, Account from models.model import ApiToken, App +from services.billing_service import BillingService def validate_app_token(view=None): @@ -40,6 +41,33 @@ def validate_app_token(view=None): return decorator +def cloud_edition_billing_resource_check(resource: str, + api_token_type: str, + error_msg: str = "You have reached the limit of your subscription."): + def interceptor(view): + def decorated(*args, **kwargs): + if current_app.config['EDITION'] == 'CLOUD': + api_token = validate_and_get_api_token(api_token_type) + billing_info = BillingService.get_info(api_token.tenant_id) + + members = billing_info['members'] + apps = billing_info['apps'] + vector_space = billing_info['vector_space'] + + if resource == 'members' and 0 < members['limit'] <= members['size']: + raise Unauthorized(error_msg) + elif resource == 'apps' and 0 < apps['limit'] <= apps['size']: + raise Unauthorized(error_msg) + elif resource == 'vector_space' and 0 < vector_space['limit'] <= vector_space['size']: + raise Unauthorized(error_msg) + else: + return view(*args, **kwargs) + + return view(*args, **kwargs) + return decorated + return interceptor + + def validate_dataset_token(view=None): def decorator(view): @wraps(view) diff --git a/api/services/billing_service.py b/api/services/billing_service.py index ad8e3dd6e9..2f425a61c8 100644 --- a/api/services/billing_service.py +++ b/api/services/billing_service.py @@ -1,8 +1,6 @@ import os import requests -from services.dataset_service import DatasetService - class BillingService: base_url = os.environ.get('BILLING_API_URL', 'BILLING_API_URL') @@ -14,14 +12,14 @@ class BillingService: billing_info = cls._send_request('GET', '/info', params=params) - vector_size = DatasetService.get_tenant_datasets_usage(tenant_id) - # Convert bytes to MB - billing_info['vector_space']['size'] = int(vector_size / 1024 / 1024) - return billing_info @classmethod - def get_subscription(cls, plan: str, interval: str, prefilled_email: str = '', user_name: str = '', tenant_id: str = ''): + def get_subscription(cls, plan: str, + interval: str, + prefilled_email: str = '', + user_name: str = '', + tenant_id: str = ''): params = { 'plan': plan, 'interval': interval, diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py index 47ecca02c7..f6b57321b1 100644 --- a/api/services/dataset_service.py +++ b/api/services/dataset_service.py @@ -227,36 +227,6 @@ class DatasetService: return AppDatasetJoin.query.filter(AppDatasetJoin.dataset_id == dataset_id) \ .order_by(db.desc(AppDatasetJoin.created_at)).all() - @staticmethod - def get_tenant_datasets_usage(tenant_id): - # get the high_quality datasets - dataset_ids = db.session.query(Dataset.id).filter(Dataset.indexing_technique == 'high_quality', - Dataset.tenant_id == tenant_id).all() - if not dataset_ids: - return 0 - dataset_ids = [result[0] for result in dataset_ids] - document_ids = db.session.query(Document.id).filter(Document.dataset_id.in_(dataset_ids), - Document.tenant_id == tenant_id, - Document.completed_at.isnot(None), - Document.enabled == True, - Document.archived == False - ).all() - if not document_ids: - return 0 - document_ids = [result[0] for result in document_ids] - document_segments = db.session.query(DocumentSegment).filter(DocumentSegment.document_id.in_(document_ids), - DocumentSegment.tenant_id == tenant_id, - DocumentSegment.completed_at.isnot(None), - DocumentSegment.enabled == True, - ).all() - if not document_segments: - return 0 - - total_words_size = sum(document_segment.word_count * 3 for document_segment in document_segments) - total_vector_size = 1536 * 4 * len(document_segments) - - return total_words_size + total_vector_size - class DocumentService: DEFAULT_RULES = { @@ -480,11 +450,6 @@ class DocumentService: notion_info_list = document_data["data_source"]['info_list']['notion_info_list'] for notion_info in notion_info_list: count = count + len(notion_info['pages']) - documents_count = DocumentService.get_tenant_documents_count() - total_count = documents_count + count - tenant_document_count = int(current_app.config['TENANT_DOCUMENT_COUNT']) - if total_count > tenant_document_count: - raise ValueError(f"over document limit {tenant_document_count}.") # if dataset is empty, update dataset data_source_type if not dataset.data_source_type: dataset.data_source_type = document_data["data_source"]["type"] @@ -770,13 +735,7 @@ class DocumentService: notion_info_list = document_data["data_source"]['info_list']['notion_info_list'] for notion_info in notion_info_list: count = count + len(notion_info['pages']) - # check document limit - if current_app.config['EDITION'] == 'CLOUD': - documents_count = DocumentService.get_tenant_documents_count() - total_count = documents_count + count - tenant_document_count = int(current_app.config['TENANT_DOCUMENT_COUNT']) - if total_count > tenant_document_count: - raise ValueError(f"All your documents have overed limit {tenant_document_count}.") + embedding_model = None dataset_collection_binding_id = None retrieval_model = None