diff --git a/api/controllers/console/datasets/datasets_document.py b/api/controllers/console/datasets/datasets_document.py index f0a884cbec..717edf07a6 100644 --- a/api/controllers/console/datasets/datasets_document.py +++ b/api/controllers/console/datasets/datasets_document.py @@ -3,7 +3,7 @@ import random from datetime import datetime from typing import List -from flask import request +from flask import request, current_app from flask_login import current_user from core.login.login import login_required from flask_restful import Resource, fields, marshal, marshal_with, reqparse @@ -275,7 +275,8 @@ class DatasetDocumentListApi(Resource): parser.add_argument('duplicate', type=bool, nullable=False, location='json') parser.add_argument('original_document_id', type=str, required=False, location='json') parser.add_argument('doc_form', type=str, default='text_model', required=False, nullable=False, location='json') - parser.add_argument('doc_language', type=str, default='English', required=False, nullable=False, location='json') + parser.add_argument('doc_language', type=str, default='English', required=False, nullable=False, + location='json') args = parser.parse_args() if not dataset.indexing_technique and not args['indexing_technique']: @@ -335,7 +336,8 @@ class DatasetInitApi(Resource): parser.add_argument('data_source', type=dict, required=True, nullable=True, location='json') parser.add_argument('process_rule', type=dict, required=True, nullable=True, location='json') parser.add_argument('doc_form', type=str, default='text_model', required=False, nullable=False, location='json') - parser.add_argument('doc_language', type=str, default='English', required=False, nullable=False, location='json') + parser.add_argument('doc_language', type=str, default='English', required=False, nullable=False, + location='json') args = parser.parse_args() try: @@ -483,7 +485,7 @@ class DocumentBatchIndexingEstimateApi(DocumentResource): indexing_runner = IndexingRunner() try: response = indexing_runner.file_indexing_estimate(current_user.current_tenant_id, file_details, - data_process_rule_dict, None, dataset_id) + data_process_rule_dict, None, dataset_id) except LLMBadRequestError: raise ProviderNotInitializeError( f"No Embedding Model available. Please configure a valid provider " @@ -855,6 +857,14 @@ class DocumentStatusApi(DocumentResource): if not document.archived: raise InvalidActionError('Document is not archived.') + # check document limit + if current_app.config['EDITION'] == 'CLOUD': + documents_count = DocumentService.get_tenant_documents_count() + total_count = documents_count + 1 + tenant_document_count = int(current_app.config['TENANT_DOCUMENT_COUNT']) + if total_count > tenant_document_count: + raise ValueError(f"All your documents have overed limit {tenant_document_count}.") + document.archived = False document.archived_at = None document.archived_by = None @@ -872,6 +882,10 @@ class DocumentStatusApi(DocumentResource): class DocumentPauseApi(DocumentResource): + + @setup_required + @login_required + @account_initialization_required def patch(self, dataset_id, document_id): """pause document.""" dataset_id = str(dataset_id) @@ -901,6 +915,9 @@ class DocumentPauseApi(DocumentResource): class DocumentRecoverApi(DocumentResource): + @setup_required + @login_required + @account_initialization_required def patch(self, dataset_id, document_id): """recover document.""" dataset_id = str(dataset_id) @@ -926,6 +943,21 @@ class DocumentRecoverApi(DocumentResource): return {'result': 'success'}, 204 +class DocumentLimitApi(DocumentResource): + @setup_required + @login_required + @account_initialization_required + def get(self): + """get document limit""" + documents_count = DocumentService.get_tenant_documents_count() + tenant_document_count = int(current_app.config['TENANT_DOCUMENT_COUNT']) + + return { + 'documents_count': documents_count, + 'documents_limit': tenant_document_count + }, 200 + + api.add_resource(GetProcessRuleApi, '/datasets/process-rule') api.add_resource(DatasetDocumentListApi, '/datasets//documents') @@ -951,3 +983,4 @@ api.add_resource(DocumentStatusApi, '/datasets//documents//status/') api.add_resource(DocumentPauseApi, '/datasets//documents//processing/pause') api.add_resource(DocumentRecoverApi, '/datasets//documents//processing/resume') +api.add_resource(DocumentLimitApi, '/datasets/limit') diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py index bcaa3e1bc1..892ec52d63 100644 --- a/api/services/dataset_service.py +++ b/api/services/dataset_service.py @@ -394,11 +394,20 @@ class DocumentService: def save_document_with_dataset_id(dataset: Dataset, document_data: dict, account: Account, dataset_process_rule: Optional[DatasetProcessRule] = None, created_from: str = 'web'): + # check document limit if current_app.config['EDITION'] == 'CLOUD': + count = 0 + if document_data["data_source"]["type"] == "upload_file": + upload_file_list = document_data["data_source"]["info_list"]['file_info_list']['file_ids'] + count = len(upload_file_list) + elif document_data["data_source"]["type"] == "notion_import": + notion_page_list = document_data["data_source"]['info_list']['notion_info_list']['pages'] + count = len(notion_page_list) documents_count = DocumentService.get_tenant_documents_count() + total_count = documents_count + count tenant_document_count = int(current_app.config['TENANT_DOCUMENT_COUNT']) - if documents_count > tenant_document_count: + if total_count > tenant_document_count: raise ValueError(f"over document limit {tenant_document_count}.") # if dataset is empty, update dataset data_source_type if not dataset.data_source_type: @@ -649,12 +658,20 @@ class DocumentService: @staticmethod def save_document_without_dataset_id(tenant_id: str, document_data: dict, account: Account): + count = 0 + if document_data["data_source"]["type"] == "upload_file": + upload_file_list = document_data["data_source"]["info_list"]['file_info_list']['file_ids'] + count = len(upload_file_list) + elif document_data["data_source"]["type"] == "notion_import": + notion_page_list = document_data["data_source"]['info_list']['notion_info_list']['pages'] + count = len(notion_page_list) # check document limit if current_app.config['EDITION'] == 'CLOUD': documents_count = DocumentService.get_tenant_documents_count() + total_count = documents_count + count tenant_document_count = int(current_app.config['TENANT_DOCUMENT_COUNT']) - if documents_count > tenant_document_count: - raise ValueError(f"over document limit {tenant_document_count}.") + if total_count > tenant_document_count: + raise ValueError(f"All your documents have overed limit {tenant_document_count}.") embedding_model = ModelFactory.get_embedding_model( tenant_id=tenant_id )