From db7156dafda199d61cf2c41e0776a87d08cb2558 Mon Sep 17 00:00:00 2001 From: Jyong <76649700+JohnJyong@users.noreply.github.com> Date: Fri, 18 Aug 2023 17:37:31 +0800 Subject: [PATCH] Feature/mutil embedding model (#908) Co-authored-by: JzoNg Co-authored-by: jyong Co-authored-by: StyleZhang --- api/controllers/console/datasets/datasets.py | 36 +++- .../console/datasets/datasets_document.py | 43 +++- .../console/datasets/datasets_segments.py | 191 ++++++++++++++++-- .../console/datasets/hit_testing.py | 7 +- api/core/docstore/dataset_docstore.py | 16 +- api/core/generator/llm_generator.py | 4 +- api/core/index/index.py | 4 +- api/core/indexing_runner.py | 113 +++++++---- api/core/prompt/prompts.py | 4 +- api/core/tool/dataset_retriever_tool.py | 14 +- .../2c8af9671032_add_qa_document_language.py | 32 +++ .../e8883b0148c9_add_dataset_model_name.py | 34 ++++ api/models/dataset.py | 5 + api/requirements.txt | 3 +- api/services/dataset_service.py | 163 +++++++++------ api/services/hit_testing_service.py | 4 +- api/services/vector_service.py | 69 +++++++ .../batch_create_segment_to_index_task.py | 95 +++++++++ api/tasks/delete_segment_from_index_task.py | 58 ++++++ ....py => disable_segment_from_index_task.py} | 8 +- .../update_segment_keyword_index_task.py | 11 - .../(commonLayout)/datasets/DatasetCard.tsx | 31 ++- web/app/(commonLayout)/datasets/page.tsx | 6 - web/app/(commonLayout)/list.module.css | 8 + .../dataset-config/card-item/index.tsx | 19 +- .../dataset-config/select-dataset/index.tsx | 21 +- .../select-dataset/style.module.css | 6 +- .../vender/line/general/dots-horizontal.svg | 8 +- .../vender/line/general/DotsHorizontal.json | 14 +- web/app/components/base/popover/index.tsx | 21 +- .../datasets/create/file-uploader/index.tsx | 2 +- .../datasets/create/step-two/index.tsx | 63 ++++-- .../create/step-two/language-select/index.tsx | 38 ++++ .../detail/batch-modal/csv-downloader.tsx | 108 ++++++++++ .../detail/batch-modal/csv-uploader.tsx | 126 ++++++++++++ .../documents/detail/batch-modal/index.tsx | 65 ++++++ .../detail/completed/InfiniteVirtualList.tsx | 7 + .../detail/completed/SegmentCard.tsx | 44 +++- .../documents/detail/completed/index.tsx | 93 ++++++--- .../detail/completed/style.module.css | 21 ++ .../datasets/documents/detail/index.tsx | 81 ++++++-- .../documents/detail/segment-add/index.tsx | 84 ++++++++ .../components/datasets/documents/list.tsx | 78 ++++++- .../datasets/settings/form/index.tsx | 37 +++- web/i18n/lang/dataset-creation.en.ts | 1 + web/i18n/lang/dataset-creation.zh.ts | 1 + web/i18n/lang/dataset-documents.en.ts | 24 ++- web/i18n/lang/dataset-documents.zh.ts | 22 ++ web/i18n/lang/dataset-settings.en.ts | 3 + web/i18n/lang/dataset-settings.zh.ts | 3 + web/i18n/lang/dataset.en.ts | 2 + web/i18n/lang/dataset.zh.ts | 2 + web/models/datasets.ts | 9 + web/service/datasets.ts | 20 +- 54 files changed, 1704 insertions(+), 278 deletions(-) create mode 100644 api/migrations/versions/2c8af9671032_add_qa_document_language.py create mode 100644 api/migrations/versions/e8883b0148c9_add_dataset_model_name.py create mode 100644 api/services/vector_service.py create mode 100644 api/tasks/batch_create_segment_to_index_task.py create mode 100644 api/tasks/delete_segment_from_index_task.py rename api/tasks/{remove_segment_from_index_task.py => disable_segment_from_index_task.py} (89%) create mode 100644 web/app/components/datasets/create/step-two/language-select/index.tsx create mode 100644 web/app/components/datasets/documents/detail/batch-modal/csv-downloader.tsx create mode 100644 web/app/components/datasets/documents/detail/batch-modal/csv-uploader.tsx create mode 100644 web/app/components/datasets/documents/detail/batch-modal/index.tsx create mode 100644 web/app/components/datasets/documents/detail/segment-add/index.tsx diff --git a/api/controllers/console/datasets/datasets.py b/api/controllers/console/datasets/datasets.py index dfe9026eb4..a2bdf28356 100644 --- a/api/controllers/console/datasets/datasets.py +++ b/api/controllers/console/datasets/datasets.py @@ -10,13 +10,15 @@ from controllers.console.datasets.error import DatasetNameDuplicateError from controllers.console.setup import setup_required from controllers.console.wraps import account_initialization_required from core.indexing_runner import IndexingRunner -from core.model_providers.error import LLMBadRequestError +from core.model_providers.error import LLMBadRequestError, ProviderTokenNotInitError from core.model_providers.model_factory import ModelFactory +from core.model_providers.models.entity.model_params import ModelType from libs.helper import TimestampField from extensions.ext_database import db from models.dataset import DocumentSegment, Document from models.model import UploadFile from services.dataset_service import DatasetService, DocumentService +from services.provider_service import ProviderService dataset_detail_fields = { 'id': fields.String, @@ -33,6 +35,9 @@ dataset_detail_fields = { 'created_at': TimestampField, 'updated_by': fields.String, 'updated_at': TimestampField, + 'embedding_model': fields.String, + 'embedding_model_provider': fields.String, + 'embedding_available': fields.Boolean } dataset_query_detail_fields = { @@ -74,8 +79,22 @@ class DatasetListApi(Resource): datasets, total = DatasetService.get_datasets(page, limit, provider, current_user.current_tenant_id, current_user) + # check embedding setting + provider_service = ProviderService() + valid_model_list = provider_service.get_valid_model_list(current_user.current_tenant_id, ModelType.EMBEDDINGS.value) + # if len(valid_model_list) == 0: + # raise ProviderNotInitializeError( + # f"No Embedding Model available. Please configure a valid provider " + # f"in the Settings -> Model Provider.") + model_names = [item['model_name'] for item in valid_model_list] + data = marshal(datasets, dataset_detail_fields) + for item in data: + if item['embedding_model'] in model_names: + item['embedding_available'] = True + else: + item['embedding_available'] = False response = { - 'data': marshal(datasets, dataset_detail_fields), + 'data': data, 'has_more': len(datasets) == limit, 'limit': limit, 'total': total, @@ -99,7 +118,6 @@ class DatasetListApi(Resource): # The role of the current user in the ta table must be admin or owner if current_user.current_tenant.current_role not in ['admin', 'owner']: raise Forbidden() - try: ModelFactory.get_embedding_model( tenant_id=current_user.current_tenant_id @@ -233,6 +251,8 @@ class DatasetIndexingEstimateApi(Resource): parser.add_argument('info_list', type=dict, required=True, nullable=True, location='json') parser.add_argument('process_rule', type=dict, required=True, nullable=True, location='json') parser.add_argument('doc_form', type=str, default='text_model', required=False, nullable=False, location='json') + parser.add_argument('dataset_id', type=str, required=False, nullable=False, location='json') + parser.add_argument('doc_language', type=str, default='English', required=False, nullable=False, location='json') args = parser.parse_args() # validate args DocumentService.estimate_args_validate(args) @@ -250,11 +270,14 @@ class DatasetIndexingEstimateApi(Resource): try: response = indexing_runner.file_indexing_estimate(current_user.current_tenant_id, file_details, - args['process_rule'], args['doc_form']) + args['process_rule'], args['doc_form'], + args['doc_language'], args['dataset_id']) except LLMBadRequestError: raise ProviderNotInitializeError( f"No Embedding Model available. Please configure a valid provider " f"in the Settings -> Model Provider.") + except ProviderTokenNotInitError as ex: + raise ProviderNotInitializeError(ex.description) elif args['info_list']['data_source_type'] == 'notion_import': indexing_runner = IndexingRunner() @@ -262,11 +285,14 @@ class DatasetIndexingEstimateApi(Resource): try: response = indexing_runner.notion_indexing_estimate(current_user.current_tenant_id, args['info_list']['notion_info_list'], - args['process_rule'], args['doc_form']) + args['process_rule'], args['doc_form'], + args['doc_language'], args['dataset_id']) except LLMBadRequestError: raise ProviderNotInitializeError( f"No Embedding Model available. Please configure a valid provider " f"in the Settings -> Model Provider.") + except ProviderTokenNotInitError as ex: + raise ProviderNotInitializeError(ex.description) else: raise ValueError('Data source type not support') return response, 200 diff --git a/api/controllers/console/datasets/datasets_document.py b/api/controllers/console/datasets/datasets_document.py index a1ef7b767c..5d67e5bfe0 100644 --- a/api/controllers/console/datasets/datasets_document.py +++ b/api/controllers/console/datasets/datasets_document.py @@ -274,6 +274,7 @@ class DatasetDocumentListApi(Resource): parser.add_argument('duplicate', type=bool, nullable=False, location='json') parser.add_argument('original_document_id', type=str, required=False, location='json') parser.add_argument('doc_form', type=str, default='text_model', required=False, nullable=False, location='json') + parser.add_argument('doc_language', type=str, default='English', required=False, nullable=False, location='json') args = parser.parse_args() if not dataset.indexing_technique and not args['indexing_technique']: @@ -282,14 +283,19 @@ class DatasetDocumentListApi(Resource): # validate args DocumentService.document_create_args_validate(args) + # check embedding model setting try: ModelFactory.get_embedding_model( - tenant_id=current_user.current_tenant_id + tenant_id=current_user.current_tenant_id, + model_provider_name=dataset.embedding_model_provider, + model_name=dataset.embedding_model ) except LLMBadRequestError: raise ProviderNotInitializeError( f"No Embedding Model available. Please configure a valid provider " f"in the Settings -> Model Provider.") + except ProviderTokenNotInitError as ex: + raise ProviderNotInitializeError(ex.description) try: documents, batch = DocumentService.save_document_with_dataset_id(dataset, args, current_user) @@ -328,6 +334,7 @@ class DatasetInitApi(Resource): parser.add_argument('data_source', type=dict, required=True, nullable=True, location='json') parser.add_argument('process_rule', type=dict, required=True, nullable=True, location='json') parser.add_argument('doc_form', type=str, default='text_model', required=False, nullable=False, location='json') + parser.add_argument('doc_language', type=str, default='English', required=False, nullable=False, location='json') args = parser.parse_args() try: @@ -406,11 +413,13 @@ class DocumentIndexingEstimateApi(DocumentResource): try: response = indexing_runner.file_indexing_estimate(current_user.current_tenant_id, [file], - data_process_rule_dict) + data_process_rule_dict, None, dataset_id) except LLMBadRequestError: raise ProviderNotInitializeError( f"No Embedding Model available. Please configure a valid provider " f"in the Settings -> Model Provider.") + except ProviderTokenNotInitError as ex: + raise ProviderNotInitializeError(ex.description) return response @@ -473,22 +482,27 @@ class DocumentBatchIndexingEstimateApi(DocumentResource): indexing_runner = IndexingRunner() try: response = indexing_runner.file_indexing_estimate(current_user.current_tenant_id, file_details, - data_process_rule_dict) + data_process_rule_dict, None, dataset_id) except LLMBadRequestError: raise ProviderNotInitializeError( f"No Embedding Model available. Please configure a valid provider " f"in the Settings -> Model Provider.") - elif dataset.data_source_type: + except ProviderTokenNotInitError as ex: + raise ProviderNotInitializeError(ex.description) + elif dataset.data_source_type == 'notion_import': indexing_runner = IndexingRunner() try: response = indexing_runner.notion_indexing_estimate(current_user.current_tenant_id, info_list, - data_process_rule_dict) + data_process_rule_dict, + None, dataset_id) except LLMBadRequestError: raise ProviderNotInitializeError( f"No Embedding Model available. Please configure a valid provider " f"in the Settings -> Model Provider.") + except ProviderTokenNotInitError as ex: + raise ProviderNotInitializeError(ex.description) else: raise ValueError('Data source type not support') return response @@ -575,7 +589,8 @@ class DocumentIndexingStatusApi(DocumentResource): document.completed_segments = completed_segments document.total_segments = total_segments - + if document.is_paused: + document.indexing_status = 'paused' return marshal(document, self.document_status_fields) @@ -832,6 +847,22 @@ class DocumentStatusApi(DocumentResource): remove_document_from_index_task.delay(document_id) + return {'result': 'success'}, 200 + elif action == "un_archive": + if not document.archived: + raise InvalidActionError('Document is not archived.') + + document.archived = False + document.archived_at = None + document.archived_by = None + document.updated_at = datetime.utcnow() + db.session.commit() + + # Set cache to prevent indexing the same document multiple times + redis_client.setex(indexing_cache_key, 600, 1) + + add_document_to_index_task.delay(document_id) + return {'result': 'success'}, 200 else: raise InvalidActionError() diff --git a/api/controllers/console/datasets/datasets_segments.py b/api/controllers/console/datasets/datasets_segments.py index de6d031dd7..7ad4c8a2d6 100644 --- a/api/controllers/console/datasets/datasets_segments.py +++ b/api/controllers/console/datasets/datasets_segments.py @@ -1,15 +1,20 @@ # -*- coding:utf-8 -*- +import uuid from datetime import datetime +from flask import request from flask_login import login_required, current_user from flask_restful import Resource, reqparse, fields, marshal from werkzeug.exceptions import NotFound, Forbidden import services from controllers.console import api -from controllers.console.datasets.error import InvalidActionError +from controllers.console.app.error import ProviderNotInitializeError +from controllers.console.datasets.error import InvalidActionError, NoFileUploadedError, TooManyFilesError from controllers.console.setup import setup_required from controllers.console.wraps import account_initialization_required +from core.model_providers.error import LLMBadRequestError, ProviderTokenNotInitError +from core.model_providers.model_factory import ModelFactory from extensions.ext_database import db from extensions.ext_redis import redis_client from models.dataset import DocumentSegment @@ -17,7 +22,9 @@ from models.dataset import DocumentSegment from libs.helper import TimestampField from services.dataset_service import DatasetService, DocumentService, SegmentService from tasks.enable_segment_to_index_task import enable_segment_to_index_task -from tasks.remove_segment_from_index_task import remove_segment_from_index_task +from tasks.disable_segment_from_index_task import disable_segment_from_index_task +from tasks.batch_create_segment_to_index_task import batch_create_segment_to_index_task +import pandas as pd segment_fields = { 'id': fields.String, @@ -152,6 +159,20 @@ class DatasetDocumentSegmentApi(Resource): except services.errors.account.NoPermissionError as e: raise Forbidden(str(e)) + # check embedding model setting + try: + ModelFactory.get_embedding_model( + tenant_id=current_user.current_tenant_id, + model_provider_name=dataset.embedding_model_provider, + model_name=dataset.embedding_model + ) + except LLMBadRequestError: + raise ProviderNotInitializeError( + f"No Embedding Model available. Please configure a valid provider " + f"in the Settings -> Model Provider.") + except ProviderTokenNotInitError as ex: + raise ProviderNotInitializeError(ex.description) + segment = DocumentSegment.query.filter( DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id @@ -197,7 +218,7 @@ class DatasetDocumentSegmentApi(Resource): # Set cache to prevent indexing the same segment multiple times redis_client.setex(indexing_cache_key, 600, 1) - remove_segment_from_index_task.delay(segment.id) + disable_segment_from_index_task.delay(segment.id) return {'result': 'success'}, 200 else: @@ -222,6 +243,19 @@ class DatasetDocumentSegmentAddApi(Resource): # The role of the current user in the ta table must be admin or owner if current_user.current_tenant.current_role not in ['admin', 'owner']: raise Forbidden() + # check embedding model setting + try: + ModelFactory.get_embedding_model( + tenant_id=current_user.current_tenant_id, + model_provider_name=dataset.embedding_model_provider, + model_name=dataset.embedding_model + ) + except LLMBadRequestError: + raise ProviderNotInitializeError( + f"No Embedding Model available. Please configure a valid provider " + f"in the Settings -> Model Provider.") + except ProviderTokenNotInitError as ex: + raise ProviderNotInitializeError(ex.description) try: DatasetService.check_dataset_permission(dataset, current_user) except services.errors.account.NoPermissionError as e: @@ -233,7 +267,7 @@ class DatasetDocumentSegmentAddApi(Resource): parser.add_argument('keywords', type=list, required=False, nullable=True, location='json') args = parser.parse_args() SegmentService.segment_create_args_validate(args, document) - segment = SegmentService.create_segment(args, document) + segment = SegmentService.create_segment(args, document, dataset) return { 'data': marshal(segment, segment_fields), 'doc_form': document.doc_form @@ -245,6 +279,61 @@ class DatasetDocumentSegmentUpdateApi(Resource): @login_required @account_initialization_required def patch(self, dataset_id, document_id, segment_id): + # check dataset + dataset_id = str(dataset_id) + dataset = DatasetService.get_dataset(dataset_id) + if not dataset: + raise NotFound('Dataset not found.') + # check document + document_id = str(document_id) + document = DocumentService.get_document(dataset_id, document_id) + if not document: + raise NotFound('Document not found.') + # check embedding model setting + try: + ModelFactory.get_embedding_model( + tenant_id=current_user.current_tenant_id, + model_provider_name=dataset.embedding_model_provider, + model_name=dataset.embedding_model + ) + except LLMBadRequestError: + raise ProviderNotInitializeError( + f"No Embedding Model available. Please configure a valid provider " + f"in the Settings -> Model Provider.") + except ProviderTokenNotInitError as ex: + raise ProviderNotInitializeError(ex.description) + # check segment + segment_id = str(segment_id) + segment = DocumentSegment.query.filter( + DocumentSegment.id == str(segment_id), + DocumentSegment.tenant_id == current_user.current_tenant_id + ).first() + if not segment: + raise NotFound('Segment not found.') + # The role of the current user in the ta table must be admin or owner + if current_user.current_tenant.current_role not in ['admin', 'owner']: + raise Forbidden() + try: + DatasetService.check_dataset_permission(dataset, current_user) + except services.errors.account.NoPermissionError as e: + raise Forbidden(str(e)) + # validate args + parser = reqparse.RequestParser() + parser.add_argument('content', type=str, required=True, nullable=False, location='json') + parser.add_argument('answer', type=str, required=False, nullable=True, location='json') + parser.add_argument('keywords', type=list, required=False, nullable=True, location='json') + args = parser.parse_args() + SegmentService.segment_create_args_validate(args, document) + segment = SegmentService.update_segment(args, segment, document, dataset) + return { + 'data': marshal(segment, segment_fields), + 'doc_form': document.doc_form + }, 200 + + @setup_required + @login_required + @account_initialization_required + def delete(self, dataset_id, document_id, segment_id): # check dataset dataset_id = str(dataset_id) dataset = DatasetService.get_dataset(dataset_id) @@ -270,17 +359,88 @@ class DatasetDocumentSegmentUpdateApi(Resource): DatasetService.check_dataset_permission(dataset, current_user) except services.errors.account.NoPermissionError as e: raise Forbidden(str(e)) - # validate args - parser = reqparse.RequestParser() - parser.add_argument('content', type=str, required=True, nullable=False, location='json') - parser.add_argument('answer', type=str, required=False, nullable=True, location='json') - parser.add_argument('keywords', type=list, required=False, nullable=True, location='json') - args = parser.parse_args() - SegmentService.segment_create_args_validate(args, document) - segment = SegmentService.update_segment(args, segment, document) + SegmentService.delete_segment(segment, document, dataset) + return {'result': 'success'}, 200 + + +class DatasetDocumentSegmentBatchImportApi(Resource): + @setup_required + @login_required + @account_initialization_required + def post(self, dataset_id, document_id): + # check dataset + dataset_id = str(dataset_id) + dataset = DatasetService.get_dataset(dataset_id) + if not dataset: + raise NotFound('Dataset not found.') + # check document + document_id = str(document_id) + document = DocumentService.get_document(dataset_id, document_id) + if not document: + raise NotFound('Document not found.') + try: + ModelFactory.get_embedding_model( + tenant_id=current_user.current_tenant_id, + model_provider_name=dataset.embedding_model_provider, + model_name=dataset.embedding_model + ) + except LLMBadRequestError: + raise ProviderNotInitializeError( + f"No Embedding Model available. Please configure a valid provider " + f"in the Settings -> Model Provider.") + except ProviderTokenNotInitError as ex: + raise ProviderNotInitializeError(ex.description) + # get file from request + file = request.files['file'] + # check file + if 'file' not in request.files: + raise NoFileUploadedError() + + if len(request.files) > 1: + raise TooManyFilesError() + # check file type + if not file.filename.endswith('.csv'): + raise ValueError("Invalid file type. Only CSV files are allowed") + + try: + # Skip the first row + df = pd.read_csv(file) + result = [] + for index, row in df.iterrows(): + if document.doc_form == 'qa_model': + data = {'content': row[0], 'answer': row[1]} + else: + data = {'content': row[0]} + result.append(data) + if len(result) == 0: + raise ValueError("The CSV file is empty.") + # async job + job_id = str(uuid.uuid4()) + indexing_cache_key = 'segment_batch_import_{}'.format(str(job_id)) + # send batch add segments task + redis_client.setnx(indexing_cache_key, 'waiting') + batch_create_segment_to_index_task.delay(str(job_id), result, dataset_id, document_id, + current_user.current_tenant_id, current_user.id) + except Exception as e: + return {'error': str(e)}, 500 return { - 'data': marshal(segment, segment_fields), - 'doc_form': document.doc_form + 'job_id': job_id, + 'job_status': 'waiting' + }, 200 + + @setup_required + @login_required + @account_initialization_required + def get(self, job_id): + job_id = str(job_id) + indexing_cache_key = 'segment_batch_import_{}'.format(job_id) + cache_result = redis_client.get(indexing_cache_key) + if cache_result is None: + raise ValueError("The job is not exist.") + + return { + 'job_id': job_id, + 'job_status': cache_result.decode() }, 200 @@ -292,3 +452,6 @@ api.add_resource(DatasetDocumentSegmentAddApi, '/datasets//documents//segment') api.add_resource(DatasetDocumentSegmentUpdateApi, '/datasets//documents//segments/') +api.add_resource(DatasetDocumentSegmentBatchImportApi, + '/datasets//documents//segments/batch_import', + '/datasets/batch_import_status/') diff --git a/api/controllers/console/datasets/hit_testing.py b/api/controllers/console/datasets/hit_testing.py index 399bd4c0c9..00d14d93dc 100644 --- a/api/controllers/console/datasets/hit_testing.py +++ b/api/controllers/console/datasets/hit_testing.py @@ -11,7 +11,8 @@ from controllers.console.app.error import ProviderNotInitializeError, ProviderQu from controllers.console.datasets.error import HighQualityDatasetOnlyError, DatasetNotInitializedError from controllers.console.setup import setup_required from controllers.console.wraps import account_initialization_required -from core.model_providers.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError +from core.model_providers.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError, \ + LLMBadRequestError from libs.helper import TimestampField from services.dataset_service import DatasetService from services.hit_testing_service import HitTestingService @@ -102,6 +103,10 @@ class HitTestingApi(Resource): raise ProviderQuotaExceededError() except ModelCurrentlyNotSupportError: raise ProviderModelCurrentlyNotSupportError() + except LLMBadRequestError: + raise ProviderNotInitializeError( + f"No Embedding Model available. Please configure a valid provider " + f"in the Settings -> Model Provider.") except ValueError as e: raise ValueError(str(e)) except Exception as e: diff --git a/api/core/docstore/dataset_docstore.py b/api/core/docstore/dataset_docstore.py index 786ae4469d..24dc6194e2 100644 --- a/api/core/docstore/dataset_docstore.py +++ b/api/core/docstore/dataset_docstore.py @@ -10,10 +10,10 @@ from models.dataset import Dataset, DocumentSegment class DatesetDocumentStore: def __init__( - self, - dataset: Dataset, - user_id: str, - document_id: Optional[str] = None, + self, + dataset: Dataset, + user_id: str, + document_id: Optional[str] = None, ): self._dataset = dataset self._user_id = user_id @@ -59,7 +59,7 @@ class DatesetDocumentStore: return output def add_documents( - self, docs: Sequence[Document], allow_update: bool = True + self, docs: Sequence[Document], allow_update: bool = True ) -> None: max_position = db.session.query(func.max(DocumentSegment.position)).filter( DocumentSegment.document_id == self._document_id @@ -69,7 +69,9 @@ class DatesetDocumentStore: max_position = 0 embedding_model = ModelFactory.get_embedding_model( - tenant_id=self._dataset.tenant_id + tenant_id=self._dataset.tenant_id, + model_provider_name=self._dataset.embedding_model_provider, + model_name=self._dataset.embedding_model ) for doc in docs: @@ -123,7 +125,7 @@ class DatesetDocumentStore: return result is not None def get_document( - self, doc_id: str, raise_error: bool = True + self, doc_id: str, raise_error: bool = True ) -> Optional[Document]: document_segment = self.get_document_segment(doc_id) diff --git a/api/core/generator/llm_generator.py b/api/core/generator/llm_generator.py index 034483ebd5..3ace061ac7 100644 --- a/api/core/generator/llm_generator.py +++ b/api/core/generator/llm_generator.py @@ -179,8 +179,8 @@ class LLMGenerator: return rule_config @classmethod - def generate_qa_document(cls, tenant_id: str, query): - prompt = GENERATOR_QA_PROMPT + def generate_qa_document(cls, tenant_id: str, query, document_language: str): + prompt = GENERATOR_QA_PROMPT.format(language=document_language) model_instance = ModelFactory.get_text_generation_model( tenant_id=tenant_id, diff --git a/api/core/index/index.py b/api/core/index/index.py index 316b604566..26b6a84dfe 100644 --- a/api/core/index/index.py +++ b/api/core/index/index.py @@ -15,7 +15,9 @@ class IndexBuilder: return None embedding_model = ModelFactory.get_embedding_model( - tenant_id=dataset.tenant_id + tenant_id=dataset.tenant_id, + model_provider_name=dataset.embedding_model_provider, + model_name=dataset.embedding_model ) embeddings = CacheEmbedding(embedding_model) diff --git a/api/core/indexing_runner.py b/api/core/indexing_runner.py index f23ed0f2b5..21a678fef9 100644 --- a/api/core/indexing_runner.py +++ b/api/core/indexing_runner.py @@ -67,14 +67,6 @@ class IndexingRunner: dataset_document=dataset_document, processing_rule=processing_rule ) - # new_documents = [] - # for document in documents: - # response = LLMGenerator.generate_qa_document(dataset.tenant_id, document.page_content) - # document_qa_list = self.format_split_text(response) - # for result in document_qa_list: - # document = Document(page_content=result['question'], metadata={'source': result['answer']}) - # new_documents.append(document) - # build index self._build_index( dataset=dataset, dataset_document=dataset_document, @@ -225,14 +217,25 @@ class IndexingRunner: db.session.commit() def file_indexing_estimate(self, tenant_id: str, file_details: List[UploadFile], tmp_processing_rule: dict, - doc_form: str = None) -> dict: + doc_form: str = None, doc_language: str = 'English', dataset_id: str = None) -> dict: """ Estimate the indexing for the document. """ - embedding_model = ModelFactory.get_embedding_model( - tenant_id=tenant_id - ) - + if dataset_id: + dataset = Dataset.query.filter_by( + id=dataset_id + ).first() + if not dataset: + raise ValueError('Dataset not found.') + embedding_model = ModelFactory.get_embedding_model( + tenant_id=dataset.tenant_id, + model_provider_name=dataset.embedding_model_provider, + model_name=dataset.embedding_model + ) + else: + embedding_model = ModelFactory.get_embedding_model( + tenant_id=tenant_id + ) tokens = 0 preview_texts = [] total_segments = 0 @@ -263,14 +266,13 @@ class IndexingRunner: tokens += embedding_model.get_num_tokens(self.filter_string(document.page_content)) - text_generation_model = ModelFactory.get_text_generation_model( - tenant_id=tenant_id - ) - if doc_form and doc_form == 'qa_model': + text_generation_model = ModelFactory.get_text_generation_model( + tenant_id=tenant_id + ) if len(preview_texts) > 0: # qa model document - response = LLMGenerator.generate_qa_document(current_user.current_tenant_id, preview_texts[0]) + response = LLMGenerator.generate_qa_document(current_user.current_tenant_id, preview_texts[0], doc_language) document_qa_list = self.format_split_text(response) return { "total_segments": total_segments * 20, @@ -289,13 +291,26 @@ class IndexingRunner: "preview": preview_texts } - def notion_indexing_estimate(self, tenant_id: str, notion_info_list: list, tmp_processing_rule: dict, doc_form: str = None) -> dict: + def notion_indexing_estimate(self, tenant_id: str, notion_info_list: list, tmp_processing_rule: dict, + doc_form: str = None, doc_language: str = 'English', dataset_id: str = None) -> dict: """ Estimate the indexing for the document. """ - embedding_model = ModelFactory.get_embedding_model( - tenant_id=tenant_id - ) + if dataset_id: + dataset = Dataset.query.filter_by( + id=dataset_id + ).first() + if not dataset: + raise ValueError('Dataset not found.') + embedding_model = ModelFactory.get_embedding_model( + tenant_id=dataset.tenant_id, + model_provider_name=dataset.embedding_model_provider, + model_name=dataset.embedding_model + ) + else: + embedding_model = ModelFactory.get_embedding_model( + tenant_id=tenant_id + ) # load data from notion tokens = 0 @@ -344,14 +359,13 @@ class IndexingRunner: tokens += embedding_model.get_num_tokens(document.page_content) - text_generation_model = ModelFactory.get_text_generation_model( - tenant_id=tenant_id - ) - if doc_form and doc_form == 'qa_model': + text_generation_model = ModelFactory.get_text_generation_model( + tenant_id=tenant_id + ) if len(preview_texts) > 0: # qa model document - response = LLMGenerator.generate_qa_document(current_user.current_tenant_id, preview_texts[0]) + response = LLMGenerator.generate_qa_document(current_user.current_tenant_id, preview_texts[0], doc_language) document_qa_list = self.format_split_text(response) return { "total_segments": total_segments * 20, @@ -458,7 +472,8 @@ class IndexingRunner: splitter=splitter, processing_rule=processing_rule, tenant_id=dataset.tenant_id, - document_form=dataset_document.doc_form + document_form=dataset_document.doc_form, + document_language=dataset_document.doc_language ) # save node to document segment @@ -494,7 +509,8 @@ class IndexingRunner: return documents def _split_to_documents(self, text_docs: List[Document], splitter: TextSplitter, - processing_rule: DatasetProcessRule, tenant_id: str, document_form: str) -> List[Document]: + processing_rule: DatasetProcessRule, tenant_id: str, + document_form: str, document_language: str) -> List[Document]: """ Split the text documents into nodes. """ @@ -523,8 +539,9 @@ class IndexingRunner: sub_documents = all_documents[i:i + 10] for doc in sub_documents: document_format_thread = threading.Thread(target=self.format_qa_document, kwargs={ - 'flask_app': current_app._get_current_object(), 'tenant_id': tenant_id, 'document_node': doc, - 'all_qa_documents': all_qa_documents}) + 'flask_app': current_app._get_current_object(), + 'tenant_id': tenant_id, 'document_node': doc, 'all_qa_documents': all_qa_documents, + 'document_language': document_language}) threads.append(document_format_thread) document_format_thread.start() for thread in threads: @@ -532,14 +549,14 @@ class IndexingRunner: return all_qa_documents return all_documents - def format_qa_document(self, flask_app: Flask, tenant_id: str, document_node, all_qa_documents): + def format_qa_document(self, flask_app: Flask, tenant_id: str, document_node, all_qa_documents, document_language): format_documents = [] if document_node.page_content is None or not document_node.page_content.strip(): return with flask_app.app_context(): try: # qa model document - response = LLMGenerator.generate_qa_document(tenant_id, document_node.page_content) + response = LLMGenerator.generate_qa_document(tenant_id, document_node.page_content, document_language) document_qa_list = self.format_split_text(response) qa_documents = [] for result in document_qa_list: @@ -641,7 +658,9 @@ class IndexingRunner: keyword_table_index = IndexBuilder.get_index(dataset, 'economy') embedding_model = ModelFactory.get_embedding_model( - tenant_id=dataset.tenant_id + tenant_id=dataset.tenant_id, + model_provider_name=dataset.embedding_model_provider, + model_name=dataset.embedding_model ) # chunk nodes by chunk size @@ -722,6 +741,32 @@ class IndexingRunner: DocumentSegment.query.filter_by(document_id=dataset_document_id).update(update_params) db.session.commit() + def batch_add_segments(self, segments: List[DocumentSegment], dataset: Dataset): + """ + Batch add segments index processing + """ + documents = [] + for segment in segments: + document = Document( + page_content=segment.content, + metadata={ + "doc_id": segment.index_node_id, + "doc_hash": segment.index_node_hash, + "document_id": segment.document_id, + "dataset_id": segment.dataset_id, + } + ) + documents.append(document) + # save vector index + index = IndexBuilder.get_index(dataset, 'high_quality') + if index: + index.add_texts(documents, duplicate_check=True) + + # save keyword index + index = IndexBuilder.get_index(dataset, 'economy') + if index: + index.add_texts(documents) + class DocumentIsPausedException(Exception): pass diff --git a/api/core/prompt/prompts.py b/api/core/prompt/prompts.py index 8e829d47ae..f30d329ec2 100644 --- a/api/core/prompt/prompts.py +++ b/api/core/prompt/prompts.py @@ -44,13 +44,13 @@ SUGGESTED_QUESTIONS_AFTER_ANSWER_INSTRUCTION_PROMPT = ( ) GENERATOR_QA_PROMPT = ( - "Please respond according to the language of the user's input text. If the text is in language [A], you must also reply in language [A].\n" + 'The user will send a long text. Please think step by step.' 'Step 1: Understand and summarize the main content of this text.\n' 'Step 2: What key information or concepts are mentioned in this text?\n' 'Step 3: Decompose or combine multiple pieces of information and concepts.\n' 'Step 4: Generate 20 questions and answers based on these key information and concepts.' 'The questions should be clear and detailed, and the answers should be detailed and complete.\n' - "Answer in the following format: Q1:\nA1:\nQ2:\nA2:...\n" + "Answer must be the language:{language} and in the following format: Q1:\nA1:\nQ2:\nA2:...\n" ) RULE_CONFIG_GENERATE_TEMPLATE = """Given MY INTENDED AUDIENCES and HOPING TO SOLVE using a language model, please select \ diff --git a/api/core/tool/dataset_retriever_tool.py b/api/core/tool/dataset_retriever_tool.py index 57ff10ae9b..a0fe89fe83 100644 --- a/api/core/tool/dataset_retriever_tool.py +++ b/api/core/tool/dataset_retriever_tool.py @@ -9,6 +9,7 @@ from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCa from core.embedding.cached_embedding import CacheEmbedding from core.index.keyword_table_index.keyword_table_index import KeywordTableIndex, KeywordTableConfig from core.index.vector_index.vector_index import VectorIndex +from core.model_providers.error import LLMBadRequestError, ProviderTokenNotInitError from core.model_providers.model_factory import ModelFactory from extensions.ext_database import db from models.dataset import Dataset, DocumentSegment @@ -70,10 +71,17 @@ class DatasetRetrieverTool(BaseTool): documents = kw_table_index.search(query, search_kwargs={'k': self.k}) return str("\n".join([document.page_content for document in documents])) else: - embedding_model = ModelFactory.get_embedding_model( - tenant_id=dataset.tenant_id - ) + try: + embedding_model = ModelFactory.get_embedding_model( + tenant_id=dataset.tenant_id, + model_provider_name=dataset.embedding_model_provider, + model_name=dataset.embedding_model + ) + except LLMBadRequestError: + return '' + except ProviderTokenNotInitError: + return '' embeddings = CacheEmbedding(embedding_model) vector_index = VectorIndex( diff --git a/api/migrations/versions/2c8af9671032_add_qa_document_language.py b/api/migrations/versions/2c8af9671032_add_qa_document_language.py new file mode 100644 index 0000000000..5ad90fa31c --- /dev/null +++ b/api/migrations/versions/2c8af9671032_add_qa_document_language.py @@ -0,0 +1,32 @@ +"""add_qa_document_language + +Revision ID: 2c8af9671032 +Revises: 8d2d099ceb74 +Create Date: 2023-08-01 18:57:27.294973 + +""" +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = '2c8af9671032' +down_revision = '5022897aaceb' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('documents', schema=None) as batch_op: + batch_op.add_column(sa.Column('doc_language', sa.String(length=255), nullable=True)) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('documents', schema=None) as batch_op: + batch_op.drop_column('doc_language') + + # ### end Alembic commands ### diff --git a/api/migrations/versions/e8883b0148c9_add_dataset_model_name.py b/api/migrations/versions/e8883b0148c9_add_dataset_model_name.py new file mode 100644 index 0000000000..67eaf35e52 --- /dev/null +++ b/api/migrations/versions/e8883b0148c9_add_dataset_model_name.py @@ -0,0 +1,34 @@ +"""add_dataset_model_name + +Revision ID: e8883b0148c9 +Revises: 2c8af9671032 +Create Date: 2023-08-15 20:54:58.936787 + +""" +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = 'e8883b0148c9' +down_revision = '2c8af9671032' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('datasets', schema=None) as batch_op: + batch_op.add_column(sa.Column('embedding_model', sa.String(length=255), server_default=sa.text("'text-embedding-ada-002'::character varying"), nullable=False)) + batch_op.add_column(sa.Column('embedding_model_provider', sa.String(length=255), server_default=sa.text("'openai'::character varying"), nullable=False)) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('datasets', schema=None) as batch_op: + batch_op.drop_column('embedding_model_provider') + batch_op.drop_column('embedding_model') + + # ### end Alembic commands ### diff --git a/api/models/dataset.py b/api/models/dataset.py index ecf087ef65..6f7891a163 100644 --- a/api/models/dataset.py +++ b/api/models/dataset.py @@ -36,6 +36,10 @@ class Dataset(db.Model): updated_by = db.Column(UUID, nullable=True) updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + embedding_model = db.Column(db.String( + 255), nullable=False, server_default=db.text("'text-embedding-ada-002'::character varying")) + embedding_model_provider = db.Column(db.String( + 255), nullable=False, server_default=db.text("'openai'::character varying")) @property def dataset_keyword_table(self): @@ -209,6 +213,7 @@ class Document(db.Model): doc_metadata = db.Column(db.JSON, nullable=True) doc_form = db.Column(db.String( 255), nullable=False, server_default=db.text("'text_model'::character varying")) + doc_language = db.Column(db.String(255), nullable=True) DATA_SOURCES = ['upload_file', 'notion_import'] diff --git a/api/requirements.txt b/api/requirements.txt index ac87a58ea7..d7d546c856 100644 --- a/api/requirements.txt +++ b/api/requirements.txt @@ -47,4 +47,5 @@ websocket-client~=1.6.1 dashscope~=1.5.0 huggingface_hub~=0.16.4 transformers~=4.31.0 -stripe~=5.5.0 \ No newline at end of file +stripe~=5.5.0 +pandas==1.5.3 \ No newline at end of file diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py index 5edd4b3da8..636f0fa08f 100644 --- a/api/services/dataset_service.py +++ b/api/services/dataset_service.py @@ -9,6 +9,7 @@ from typing import Optional, List from flask import current_app from sqlalchemy import func +from core.index.index import IndexBuilder from core.model_providers.model_factory import ModelFactory from extensions.ext_redis import redis_client from flask_login import current_user @@ -25,14 +26,16 @@ from services.errors.account import NoPermissionError from services.errors.dataset import DatasetNameDuplicateError from services.errors.document import DocumentIndexingError from services.errors.file import FileNotExistsError +from services.vector_service import VectorService from tasks.clean_notion_document_task import clean_notion_document_task from tasks.deal_dataset_vector_index_task import deal_dataset_vector_index_task from tasks.document_indexing_task import document_indexing_task from tasks.document_indexing_update_task import document_indexing_update_task from tasks.create_segment_to_index_task import create_segment_to_index_task from tasks.update_segment_index_task import update_segment_index_task -from tasks.update_segment_keyword_index_task\ - import update_segment_keyword_index_task +from tasks.recover_document_indexing_task import recover_document_indexing_task +from tasks.update_segment_keyword_index_task import update_segment_keyword_index_task +from tasks.delete_segment_from_index_task import delete_segment_from_index_task class DatasetService: @@ -88,12 +91,16 @@ class DatasetService: if Dataset.query.filter_by(name=name, tenant_id=tenant_id).first(): raise DatasetNameDuplicateError( f'Dataset with name {name} already exists.') - + embedding_model = ModelFactory.get_embedding_model( + tenant_id=current_user.current_tenant_id + ) dataset = Dataset(name=name, indexing_technique=indexing_technique) # dataset = Dataset(name=name, provider=provider, config=config) dataset.created_by = account.id dataset.updated_by = account.id dataset.tenant_id = tenant_id + dataset.embedding_model_provider = embedding_model.model_provider.provider_name + dataset.embedding_model = embedding_model.name db.session.add(dataset) db.session.commit() return dataset @@ -372,7 +379,7 @@ class DocumentService: indexing_cache_key = 'document_{}_is_paused'.format(document.id) redis_client.delete(indexing_cache_key) # trigger async task - document_indexing_task.delay(document.dataset_id, document.id) + recover_document_indexing_task.delay(document.dataset_id, document.id) @staticmethod def get_documents_position(dataset_id): @@ -450,6 +457,7 @@ class DocumentService: document = DocumentService.save_document(dataset, dataset_process_rule.id, document_data["data_source"]["type"], document_data["doc_form"], + document_data["doc_language"], data_source_info, created_from, position, account, file_name, batch) db.session.add(document) @@ -495,20 +503,11 @@ class DocumentService: document = DocumentService.save_document(dataset, dataset_process_rule.id, document_data["data_source"]["type"], document_data["doc_form"], + document_data["doc_language"], data_source_info, created_from, position, account, page['page_name'], batch) - # if page['type'] == 'database': - # document.splitting_completed_at = datetime.datetime.utcnow() - # document.cleaning_completed_at = datetime.datetime.utcnow() - # document.parsing_completed_at = datetime.datetime.utcnow() - # document.completed_at = datetime.datetime.utcnow() - # document.indexing_status = 'completed' - # document.word_count = 0 - # document.tokens = 0 - # document.indexing_latency = 0 db.session.add(document) db.session.flush() - # if page['type'] != 'database': document_ids.append(document.id) documents.append(document) position += 1 @@ -520,15 +519,15 @@ class DocumentService: db.session.commit() # trigger async task - #document_index_created.send(dataset.id, document_ids=document_ids) document_indexing_task.delay(dataset.id, document_ids) return documents, batch @staticmethod def save_document(dataset: Dataset, process_rule_id: str, data_source_type: str, document_form: str, - data_source_info: dict, created_from: str, position: int, account: Account, name: str, - batch: str): + document_language: str, data_source_info: dict, created_from: str, position: int, + account: Account, + name: str, batch: str): document = Document( tenant_id=dataset.tenant_id, dataset_id=dataset.id, @@ -540,7 +539,8 @@ class DocumentService: name=name, created_from=created_from, created_by=account.id, - doc_form=document_form + doc_form=document_form, + doc_language=document_language ) return document @@ -654,13 +654,18 @@ class DocumentService: tenant_document_count = int(current_app.config['TENANT_DOCUMENT_COUNT']) if documents_count > tenant_document_count: raise ValueError(f"over document limit {tenant_document_count}.") + embedding_model = ModelFactory.get_embedding_model( + tenant_id=tenant_id + ) # save dataset dataset = Dataset( tenant_id=tenant_id, name='', data_source_type=document_data["data_source"]["type"], indexing_technique=document_data["indexing_technique"], - created_by=account.id + created_by=account.id, + embedding_model=embedding_model.name, + embedding_model_provider=embedding_model.model_provider.provider_name ) db.session.add(dataset) @@ -870,13 +875,15 @@ class SegmentService: raise ValueError("Answer is required") @classmethod - def create_segment(cls, args: dict, document: Document): + def create_segment(cls, args: dict, document: Document, dataset: Dataset): content = args['content'] doc_id = str(uuid.uuid4()) segment_hash = helper.generate_text_hash(content) embedding_model = ModelFactory.get_embedding_model( - tenant_id=document.tenant_id + tenant_id=dataset.tenant_id, + model_provider_name=dataset.embedding_model_provider, + model_name=dataset.embedding_model ) # calc embedding use tokens @@ -894,6 +901,9 @@ class SegmentService: content=content, word_count=len(content), tokens=tokens, + status='completed', + indexing_at=datetime.datetime.utcnow(), + completed_at=datetime.datetime.utcnow(), created_by=current_user.id ) if document.doc_form == 'qa_model': @@ -901,49 +911,88 @@ class SegmentService: db.session.add(segment_document) db.session.commit() - indexing_cache_key = 'segment_{}_indexing'.format(segment_document.id) - redis_client.setex(indexing_cache_key, 600, 1) - create_segment_to_index_task.delay(segment_document.id, args['keywords']) - return segment_document + + # save vector index + try: + VectorService.create_segment_vector(args['keywords'], segment_document, dataset) + except Exception as e: + logging.exception("create segment index failed") + segment_document.enabled = False + segment_document.disabled_at = datetime.datetime.utcnow() + segment_document.status = 'error' + segment_document.error = str(e) + db.session.commit() + segment = db.session.query(DocumentSegment).filter(DocumentSegment.id == segment_document.id).first() + return segment @classmethod - def update_segment(cls, args: dict, segment: DocumentSegment, document: Document): + def update_segment(cls, args: dict, segment: DocumentSegment, document: Document, dataset: Dataset): indexing_cache_key = 'segment_{}_indexing'.format(segment.id) cache_result = redis_client.get(indexing_cache_key) if cache_result is not None: raise ValueError("Segment is indexing, please try again later") - content = args['content'] - if segment.content == content: - if document.doc_form == 'qa_model': - segment.answer = args['answer'] - if args['keywords']: - segment.keywords = args['keywords'] - db.session.add(segment) - db.session.commit() - # update segment index task - redis_client.setex(indexing_cache_key, 600, 1) - update_segment_keyword_index_task.delay(segment.id) - else: - segment_hash = helper.generate_text_hash(content) + try: + content = args['content'] + if segment.content == content: + if document.doc_form == 'qa_model': + segment.answer = args['answer'] + if args['keywords']: + segment.keywords = args['keywords'] + db.session.add(segment) + db.session.commit() + # update segment index task + if args['keywords']: + kw_index = IndexBuilder.get_index(dataset, 'economy') + # delete from keyword index + kw_index.delete_by_ids([segment.index_node_id]) + # save keyword index + kw_index.update_segment_keywords_index(segment.index_node_id, segment.keywords) + else: + segment_hash = helper.generate_text_hash(content) - embedding_model = ModelFactory.get_embedding_model( - tenant_id=document.tenant_id - ) + embedding_model = ModelFactory.get_embedding_model( + tenant_id=dataset.tenant_id, + model_provider_name=dataset.embedding_model_provider, + model_name=dataset.embedding_model + ) - # calc embedding use tokens - tokens = embedding_model.get_num_tokens(content) - segment.content = content - segment.index_node_hash = segment_hash - segment.word_count = len(content) - segment.tokens = tokens - segment.status = 'updating' - segment.updated_by = current_user.id - segment.updated_at = datetime.datetime.utcnow() - if document.doc_form == 'qa_model': - segment.answer = args['answer'] - db.session.add(segment) + # calc embedding use tokens + tokens = embedding_model.get_num_tokens(content) + segment.content = content + segment.index_node_hash = segment_hash + segment.word_count = len(content) + segment.tokens = tokens + segment.status = 'completed' + segment.indexing_at = datetime.datetime.utcnow() + segment.completed_at = datetime.datetime.utcnow() + segment.updated_by = current_user.id + segment.updated_at = datetime.datetime.utcnow() + if document.doc_form == 'qa_model': + segment.answer = args['answer'] + db.session.add(segment) + db.session.commit() + # update segment vector index + VectorService.create_segment_vector(args['keywords'], segment, dataset) + except Exception as e: + logging.exception("update segment index failed") + segment.enabled = False + segment.disabled_at = datetime.datetime.utcnow() + segment.status = 'error' + segment.error = str(e) db.session.commit() - # update segment index task - redis_client.setex(indexing_cache_key, 600, 1) - update_segment_index_task.delay(segment.id, args['keywords']) + segment = db.session.query(DocumentSegment).filter(DocumentSegment.id == segment.id).first() return segment + + @classmethod + def delete_segment(cls, segment: DocumentSegment, document: Document, dataset: Dataset): + indexing_cache_key = 'segment_{}_delete_indexing'.format(segment.id) + cache_result = redis_client.get(indexing_cache_key) + if cache_result is not None: + raise ValueError("Segment is deleting.") + # send delete segment index task + redis_client.setex(indexing_cache_key, 600, 1) + # enabled segment need to delete index + if segment.enabled: + delete_segment_from_index_task.delay(segment.id, segment.index_node_id, dataset.id, document.id) + db.session.delete(segment) + db.session.commit() diff --git a/api/services/hit_testing_service.py b/api/services/hit_testing_service.py index 3c1247ba56..4c2a8bf904 100644 --- a/api/services/hit_testing_service.py +++ b/api/services/hit_testing_service.py @@ -29,7 +29,9 @@ class HitTestingService: } embedding_model = ModelFactory.get_embedding_model( - tenant_id=dataset.tenant_id + tenant_id=dataset.tenant_id, + model_provider_name=dataset.embedding_model_provider, + model_name=dataset.embedding_model ) embeddings = CacheEmbedding(embedding_model) diff --git a/api/services/vector_service.py b/api/services/vector_service.py new file mode 100644 index 0000000000..3cb8eb0099 --- /dev/null +++ b/api/services/vector_service.py @@ -0,0 +1,69 @@ + +from typing import Optional, List + +from langchain.schema import Document + +from core.index.index import IndexBuilder + +from models.dataset import Dataset, DocumentSegment + + +class VectorService: + + @classmethod + def create_segment_vector(cls, keywords: Optional[List[str]], segment: DocumentSegment, dataset: Dataset): + document = Document( + page_content=segment.content, + metadata={ + "doc_id": segment.index_node_id, + "doc_hash": segment.index_node_hash, + "document_id": segment.document_id, + "dataset_id": segment.dataset_id, + } + ) + + # save vector index + index = IndexBuilder.get_index(dataset, 'high_quality') + if index: + index.add_texts([document], duplicate_check=True) + + # save keyword index + index = IndexBuilder.get_index(dataset, 'economy') + if index: + if keywords and len(keywords) > 0: + index.create_segment_keywords(segment.index_node_id, keywords) + else: + index.add_texts([document]) + + @classmethod + def update_segment_vector(cls, keywords: Optional[List[str]], segment: DocumentSegment, dataset: Dataset): + # update segment index task + vector_index = IndexBuilder.get_index(dataset, 'high_quality') + kw_index = IndexBuilder.get_index(dataset, 'economy') + # delete from vector index + if vector_index: + vector_index.delete_by_ids([segment.index_node_id]) + + # delete from keyword index + kw_index.delete_by_ids([segment.index_node_id]) + + # add new index + document = Document( + page_content=segment.content, + metadata={ + "doc_id": segment.index_node_id, + "doc_hash": segment.index_node_hash, + "document_id": segment.document_id, + "dataset_id": segment.dataset_id, + } + ) + + # save vector index + if vector_index: + vector_index.add_texts([document], duplicate_check=True) + + # save keyword index + if keywords and len(keywords) > 0: + kw_index.create_segment_keywords(segment.index_node_id, keywords) + else: + kw_index.add_texts([document]) diff --git a/api/tasks/batch_create_segment_to_index_task.py b/api/tasks/batch_create_segment_to_index_task.py new file mode 100644 index 0000000000..86421a0115 --- /dev/null +++ b/api/tasks/batch_create_segment_to_index_task.py @@ -0,0 +1,95 @@ +import datetime +import logging +import time +import uuid +from typing import Optional, List + +import click +from celery import shared_task +from sqlalchemy import func +from werkzeug.exceptions import NotFound + +from core.index.index import IndexBuilder +from core.indexing_runner import IndexingRunner +from core.model_providers.model_factory import ModelFactory +from extensions.ext_database import db +from extensions.ext_redis import redis_client +from libs import helper +from models.dataset import DocumentSegment, Dataset, Document + + +@shared_task(queue='dataset') +def batch_create_segment_to_index_task(job_id: str, content: List, dataset_id: str, document_id: str, + tenant_id: str, user_id: str): + """ + Async batch create segment to index + :param job_id: + :param content: + :param dataset_id: + :param document_id: + :param tenant_id: + :param user_id: + + Usage: batch_create_segment_to_index_task.delay(segment_id) + """ + logging.info(click.style('Start batch create segment jobId: {}'.format(job_id), fg='green')) + start_at = time.perf_counter() + + indexing_cache_key = 'segment_batch_import_{}'.format(job_id) + + try: + dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first() + if not dataset: + raise ValueError('Dataset not exist.') + + dataset_document = db.session.query(Document).filter(Document.id == document_id).first() + if not dataset_document: + raise ValueError('Document not exist.') + + if not dataset_document.enabled or dataset_document.archived or dataset_document.indexing_status != 'completed': + raise ValueError('Document is not available.') + document_segments = [] + for segment in content: + content = segment['content'] + doc_id = str(uuid.uuid4()) + segment_hash = helper.generate_text_hash(content) + embedding_model = ModelFactory.get_embedding_model( + tenant_id=dataset.tenant_id, + model_provider_name=dataset.embedding_model_provider, + model_name=dataset.embedding_model + ) + + # calc embedding use tokens + tokens = embedding_model.get_num_tokens(content) + max_position = db.session.query(func.max(DocumentSegment.position)).filter( + DocumentSegment.document_id == dataset_document.id + ).scalar() + segment_document = DocumentSegment( + tenant_id=tenant_id, + dataset_id=dataset_id, + document_id=document_id, + index_node_id=doc_id, + index_node_hash=segment_hash, + position=max_position + 1 if max_position else 1, + content=content, + word_count=len(content), + tokens=tokens, + created_by=user_id, + indexing_at=datetime.datetime.utcnow(), + status='completed', + completed_at=datetime.datetime.utcnow() + ) + if dataset_document.doc_form == 'qa_model': + segment_document.answer = segment['answer'] + db.session.add(segment_document) + document_segments.append(segment_document) + # add index to db + indexing_runner = IndexingRunner() + indexing_runner.batch_add_segments(document_segments, dataset) + db.session.commit() + redis_client.setex(indexing_cache_key, 600, 'completed') + end_at = time.perf_counter() + logging.info(click.style('Segment batch created job: {} latency: {}'.format(job_id, end_at - start_at), fg='green')) + except Exception as e: + logging.exception("Segments batch created index failed:{}".format(str(e))) + redis_client.setex(indexing_cache_key, 600, 'error') diff --git a/api/tasks/delete_segment_from_index_task.py b/api/tasks/delete_segment_from_index_task.py new file mode 100644 index 0000000000..bb5a87410f --- /dev/null +++ b/api/tasks/delete_segment_from_index_task.py @@ -0,0 +1,58 @@ +import logging +import time + +import click +from celery import shared_task +from werkzeug.exceptions import NotFound + +from core.index.index import IndexBuilder +from extensions.ext_database import db +from extensions.ext_redis import redis_client +from models.dataset import DocumentSegment, Dataset, Document + + +@shared_task(queue='dataset') +def delete_segment_from_index_task(segment_id: str, index_node_id: str, dataset_id: str, document_id: str): + """ + Async Remove segment from index + :param segment_id: + :param index_node_id: + :param dataset_id: + :param document_id: + + Usage: delete_segment_from_index_task.delay(segment_id) + """ + logging.info(click.style('Start delete segment from index: {}'.format(segment_id), fg='green')) + start_at = time.perf_counter() + indexing_cache_key = 'segment_{}_delete_indexing'.format(segment_id) + try: + dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first() + if not dataset: + logging.info(click.style('Segment {} has no dataset, pass.'.format(segment_id), fg='cyan')) + return + + dataset_document = db.session.query(Document).filter(Document.id == document_id).first() + if not dataset_document: + logging.info(click.style('Segment {} has no document, pass.'.format(segment_id), fg='cyan')) + return + + if not dataset_document.enabled or dataset_document.archived or dataset_document.indexing_status != 'completed': + logging.info(click.style('Segment {} document status is invalid, pass.'.format(segment_id), fg='cyan')) + return + + vector_index = IndexBuilder.get_index(dataset, 'high_quality') + kw_index = IndexBuilder.get_index(dataset, 'economy') + + # delete from vector index + if vector_index: + vector_index.delete_by_ids([index_node_id]) + + # delete from keyword index + kw_index.delete_by_ids([index_node_id]) + + end_at = time.perf_counter() + logging.info(click.style('Segment deleted from index: {} latency: {}'.format(segment_id, end_at - start_at), fg='green')) + except Exception: + logging.exception("delete segment from index failed") + finally: + redis_client.delete(indexing_cache_key) diff --git a/api/tasks/remove_segment_from_index_task.py b/api/tasks/disable_segment_from_index_task.py similarity index 89% rename from api/tasks/remove_segment_from_index_task.py rename to api/tasks/disable_segment_from_index_task.py index b0f118649d..d6e75dba69 100644 --- a/api/tasks/remove_segment_from_index_task.py +++ b/api/tasks/disable_segment_from_index_task.py @@ -12,14 +12,14 @@ from models.dataset import DocumentSegment @shared_task(queue='dataset') -def remove_segment_from_index_task(segment_id: str): +def disable_segment_from_index_task(segment_id: str): """ - Async Remove segment from index + Async disable segment from index :param segment_id: - Usage: remove_segment_from_index.delay(segment_id) + Usage: disable_segment_from_index_task.delay(segment_id) """ - logging.info(click.style('Start remove segment from index: {}'.format(segment_id), fg='green')) + logging.info(click.style('Start disable segment from index: {}'.format(segment_id), fg='green')) start_at = time.perf_counter() segment = db.session.query(DocumentSegment).filter(DocumentSegment.id == segment_id).first() diff --git a/api/tasks/update_segment_keyword_index_task.py b/api/tasks/update_segment_keyword_index_task.py index de3dfd10df..284f73677c 100644 --- a/api/tasks/update_segment_keyword_index_task.py +++ b/api/tasks/update_segment_keyword_index_task.py @@ -52,17 +52,6 @@ def update_segment_keyword_index_task(segment_id: str): # delete from keyword index kw_index.delete_by_ids([segment.index_node_id]) - # add new index - document = Document( - page_content=segment.content, - metadata={ - "doc_id": segment.index_node_id, - "doc_hash": segment.index_node_hash, - "document_id": segment.document_id, - "dataset_id": segment.dataset_id, - } - ) - # save keyword index index = IndexBuilder.get_index(dataset, 'economy') if index: diff --git a/web/app/(commonLayout)/datasets/DatasetCard.tsx b/web/app/(commonLayout)/datasets/DatasetCard.tsx index 30d9e776f2..8e4f29af9d 100644 --- a/web/app/(commonLayout)/datasets/DatasetCard.tsx +++ b/web/app/(commonLayout)/datasets/DatasetCard.tsx @@ -5,13 +5,14 @@ import Link from 'next/link' import type { MouseEventHandler } from 'react' import { useCallback, useState } from 'react' import { useTranslation } from 'react-i18next' -import classNames from 'classnames' +import cn from 'classnames' import style from '../list.module.css' import Confirm from '@/app/components/base/confirm' import { ToastContext } from '@/app/components/base/toast' import { deleteDataset } from '@/service/datasets' import AppIcon from '@/app/components/base/app-icon' import type { DataSet } from '@/models/datasets' +import Tooltip from '@/app/components/base/tooltip' export type DatasetCardProps = { dataset: DataSet @@ -45,26 +46,36 @@ const DatasetCard = ({ return ( <> - +
- -
-
{dataset.name}
+ +
+
+ {dataset.name} +
+ {!dataset.embedding_available && ( + + {t('dataset.unavailable')} + + )}
-
{dataset.description}
-
+
{dataset.description}
+
- + {dataset.document_count}{t('dataset.documentCount')} - + {Math.round(dataset.word_count / 1000)}{t('dataset.wordCount')} - + {dataset.app_count}{t('dataset.appCount')}
diff --git a/web/app/(commonLayout)/datasets/page.tsx b/web/app/(commonLayout)/datasets/page.tsx index 909c46a435..4513ad1900 100644 --- a/web/app/(commonLayout)/datasets/page.tsx +++ b/web/app/(commonLayout)/datasets/page.tsx @@ -1,13 +1,7 @@ -import classNames from 'classnames' -import { getLocaleOnServer } from '@/i18n/server' -import { useTranslation } from '@/i18n/i18next-serverside-config' import Datasets from './Datasets' import DatasetFooter from './DatasetFooter' const AppList = async () => { - const locale = getLocaleOnServer() - const { t } = await useTranslation(locale, 'dataset') - return (
diff --git a/web/app/(commonLayout)/list.module.css b/web/app/(commonLayout)/list.module.css index bb63290a0f..84108a378e 100644 --- a/web/app/(commonLayout)/list.module.css +++ b/web/app/(commonLayout)/list.module.css @@ -192,3 +192,11 @@ @apply inline-flex items-center mb-2 text-sm font-medium; } /* #endregion new app dialog */ + +.unavailable { + @apply opacity-50; +} + +.listItem:hover .unavailable { + @apply opacity-100; +} diff --git a/web/app/components/app/configuration/dataset-config/card-item/index.tsx b/web/app/components/app/configuration/dataset-config/card-item/index.tsx index a976315ff8..2ad7ad0d66 100644 --- a/web/app/components/app/configuration/dataset-config/card-item/index.tsx +++ b/web/app/components/app/configuration/dataset-config/card-item/index.tsx @@ -7,6 +7,7 @@ import TypeIcon from '../type-icon' import RemoveIcon from '../../base/icons/remove-icon' import s from './style.module.css' import { formatNumber } from '@/utils/format' +import Tooltip from '@/app/components/base/tooltip' export type ICardItemProps = { className?: string @@ -36,10 +37,22 @@ const CardItem: FC = ({ 'flex items-center justify-between rounded-xl px-3 py-2.5 bg-white border border-gray-200 cursor-pointer') }>
- +
+ +
-
{config.name}
-
+
+
{config.name}
+ {!config.embedding_available && ( + + {t('dataset.unavailable')} + + )} +
+
{formatNumber(config.word_count)} {t('appDebug.feature.dataSet.words')} · {formatNumber(config.document_count)} {t('appDebug.feature.dataSet.textBlocks')}
diff --git a/web/app/components/app/configuration/dataset-config/select-dataset/index.tsx b/web/app/components/app/configuration/dataset-config/select-dataset/index.tsx index 14312d9e48..1bcc742317 100644 --- a/web/app/components/app/configuration/dataset-config/select-dataset/index.tsx +++ b/web/app/components/app/configuration/dataset-config/select-dataset/index.tsx @@ -120,15 +120,24 @@ const SelectDataSet: FC = ({ {datasets.map(item => (
i.id === item.id) && s.selected, 'flex justify-between items-center h-10 px-2 rounded-lg bg-white border border-gray-200 cursor-pointer')} - onClick={() => toggleSelect(item)} + className={cn(s.item, selected.some(i => i.id === item.id) && s.selected, 'flex justify-between items-center h-10 px-2 rounded-lg bg-white border border-gray-200 cursor-pointer', !item.embedding_available && s.disabled)} + onClick={() => { + if (!item.embedding_available) + return + toggleSelect(item) + }} > -
- -
{item.name}
+
+
+ +
+
{item.name}
+ {!item.embedding_available && ( + {t('dataset.unavailable')} + )}
-
+
{formatNumber(item.word_count)} {t('appDebug.feature.dataSet.words')} · diff --git a/web/app/components/app/configuration/dataset-config/select-dataset/style.module.css b/web/app/components/app/configuration/dataset-config/select-dataset/style.module.css index 9c73b88298..b560f29c43 100644 --- a/web/app/components/app/configuration/dataset-config/select-dataset/style.module.css +++ b/web/app/components/app/configuration/dataset-config/select-dataset/style.module.css @@ -6,4 +6,8 @@ .item.selected { background: #F5F8FF; border-color: #528BFF; -} \ No newline at end of file +} + +.item.disabled { + @apply bg-white border-gray-200 cursor-default; +} diff --git a/web/app/components/base/icons/assets/vender/line/general/dots-horizontal.svg b/web/app/components/base/icons/assets/vender/line/general/dots-horizontal.svg index db7b0853d5..2e5eacff53 100644 --- a/web/app/components/base/icons/assets/vender/line/general/dots-horizontal.svg +++ b/web/app/components/base/icons/assets/vender/line/general/dots-horizontal.svg @@ -1,9 +1,9 @@ - + - - - + + + diff --git a/web/app/components/base/icons/src/vender/line/general/DotsHorizontal.json b/web/app/components/base/icons/src/vender/line/general/DotsHorizontal.json index 2dc0ee050b..38a493e0d3 100644 --- a/web/app/components/base/icons/src/vender/line/general/DotsHorizontal.json +++ b/web/app/components/base/icons/src/vender/line/general/DotsHorizontal.json @@ -4,9 +4,9 @@ "isRootNode": true, "name": "svg", "attributes": { - "width": "12", - "height": "12", - "viewBox": "0 0 12 12", + "width": "16", + "height": "16", + "viewBox": "0 0 16 16", "fill": "none", "xmlns": "http://www.w3.org/2000/svg" }, @@ -29,7 +29,7 @@ "type": "element", "name": "path", "attributes": { - "d": "M6 6.5C6.27614 6.5 6.5 6.27614 6.5 6C6.5 5.72386 6.27614 5.5 6 5.5C5.72386 5.5 5.5 5.72386 5.5 6C5.5 6.27614 5.72386 6.5 6 6.5Z", + "d": "M8.00008 8.66634C8.36827 8.66634 8.66675 8.36786 8.66675 7.99967C8.66675 7.63148 8.36827 7.33301 8.00008 7.33301C7.63189 7.33301 7.33341 7.63148 7.33341 7.99967C7.33341 8.36786 7.63189 8.66634 8.00008 8.66634Z", "stroke": "currentColor", "stroke-width": "1.5", "stroke-linecap": "round", @@ -41,7 +41,7 @@ "type": "element", "name": "path", "attributes": { - "d": "M9.5 6.5C9.77614 6.5 10 6.27614 10 6C10 5.72386 9.77614 5.5 9.5 5.5C9.22386 5.5 9 5.72386 9 6C9 6.27614 9.22386 6.5 9.5 6.5Z", + "d": "M12.6667 8.66634C13.0349 8.66634 13.3334 8.36786 13.3334 7.99967C13.3334 7.63148 13.0349 7.33301 12.6667 7.33301C12.2986 7.33301 12.0001 7.63148 12.0001 7.99967C12.0001 8.36786 12.2986 8.66634 12.6667 8.66634Z", "stroke": "currentColor", "stroke-width": "1.5", "stroke-linecap": "round", @@ -53,7 +53,7 @@ "type": "element", "name": "path", "attributes": { - "d": "M2.5 6.5C2.77614 6.5 3 6.27614 3 6C3 5.72386 2.77614 5.5 2.5 5.5C2.22386 5.5 2 5.72386 2 6C2 6.27614 2.22386 6.5 2.5 6.5Z", + "d": "M3.33341 8.66634C3.7016 8.66634 4.00008 8.36786 4.00008 7.99967C4.00008 7.63148 3.7016 7.33301 3.33341 7.33301C2.96522 7.33301 2.66675 7.63148 2.66675 7.99967C2.66675 8.36786 2.96522 8.66634 3.33341 8.66634Z", "stroke": "currentColor", "stroke-width": "1.5", "stroke-linecap": "round", @@ -68,4 +68,4 @@ ] }, "name": "DotsHorizontal" -} \ No newline at end of file +} diff --git a/web/app/components/base/popover/index.tsx b/web/app/components/base/popover/index.tsx index 5cf62bd9df..cf8a352e16 100644 --- a/web/app/components/base/popover/index.tsx +++ b/web/app/components/base/popover/index.tsx @@ -9,6 +9,7 @@ type IPopover = { position?: 'bottom' | 'br' btnElement?: string | React.ReactNode btnClassName?: string | ((open: boolean) => string) + manualClose?: boolean } const timeoutDuration = 100 @@ -20,6 +21,7 @@ export default function CustomPopover({ btnElement, className, btnClassName, + manualClose, }: IPopover) { const buttonRef = useRef(null) const timeOutRef = useRef(null) @@ -62,17 +64,14 @@ export default function CustomPopover({ onMouseLeave(open), onMouseEnter: () => onMouseEnter(open), - })} + }) + } > {({ close }) => (
onMouseLeave(open), onMouseEnter: () => onMouseEnter(open), - })} + }) + } > {cloneElement(htmlContent as React.ReactElement, { - onClose: () => close(), + onClose: () => onMouseLeave(open), + ...(manualClose + ? { + onClick: close, + } + : {}), })}
)} diff --git a/web/app/components/datasets/create/file-uploader/index.tsx b/web/app/components/datasets/create/file-uploader/index.tsx index 32926c733c..e86f110bd7 100644 --- a/web/app/components/datasets/create/file-uploader/index.tsx +++ b/web/app/components/datasets/create/file-uploader/index.tsx @@ -29,7 +29,7 @@ const ACCEPTS = [ '.txt', // '.xls', '.xlsx', - '.csv', + // '.csv', ] const FileUploader = ({ diff --git a/web/app/components/datasets/create/step-two/index.tsx b/web/app/components/datasets/create/step-two/index.tsx index f72fe9ef6d..e7d5fb8e38 100644 --- a/web/app/components/datasets/create/step-two/index.tsx +++ b/web/app/components/datasets/create/step-two/index.tsx @@ -2,12 +2,14 @@ 'use client' import React, { useEffect, useLayoutEffect, useRef, useState } from 'react' import { useTranslation } from 'react-i18next' +import { useContext } from 'use-context-selector' import { useBoolean } from 'ahooks' import { XMarkIcon } from '@heroicons/react/20/solid' import cn from 'classnames' import Link from 'next/link' import { groupBy } from 'lodash-es' import PreviewItem, { PreviewType } from './preview-item' +import LanguageSelect from './language-select' import s from './index.module.css' import type { CreateDocumentReq, CustomFile, FullDocumentDetail, FileIndexingEstimateResponse as IndexingEstimateResponse, NotionInfo, PreProcessingRule, Rules, createDocumentResponse } from '@/models/datasets' import { @@ -22,11 +24,13 @@ import Loading from '@/app/components/base/loading' import Toast from '@/app/components/base/toast' import { formatNumber } from '@/utils/format' import type { DataSourceNotionPage } from '@/models/common' -import { DataSourceType } from '@/models/datasets' +import { DataSourceType, DocForm } from '@/models/datasets' import NotionIcon from '@/app/components/base/notion-icon' import Switch from '@/app/components/base/switch' import { MessageChatSquare } from '@/app/components/base/icons/src/public/common' +import { XClose } from '@/app/components/base/icons/src/vender/line/general' import { useDatasetDetailContext } from '@/context/dataset-detail' +import I18n from '@/context/i18n' import { IS_CE_EDITION } from '@/config' type Page = DataSourceNotionPage & { workspace_id: string } @@ -56,10 +60,6 @@ enum IndexingType { QUALIFIED = 'high_quality', ECONOMICAL = 'economy', } -enum DocForm { - TEXT = 'text_model', - QA = 'qa_model', -} const StepTwo = ({ isSetting, @@ -78,6 +78,8 @@ const StepTwo = ({ onCancel, }: StepTwoProps) => { const { t } = useTranslation() + const { locale } = useContext(I18n) + const { mutateDatasetRes } = useDatasetDetailContext() const scrollRef = useRef(null) const [scrolled, setScrolled] = useState(false) @@ -98,6 +100,8 @@ const StepTwo = ({ const [docForm, setDocForm] = useState( datasetId && documentDetail ? documentDetail.doc_form : DocForm.TEXT, ) + const [docLanguage, setDocLanguage] = useState(locale === 'en' ? 'English' : 'Chinese') + const [QATipHide, setQATipHide] = useState(false) const [previewSwitched, setPreviewSwitched] = useState(false) const [showPreview, { setTrue: setShowPreview, setFalse: hidePreview }] = useBoolean() const [customFileIndexingEstimate, setCustomFileIndexingEstimate] = useState(null) @@ -230,6 +234,8 @@ const StepTwo = ({ indexing_technique: getIndexing_technique(), process_rule: getProcessRule(), doc_form: docForm, + doc_language: docLanguage, + dataset_id: datasetId, } } if (dataSourceType === DataSourceType.NOTION) { @@ -241,6 +247,8 @@ const StepTwo = ({ indexing_technique: getIndexing_technique(), process_rule: getProcessRule(), doc_form: docForm, + doc_language: docLanguage, + dataset_id: datasetId, } } return params @@ -252,6 +260,7 @@ const StepTwo = ({ params = { original_document_id: documentDetail?.id, doc_form: docForm, + doc_language: docLanguage, process_rule: getProcessRule(), } as CreateDocumentReq } @@ -266,6 +275,7 @@ const StepTwo = ({ indexing_technique: getIndexing_technique(), process_rule: getProcessRule(), doc_form: docForm, + doc_language: docLanguage, } as CreateDocumentReq if (dataSourceType === DataSourceType.FILE) { params.data_source.info_list.file_info_list = { @@ -348,6 +358,10 @@ const StepTwo = ({ setDocForm(DocForm.TEXT) } + const handleSelect = (language: string) => { + setDocLanguage(language) + } + const changeToEconomicalType = () => { if (!hasSetIndexType) { setIndexType(IndexingType.ECONOMICAL) @@ -574,21 +588,32 @@ const StepTwo = ({
)} {IS_CE_EDITION && indexType === IndexingType.QUALIFIED && ( -
-
- -
-
-
{t('datasetCreation.stepTwo.QATitle')}
-
{t('datasetCreation.stepTwo.QATip')}
-
-
- +
+
+
+ +
+
+
{t('datasetCreation.stepTwo.QATitle')}
+
+ {t('datasetCreation.stepTwo.QALanguage')} + +
+
+
+ +
+ {docForm === DocForm.QA && !QATipHide && ( +
+ {t('datasetCreation.stepTwo.QATip')} + setQATipHide(true)} /> +
+ )}
)}
diff --git a/web/app/components/datasets/create/step-two/language-select/index.tsx b/web/app/components/datasets/create/step-two/language-select/index.tsx new file mode 100644 index 0000000000..859a4c5823 --- /dev/null +++ b/web/app/components/datasets/create/step-two/language-select/index.tsx @@ -0,0 +1,38 @@ +'use client' +import type { FC } from 'react' +import React from 'react' +import cn from 'classnames' +import { ChevronDown } from '@/app/components/base/icons/src/vender/line/arrows' +import Popover from '@/app/components/base/popover' + +export type ILanguageSelectProps = { + currentLanguage: string + onSelect: (language: string) => void +} + +const LanguageSelect: FC = ({ + currentLanguage, + onSelect, +}) => { + return ( + +
onSelect('English')}>English
+
onSelect('Chinese')}>简体中文
+
+ } + btnElement={ +
+ {currentLanguage === 'English' ? 'English' : '简体中文'} + +
+ } + btnClassName={open => cn('!border-0 !px-0 !py-0 !bg-inherit !hover:bg-inherit', open ? 'text-blue-600' : 'text-gray-500')} + className='!w-[120px] h-fit !z-20 !translate-x-0 !left-[-16px]' + /> + ) +} +export default React.memo(LanguageSelect) diff --git a/web/app/components/datasets/documents/detail/batch-modal/csv-downloader.tsx b/web/app/components/datasets/documents/detail/batch-modal/csv-downloader.tsx new file mode 100644 index 0000000000..a2df97cf05 --- /dev/null +++ b/web/app/components/datasets/documents/detail/batch-modal/csv-downloader.tsx @@ -0,0 +1,108 @@ +'use client' +import type { FC } from 'react' +import React from 'react' +import { + useCSVDownloader, +} from 'react-papaparse' +import { useTranslation } from 'react-i18next' +import { useContext } from 'use-context-selector' +import { Download02 as DownloadIcon } from '@/app/components/base/icons/src/vender/solid/general' +import { DocForm } from '@/models/datasets' +import I18n from '@/context/i18n' + +const CSV_TEMPLATE_QA_EN = [ + ['question', 'answer'], + ['question1', 'answer1'], + ['question2', 'answer2'], +] +const CSV_TEMPLATE_QA_CN = [ + ['问题', '答案'], + ['问题 1', '答案 1'], + ['问题 2', '答案 2'], +] +const CSV_TEMPLATE_EN = [ + ['segment content'], + ['content1'], + ['content2'], +] +const CSV_TEMPLATE_CN = [ + ['分段内容'], + ['内容 1'], + ['内容 2'], +] + +const CSVDownload: FC<{ docForm: DocForm }> = ({ docForm }) => { + const { t } = useTranslation() + const { locale } = useContext(I18n) + const { CSVDownloader, Type } = useCSVDownloader() + + const getTemplate = () => { + if (locale === 'en') { + if (docForm === DocForm.QA) + return CSV_TEMPLATE_QA_EN + return CSV_TEMPLATE_EN + } + if (docForm === DocForm.QA) + return CSV_TEMPLATE_QA_CN + return CSV_TEMPLATE_CN + } + + return ( +
+
{t('share.generation.csvStructureTitle')}
+
+ {docForm === DocForm.QA && ( + + + + + + + + + + + + + + + + + +
{t('datasetDocuments.list.batchModal.question')}{t('datasetDocuments.list.batchModal.answer')}
{t('datasetDocuments.list.batchModal.question')} 1{t('datasetDocuments.list.batchModal.answer')} 1
{t('datasetDocuments.list.batchModal.question')} 2{t('datasetDocuments.list.batchModal.answer')} 2
+ )} + {docForm === DocForm.TEXT && ( + + + + + + + + + + + + + + +
{t('datasetDocuments.list.batchModal.contentTitle')}
{t('datasetDocuments.list.batchModal.content')} 1
{t('datasetDocuments.list.batchModal.content')} 2
+ )} +
+ +
+ + {t('datasetDocuments.list.batchModal.template')} +
+
+
+ + ) +} +export default React.memo(CSVDownload) diff --git a/web/app/components/datasets/documents/detail/batch-modal/csv-uploader.tsx b/web/app/components/datasets/documents/detail/batch-modal/csv-uploader.tsx new file mode 100644 index 0000000000..4802315c46 --- /dev/null +++ b/web/app/components/datasets/documents/detail/batch-modal/csv-uploader.tsx @@ -0,0 +1,126 @@ +'use client' +import type { FC } from 'react' +import React, { useEffect, useRef, useState } from 'react' +import cn from 'classnames' +import { useTranslation } from 'react-i18next' +import { useContext } from 'use-context-selector' +import { Csv as CSVIcon } from '@/app/components/base/icons/src/public/files' +import { ToastContext } from '@/app/components/base/toast' +import { Trash03 } from '@/app/components/base/icons/src/vender/line/general' +import Button from '@/app/components/base/button' + +export type Props = { + file: File | undefined + updateFile: (file?: File) => void +} + +const CSVUploader: FC = ({ + file, + updateFile, +}) => { + const { t } = useTranslation() + const { notify } = useContext(ToastContext) + const [dragging, setDragging] = useState(false) + const dropRef = useRef(null) + const dragRef = useRef(null) + const fileUploader = useRef(null) + + const handleDragEnter = (e: DragEvent) => { + e.preventDefault() + e.stopPropagation() + e.target !== dragRef.current && setDragging(true) + } + const handleDragOver = (e: DragEvent) => { + e.preventDefault() + e.stopPropagation() + } + const handleDragLeave = (e: DragEvent) => { + e.preventDefault() + e.stopPropagation() + e.target === dragRef.current && setDragging(false) + } + const handleDrop = (e: DragEvent) => { + e.preventDefault() + e.stopPropagation() + setDragging(false) + if (!e.dataTransfer) + return + const files = [...e.dataTransfer.files] + if (files.length > 1) { + notify({ type: 'error', message: t('datasetCreation.stepOne.uploader.validation.count') }) + return + } + updateFile(files[0]) + } + const selectHandle = () => { + if (fileUploader.current) + fileUploader.current.click() + } + const removeFile = () => { + if (fileUploader.current) + fileUploader.current.value = '' + updateFile() + } + const fileChangeHandle = (e: React.ChangeEvent) => { + const currentFile = e.target.files?.[0] + updateFile(currentFile) + } + + useEffect(() => { + dropRef.current?.addEventListener('dragenter', handleDragEnter) + dropRef.current?.addEventListener('dragover', handleDragOver) + dropRef.current?.addEventListener('dragleave', handleDragLeave) + dropRef.current?.addEventListener('drop', handleDrop) + return () => { + dropRef.current?.removeEventListener('dragenter', handleDragEnter) + dropRef.current?.removeEventListener('dragover', handleDragOver) + dropRef.current?.removeEventListener('dragleave', handleDragLeave) + dropRef.current?.removeEventListener('drop', handleDrop) + } + }, []) + + return ( +
+ +
+ {!file && ( +
+
+ +
+ {t('datasetDocuments.list.batchModal.csvUploadTitle')} + {t('datasetDocuments.list.batchModal.browse')} +
+
+ {dragging &&
} +
+ )} + {file && ( +
+ +
+ {file.name.replace(/.csv$/, '')} + .csv +
+
+ +
+
+ +
+
+
+ )} +
+
+ ) +} + +export default React.memo(CSVUploader) diff --git a/web/app/components/datasets/documents/detail/batch-modal/index.tsx b/web/app/components/datasets/documents/detail/batch-modal/index.tsx new file mode 100644 index 0000000000..cfffedd6ad --- /dev/null +++ b/web/app/components/datasets/documents/detail/batch-modal/index.tsx @@ -0,0 +1,65 @@ +'use client' +import type { FC } from 'react' +import React, { useEffect, useState } from 'react' +import { useTranslation } from 'react-i18next' +import CSVUploader from './csv-uploader' +import CSVDownloader from './csv-downloader' +import Button from '@/app/components/base/button' +import Modal from '@/app/components/base/modal' +import { XClose } from '@/app/components/base/icons/src/vender/line/general' +import type { DocForm } from '@/models/datasets' + +export type IBatchModalProps = { + isShow: boolean + docForm: DocForm + onCancel: () => void + onConfirm: (file: File) => void +} + +const BatchModal: FC = ({ + isShow, + docForm, + onCancel, + onConfirm, +}) => { + const { t } = useTranslation() + const [currentCSV, setCurrentCSV] = useState() + const handleFile = (file?: File) => setCurrentCSV(file) + + const handleSend = () => { + if (!currentCSV) + return + onCancel() + onConfirm(currentCSV) + } + + useEffect(() => { + if (!isShow) + setCurrentCSV(undefined) + }, [isShow]) + + return ( + {}} className='px-8 py-6 !max-w-[520px] !rounded-xl'> +
{t('datasetDocuments.list.batchModal.title')}
+
+ +
+ + +
+ + +
+
+ ) +} +export default React.memo(BatchModal) diff --git a/web/app/components/datasets/documents/detail/completed/InfiniteVirtualList.tsx b/web/app/components/datasets/documents/detail/completed/InfiniteVirtualList.tsx index 0bbd91f877..c4365091f0 100644 --- a/web/app/components/datasets/documents/detail/completed/InfiniteVirtualList.tsx +++ b/web/app/components/datasets/documents/detail/completed/InfiniteVirtualList.tsx @@ -13,6 +13,9 @@ type IInfiniteVirtualListProps = { loadNextPage: () => Promise // Callback function responsible for loading the next page of items. onClick: (detail: SegmentDetailModel) => void onChangeSwitch: (segId: string, enabled: boolean) => Promise + onDelete: (segId: string) => Promise + archived?: boolean + } const InfiniteVirtualList: FC = ({ @@ -22,6 +25,8 @@ const InfiniteVirtualList: FC = ({ loadNextPage, onClick: onClickCard, onChangeSwitch, + onDelete, + archived, }) => { // If there are more items to be loaded then add an extra row to hold a loading indicator. const itemCount = hasNextPage ? items.length + 1 : items.length @@ -52,7 +57,9 @@ const InfiniteVirtualList: FC = ({ detail={segItem} onClick={() => onClickCard(segItem)} onChangeSwitch={onChangeSwitch} + onDelete={onDelete} loading={false} + archived={archived} /> )) } diff --git a/web/app/components/datasets/documents/detail/completed/SegmentCard.tsx b/web/app/components/datasets/documents/detail/completed/SegmentCard.tsx index d32aa10243..bfdc9c9b1c 100644 --- a/web/app/components/datasets/documents/detail/completed/SegmentCard.tsx +++ b/web/app/components/datasets/documents/detail/completed/SegmentCard.tsx @@ -1,5 +1,5 @@ import type { FC } from 'react' -import React from 'react' +import React, { useState } from 'react' import cn from 'classnames' import { ArrowUpRightIcon } from '@heroicons/react/24/outline' import { useTranslation } from 'react-i18next' @@ -7,11 +7,15 @@ import { StatusItem } from '../../list' import { DocumentTitle } from '../index' import s from './style.module.css' import { SegmentIndexTag } from './index' +import Modal from '@/app/components/base/modal' +import Button from '@/app/components/base/button' import Switch from '@/app/components/base/switch' import Divider from '@/app/components/base/divider' import Indicator from '@/app/components/header/indicator' import { formatNumber } from '@/utils/format' import type { SegmentDetailModel } from '@/models/datasets' +import { AlertCircle } from '@/app/components/base/icons/src/vender/solid/alertsAndFeedback' +import { Trash03 } from '@/app/components/base/icons/src/vender/line/general' const ProgressBar: FC<{ percent: number; loading: boolean }> = ({ percent, loading }) => { return ( @@ -35,8 +39,10 @@ type ISegmentCardProps = { score?: number onClick?: () => void onChangeSwitch?: (segId: string, enabled: boolean) => Promise + onDelete?: (segId: string) => Promise scene?: UsageScene className?: string + archived?: boolean } const SegmentCard: FC = ({ @@ -44,9 +50,11 @@ const SegmentCard: FC = ({ score, onClick, onChangeSwitch, + onDelete, loading = true, scene = 'doc', className = '', + archived, }) => { const { t } = useTranslation() const { @@ -60,6 +68,7 @@ const SegmentCard: FC = ({ answer, } = detail as any const isDocScene = scene === 'doc' + const [showModal, setShowModal] = useState(false) const renderContent = () => { if (answer) { @@ -86,7 +95,7 @@ const SegmentCard: FC = ({ s.segWrapper, (isDocScene && !enabled) ? 'bg-gray-25' : '', 'group', - !loading ? 'pb-4' : '', + !loading ? 'pb-4 hover:pb-[10px]' : '', className, )} onClick={() => onClick?.()} @@ -116,6 +125,7 @@ const SegmentCard: FC = ({ > { await onChangeSwitch?.(id, val) @@ -159,10 +169,18 @@ const SegmentCard: FC = ({
{formatNumber(hit_count)}
-
+
{index_node_hash}
+ {!archived && ( +
{ + e.stopPropagation() + setShowModal(true) + }}> + +
+ )}
: <> @@ -187,6 +205,26 @@ const SegmentCard: FC = ({
)} + {showModal && setShowModal(false)} className={s.delModal} closable> +
+
+ +
+
{t('datasetDocuments.segment.delete')}
+
+ + +
+
+
}
) } diff --git a/web/app/components/datasets/documents/detail/completed/index.tsx b/web/app/components/datasets/documents/detail/completed/index.tsx index cbfc534218..34744d6c60 100644 --- a/web/app/components/datasets/documents/detail/completed/index.tsx +++ b/web/app/components/datasets/documents/detail/completed/index.tsx @@ -8,6 +8,7 @@ import { debounce, isNil, omitBy } from 'lodash-es' import cn from 'classnames' import { StatusItem } from '../../list' import { DocumentContext } from '../index' +import { ProcessStatus } from '../segment-add' import s from './style.module.css' import InfiniteVirtualList from './InfiniteVirtualList' import { formatNumber } from '@/utils/format' @@ -18,7 +19,7 @@ import Input from '@/app/components/base/input' import { ToastContext } from '@/app/components/base/toast' import type { Item } from '@/app/components/base/select' import { SimpleSelect } from '@/app/components/base/select' -import { disableSegment, enableSegment, fetchSegments, updateSegment } from '@/service/datasets' +import { deleteSegment, disableSegment, enableSegment, fetchSegments, updateSegment } from '@/service/datasets' import type { SegmentDetailModel, SegmentUpdator, SegmentsQuery, SegmentsResponse } from '@/models/datasets' import { asyncRunSafe } from '@/utils' import type { CommonResponse } from '@/models/common' @@ -48,12 +49,14 @@ type ISegmentDetailProps = { onChangeSwitch?: (segId: string, enabled: boolean) => Promise onUpdate: (segmentId: string, q: string, a: string, k: string[]) => void onCancel: () => void + archived?: boolean } /** * Show all the contents of the segment */ export const SegmentDetail: FC = memo(({ segInfo, + archived, onChangeSwitch, onUpdate, onCancel, @@ -116,31 +119,30 @@ export const SegmentDetail: FC = memo(({ return (
- { - isEditing - ? ( - <> - - - - ) - : ( -
-
{t('common.operation.edit')}
- setIsEditing(true)} /> -
- ) - } -
+ {isEditing && ( + <> + + + + )} + {!isEditing && !archived && ( + <> +
+
{t('common.operation.edit')}
+ setIsEditing(true)} /> +
+
+ + )}
@@ -176,6 +178,7 @@ export const SegmentDetail: FC = memo(({ onChange={async (val) => { await onChangeSwitch?.(segInfo?.id || '', val) }} + disabled={archived} />
@@ -195,13 +198,20 @@ export const splitArray = (arr: any[], size = 3) => { type ICompletedProps = { showNewSegmentModal: boolean onNewSegmentModalChange: (state: boolean) => void + importStatus: ProcessStatus | string | undefined + archived?: boolean // data: Array<{}> // all/part segments } /** * Embedding done, show list of all segments * Support search and filter */ -const Completed: FC = ({ showNewSegmentModal, onNewSegmentModalChange }) => { +const Completed: FC = ({ + showNewSegmentModal, + onNewSegmentModalChange, + importStatus, + archived, +}) => { const { t } = useTranslation() const { notify } = useContext(ToastContext) const { datasetId = '', documentId = '', docForm } = useContext(DocumentContext) @@ -250,11 +260,6 @@ const Completed: FC = ({ showNewSegmentModal, onNewSegmentModal getSegments(false) } - useEffect(() => { - if (lastSegmentsRes !== undefined) - getSegments(false) - }, [selectedStatus, searchValue]) - const onClickCard = (detail: SegmentDetailModel) => { setCurrSegment({ segInfo: detail, showModal: true }) } @@ -281,6 +286,17 @@ const Completed: FC = ({ showNewSegmentModal, onNewSegmentModal } } + const onDelete = async (segId: string) => { + const [e] = await asyncRunSafe(deleteSegment({ datasetId, documentId, segmentId: segId }) as Promise) + if (!e) { + notify({ type: 'success', message: t('common.actionMsg.modifiedSuccessfully') }) + resetList() + } + else { + notify({ type: 'error', message: t('common.actionMsg.modificationFailed') }) + } + } + const handleUpdateSegment = async (segmentId: string, question: string, answer: string, keywords: string[]) => { const params: SegmentUpdator = { content: '' } if (docForm === 'qa_model') { @@ -321,6 +337,16 @@ const Completed: FC = ({ showNewSegmentModal, onNewSegmentModal setAllSegments([...allSegments]) } + useEffect(() => { + if (lastSegmentsRes !== undefined) + getSegments(false) + }, [selectedStatus, searchValue]) + + useEffect(() => { + if (importStatus === ProcessStatus.COMPLETED) + resetList() + }, [importStatus]) + return ( <>
@@ -343,7 +369,9 @@ const Completed: FC = ({ showNewSegmentModal, onNewSegmentModal items={allSegments} loadNextPage={getSegments} onChangeSwitch={onChangeSwitch} + onDelete={onDelete} onClick={onClickCard} + archived={archived} /> {}} className='!max-w-[640px] !overflow-visible'> = ({ showNewSegmentModal, onNewSegmentModal onChangeSwitch={onChangeSwitch} onUpdate={handleUpdateSegment} onCancel={onCloseModal} + archived={archived} /> void }> = ({ onClick }) => { - return ( -
- -
- ) -} +import { checkSegmentBatchImportProgress, fetchDocumentDetail, segmentBatchImport } from '@/service/datasets' +import { ToastContext } from '@/app/components/base/toast' +import type { DocForm } from '@/models/datasets' export const DocumentContext = createContext<{ datasetId?: string; documentId?: string; docForm: string }>({ docForm: '' }) @@ -51,10 +47,45 @@ type Props = { } const DocumentDetail: FC = ({ datasetId, documentId }) => { - const { t } = useTranslation() const router = useRouter() + const { t } = useTranslation() + const { notify } = useContext(ToastContext) const [showMetadata, setShowMetadata] = useState(true) - const [showNewSegmentModal, setShowNewSegmentModal] = useState(false) + const [newSegmentModalVisible, setNewSegmentModalVisible] = useState(false) + const [batchModalVisible, setBatchModalVisible] = useState(false) + const [importStatus, setImportStatus] = useState() + const showNewSegmentModal = () => setNewSegmentModalVisible(true) + const showBatchModal = () => setBatchModalVisible(true) + const hideBatchModal = () => setBatchModalVisible(false) + const resetProcessStatus = () => setImportStatus('') + const checkProcess = async (jobID: string) => { + try { + const res = await checkSegmentBatchImportProgress({ jobID }) + setImportStatus(res.job_status) + if (res.job_status === ProcessStatus.WAITING || res.job_status === ProcessStatus.PROCESSING) + setTimeout(() => checkProcess(res.job_id), 2500) + if (res.job_status === ProcessStatus.ERROR) + notify({ type: 'error', message: `${t('datasetDocuments.list.batchModal.runError')}` }) + } + catch (e: any) { + notify({ type: 'error', message: `${t('datasetDocuments.list.batchModal.runError')}${'message' in e ? `: ${e.message}` : ''}` }) + } + } + const runBatch = async (csv: File) => { + const formData = new FormData() + formData.append('file', csv) + try { + const res = await segmentBatchImport({ + url: `/datasets/${datasetId}/documents/${documentId}/segments/batch_import`, + body: formData, + }) + setImportStatus(res.job_status) + checkProcess(res.job_id) + } + catch (e: any) { + notify({ type: 'error', message: `${t('datasetDocuments.list.batchModal.runError')}${'message' in e ? `: ${e.message}` : ''}` }) + } + } const { data: documentDetail, error, mutate: detailMutate } = useSWR({ action: 'fetchDocumentDetail', @@ -91,22 +122,32 @@ const DocumentDetail: FC = ({ datasetId, documentId }) => {
- +
+ +
+ {documentDetail && !documentDetail.archived && ( + + )} setShowNewSegmentModal(true)} />
@@ -132,6 +175,12 @@ const DocumentDetail: FC = ({ datasetId, documentId }) => { onUpdate={metadataMutate} />}
+
) diff --git a/web/app/components/datasets/documents/detail/segment-add/index.tsx b/web/app/components/datasets/documents/detail/segment-add/index.tsx new file mode 100644 index 0000000000..b24ca38a44 --- /dev/null +++ b/web/app/components/datasets/documents/detail/segment-add/index.tsx @@ -0,0 +1,84 @@ +'use client' +import type { FC } from 'react' +import React from 'react' +import { useTranslation } from 'react-i18next' +import cn from 'classnames' +import { FilePlus02 } from '@/app/components/base/icons/src/vender/line/files' +import { Loading02 } from '@/app/components/base/icons/src/vender/line/general' +import { AlertCircle } from '@/app/components/base/icons/src/vender/solid/alertsAndFeedback' +import { CheckCircle } from '@/app/components/base/icons/src/vender/solid/general' +import Popover from '@/app/components/base/popover' + +export type ISegmentAddProps = { + importStatus: ProcessStatus | string | undefined + clearProcessStatus: () => void + showNewSegmentModal: () => void + showBatchModal: () => void +} + +export enum ProcessStatus { + WAITING = 'waiting', + PROCESSING = 'processing', + COMPLETED = 'completed', + ERROR = 'error', +} + +const SegmentAdd: FC = ({ + importStatus, + clearProcessStatus, + showNewSegmentModal, + showBatchModal, +}) => { + const { t } = useTranslation() + + if (importStatus) { + return ( + <> + {(importStatus === ProcessStatus.WAITING || importStatus === ProcessStatus.PROCESSING) && ( +
+ {importStatus === ProcessStatus.WAITING &&
} + {importStatus === ProcessStatus.PROCESSING &&
} + + {t('datasetDocuments.list.batchModal.processing')} +
+ )} + {importStatus === ProcessStatus.COMPLETED && ( +
+ + {t('datasetDocuments.list.batchModal.completed')} + {t('datasetDocuments.list.batchModal.ok')} +
+ )} + {importStatus === ProcessStatus.ERROR && ( +
+ + {t('datasetDocuments.list.batchModal.error')} + {t('datasetDocuments.list.batchModal.ok')} +
+ )} + + ) + } + + return ( + +
{t('datasetDocuments.list.action.add')}
+
{t('datasetDocuments.list.action.batchAdd')}
+
+ } + btnElement={ +
+ + {t('datasetDocuments.list.action.addButton')} +
+ } + btnClassName={open => cn('mr-2 !py-[6px] !text-[13px] !leading-[18px] hover:bg-gray-50 border border-gray-200 hover:border-gray-300 hover:shadow-[0_1px_2px_rgba(16,24,40,0.05)]', open ? '!bg-gray-100 !shadow-none' : '!bg-transparent')} + className='!w-[132px] h-fit !z-20 !translate-x-0 !left-0' + /> + ) +} +export default React.memo(SegmentAdd) diff --git a/web/app/components/datasets/documents/list.tsx b/web/app/components/datasets/documents/list.tsx index 3075c3596b..df9aadd7f4 100644 --- a/web/app/components/datasets/documents/list.tsx +++ b/web/app/components/datasets/documents/list.tsx @@ -22,12 +22,12 @@ import type { IndicatorProps } from '@/app/components/header/indicator' import Indicator from '@/app/components/header/indicator' import { asyncRunSafe } from '@/utils' import { formatNumber } from '@/utils/format' -import { archiveDocument, deleteDocument, disableDocument, enableDocument, syncDocument } from '@/service/datasets' +import { archiveDocument, deleteDocument, disableDocument, enableDocument, syncDocument, unArchiveDocument } from '@/service/datasets' import NotionIcon from '@/app/components/base/notion-icon' import ProgressBar from '@/app/components/base/progress-bar' import { DataSourceType, type DocumentDisplayStatus, type SimpleDocumentDetail } from '@/models/datasets' import type { CommonResponse } from '@/models/common' -import { FilePlus02 } from '@/app/components/base/icons/src/vender/line/files' +import { DotsHorizontal } from '@/app/components/base/icons/src/vender/line/general' export const SettingsIcon: FC<{ className?: string }> = ({ className }) => { return @@ -86,7 +86,7 @@ export const StatusItem: FC<{
} -type OperationName = 'delete' | 'archive' | 'enable' | 'disable' | 'sync' +type OperationName = 'delete' | 'archive' | 'enable' | 'disable' | 'sync' | 'un_archive' // operation action for list and detail export const OperationAction: FC<{ @@ -101,8 +101,7 @@ export const OperationAction: FC<{ onUpdate: (operationName?: string) => void scene?: 'list' | 'detail' className?: string - showNewSegmentModal?: () => void -}> = ({ datasetId, detail, onUpdate, scene = 'list', className = '', showNewSegmentModal }) => { +}> = ({ datasetId, detail, onUpdate, scene = 'list', className = '' }) => { const { id, enabled = false, archived = false, data_source_type } = detail || {} const [showModal, setShowModal] = useState(false) const { notify } = useContext(ToastContext) @@ -117,6 +116,9 @@ export const OperationAction: FC<{ case 'archive': opApi = archiveDocument break + case 'un_archive': + opApi = unArchiveDocument + break case 'enable': opApi = enableDocument break @@ -218,10 +220,72 @@ export const OperationAction: FC<{ } } + htmlContent={ +
+ {!isListScene && <> +
+ + {!archived && enabled ? t('datasetDocuments.list.index.enable') : t('datasetDocuments.list.index.disable')} + + +
+ !archived && onOperate(v ? 'enable' : 'disable')} + disabled={archived} + size='md' + /> +
+
+
+
+ {!archived && enabled ? t('datasetDocuments.list.index.enableTip') : t('datasetDocuments.list.index.disableTip')} +
+ + } + {!archived && ( + <> +
router.push(`/datasets/${datasetId}/documents/${detail.id}/settings`)}> + + {t('datasetDocuments.list.action.settings')} +
+ {data_source_type === 'notion_import' && ( +
onOperate('sync')}> + + {t('datasetDocuments.list.action.sync')} +
+ )} + + + )} + {!archived &&
onOperate('archive')}> + + {t('datasetDocuments.list.action.archive')} +
} + {archived && ( +
onOperate('un_archive')}> + + {t('datasetDocuments.list.action.unarchive')} +
+ )} +
setShowModal(true)}> + + {t('datasetDocuments.list.action.delete')} +
+
+ } trigger='click' position='br' - btnElement={
} + btnElement={ +
+ +
+ } btnClassName={open => cn(isListScene ? s.actionIconWrapperList : s.actionIconWrapperDetail, open ? '!bg-gray-100 !shadow-none' : '!bg-transparent')} className={`!w-[200px] h-fit !z-20 ${className}`} /> diff --git a/web/app/components/datasets/settings/form/index.tsx b/web/app/components/datasets/settings/form/index.tsx index 6b627e11b9..e2454bba80 100644 --- a/web/app/components/datasets/settings/form/index.tsx +++ b/web/app/components/datasets/settings/form/index.tsx @@ -10,6 +10,10 @@ import { ToastContext } from '@/app/components/base/toast' import Button from '@/app/components/base/button' import { fetchDataDetail, updateDatasetSetting } from '@/service/datasets' import type { DataSet } from '@/models/datasets' +import ModelSelector from '@/app/components/header/account-setting/model-page/model-selector' +import type { ProviderEnum } from '@/app/components/header/account-setting/model-page/declarations' +import { ModelType } from '@/app/components/header/account-setting/model-page/declarations' +import AccountSetting from '@/app/components/header/account-setting' const rowClass = ` flex justify-between py-4 @@ -41,7 +45,7 @@ const Form = ({ const [description, setDescription] = useState(currentDataset?.description ?? '') const [permission, setPermission] = useState(currentDataset?.permission) const [indexMethod, setIndexMethod] = useState(currentDataset?.indexing_technique) - + const [showSetAPIKeyModal, setShowSetAPIKeyModal] = useState(false) const handleSave = async () => { if (loading) return @@ -128,6 +132,32 @@ const Form = ({ />
+
+
+
{t('datasetSettings.form.embeddingModel')}
+
+
+ {currentDataset && ( + <> +
+ {}} + /> +
+
+ {t('datasetSettings.form.embeddingModelTip')} + setShowSetAPIKeyModal(true)}>{t('datasetSettings.form.embeddingModelTipLink')} +
+ + )} +
+
@@ -140,6 +170,11 @@ const Form = ({
+ {showSetAPIKeyModal && ( + { + setShowSetAPIKeyModal(false) + }} /> + )}
) } diff --git a/web/i18n/lang/dataset-creation.en.ts b/web/i18n/lang/dataset-creation.en.ts index d218536e32..e5d1a286e4 100644 --- a/web/i18n/lang/dataset-creation.en.ts +++ b/web/i18n/lang/dataset-creation.en.ts @@ -75,6 +75,7 @@ const translation = { economicalTip: 'Use offline vector engines, keyword indexes, etc. to reduce accuracy without spending tokens', QATitle: 'Segmenting in Question & Answer format', QATip: 'Enable this option will consume more tokens', + QALanguage: 'Segment using', emstimateCost: 'Estimation', emstimateSegment: 'Estimated segments', segmentCount: 'segments', diff --git a/web/i18n/lang/dataset-creation.zh.ts b/web/i18n/lang/dataset-creation.zh.ts index 65b6d13fe0..f15bcc1e87 100644 --- a/web/i18n/lang/dataset-creation.zh.ts +++ b/web/i18n/lang/dataset-creation.zh.ts @@ -75,6 +75,7 @@ const translation = { economicalTip: '使用离线的向量引擎、关键词索引等方式,降低了准确度但无需花费 Token', QATitle: '采用 Q&A 分段模式', QATip: '开启后将会消耗额外的 token', + QALanguage: '分段使用', emstimateCost: '执行嵌入预估消耗', emstimateSegment: '预估分段数', segmentCount: '段', diff --git a/web/i18n/lang/dataset-documents.en.ts b/web/i18n/lang/dataset-documents.en.ts index 3d8f6decf0..c5fdcc4007 100644 --- a/web/i18n/lang/dataset-documents.en.ts +++ b/web/i18n/lang/dataset-documents.en.ts @@ -17,8 +17,11 @@ const translation = { action: { uploadFile: 'Upload new file', settings: 'Segment settings', - add: 'Add new segment', + addButton: 'Add segment', + add: 'Add a segment', + batchAdd: 'Batch add', archive: 'Archive', + unarchive: 'Unarchive', delete: 'Delete', enableWarning: 'Archived file cannot be enabled', sync: 'Sync', @@ -53,6 +56,24 @@ const translation = { title: 'Are you sure Delete?', content: 'If you need to resume processing later, you will continue from where you left off', }, + batchModal: { + title: 'Batch add segments', + csvUploadTitle: 'Drag and drop your CSV file here, or ', + browse: 'browse', + tip: 'The CSV file must conform to the following structure:', + question: 'question', + answer: 'answer', + contentTitle: 'segment content', + content: 'content', + template: 'Download the template here', + cancel: 'Cancel', + run: 'Run Batch', + runError: 'Run batch failed', + processing: 'In batch processing', + completed: 'Import completed', + error: 'Import Error', + ok: 'OK', + }, }, metadata: { title: 'Metadata', @@ -321,6 +342,7 @@ const translation = { contentEmpty: 'Content can not be empty', newTextSegment: 'New Text Segment', newQaSegment: 'New Q&A Segment', + delete: 'Delete this segment ?', }, } diff --git a/web/i18n/lang/dataset-documents.zh.ts b/web/i18n/lang/dataset-documents.zh.ts index 571e83a3a2..0ea31d190b 100644 --- a/web/i18n/lang/dataset-documents.zh.ts +++ b/web/i18n/lang/dataset-documents.zh.ts @@ -17,8 +17,11 @@ const translation = { action: { uploadFile: '上传新文件', settings: '分段设置', + addButton: '添加分段', add: '添加新分段', + batchAdd: '批量添加', archive: '归档', + unarchive: '撤销归档', delete: '删除', enableWarning: '归档的文件无法启用', sync: '同步', @@ -53,6 +56,24 @@ const translation = { title: '确定删除吗?', content: '如果您需要稍后恢复处理,您将从您离开的地方继续', }, + batchModal: { + title: '批量添加分段', + csvUploadTitle: '将您的 CSV 文件拖放到此处,或', + browse: '选择文件', + tip: 'CSV 文件必须符合以下结构:', + question: '问题', + answer: '回答', + contentTitle: '分段内容', + content: '内容', + template: '下载模板', + cancel: '取消', + run: '导入', + runError: '批量导入失败', + processing: '批量处理中', + completed: '导入完成', + error: '导入出错', + ok: '确定', + }, }, metadata: { title: '元数据', @@ -320,6 +341,7 @@ const translation = { contentEmpty: '内容不能为空', newTextSegment: '新文本分段', newQaSegment: '新问答分段', + delete: '删除这个分段?', }, } diff --git a/web/i18n/lang/dataset-settings.en.ts b/web/i18n/lang/dataset-settings.en.ts index 1337383ad4..c44ee7847c 100644 --- a/web/i18n/lang/dataset-settings.en.ts +++ b/web/i18n/lang/dataset-settings.en.ts @@ -15,6 +15,9 @@ const translation = { indexMethodHighQualityTip: 'Call OpenAI\'s embedding interface for processing to provide higher accuracy when users query.', indexMethodEconomy: 'Economical', indexMethodEconomyTip: 'Use offline vector engines, keyword indexes, etc. to reduce accuracy without spending tokens', + embeddingModel: 'Embedding Model', + embeddingModelTip: 'Change the embedded model, please go to ', + embeddingModelTipLink: 'Settings', save: 'Save', }, } diff --git a/web/i18n/lang/dataset-settings.zh.ts b/web/i18n/lang/dataset-settings.zh.ts index 0818836d79..91187be169 100644 --- a/web/i18n/lang/dataset-settings.zh.ts +++ b/web/i18n/lang/dataset-settings.zh.ts @@ -15,6 +15,9 @@ const translation = { indexMethodHighQualityTip: '调用 OpenAI 的嵌入接口进行处理,以在用户查询时提供更高的准确度', indexMethodEconomy: '经济', indexMethodEconomyTip: '使用离线的向量引擎、关键词索引等方式,降低了准确度但无需花费 Token', + embeddingModel: 'Embedding 模型', + embeddingModelTip: '修改 Embedding 模型,请去', + embeddingModelTipLink: '设置', save: '保存', }, } diff --git a/web/i18n/lang/dataset.en.ts b/web/i18n/lang/dataset.en.ts index 94d98c2206..680c274ef4 100644 --- a/web/i18n/lang/dataset.en.ts +++ b/web/i18n/lang/dataset.en.ts @@ -16,6 +16,8 @@ const translation = { intro4: 'or it ', intro5: 'can be created', intro6: ' as a standalone ChatGPT index plug-in to publish', + unavailable: 'Unavailable', + unavailableTip: 'Embedding model is not available, the default embedding model needs to be configured', } export default translation diff --git a/web/i18n/lang/dataset.zh.ts b/web/i18n/lang/dataset.zh.ts index 2174df2e05..f2e2fbb16f 100644 --- a/web/i18n/lang/dataset.zh.ts +++ b/web/i18n/lang/dataset.zh.ts @@ -16,6 +16,8 @@ const translation = { intro4: '或可以', intro5: '创建', intro6: '为独立的 ChatGPT 插件发布使用', + unavailable: '不可用', + unavailableTip: '由于 embedding 模型不可用,需要配置默认 embedding 模型', } export default translation diff --git a/web/models/datasets.ts b/web/models/datasets.ts index 0167babf74..fa52ae48cb 100644 --- a/web/models/datasets.ts +++ b/web/models/datasets.ts @@ -22,6 +22,9 @@ export type DataSet = { app_count: number document_count: number word_count: number + embedding_model: string + embedding_model_provider: string + embedding_available: boolean } export type CustomFile = File & { @@ -184,6 +187,7 @@ export type CreateDocumentReq = { original_document_id?: string indexing_technique?: string doc_form: 'text_model' | 'qa_model' + doc_language: string data_source: DataSource process_rule: ProcessRule } @@ -390,3 +394,8 @@ export type SegmentUpdator = { answer?: string keywords?: string[] } + +export enum DocForm { + TEXT = 'text_model', + QA = 'qa_model', +} diff --git a/web/service/datasets.ts b/web/service/datasets.ts index 209260e93f..d2178db0d5 100644 --- a/web/service/datasets.ts +++ b/web/service/datasets.ts @@ -118,6 +118,10 @@ export const archiveDocument: Fetcher = ({ dataset return patch(`/datasets/${datasetId}/documents/${documentId}/status/archive`) as Promise } +export const unArchiveDocument: Fetcher = ({ datasetId, documentId }) => { + return patch(`/datasets/${datasetId}/documents/${documentId}/status/un_archive`) as Promise +} + export const enableDocument: Fetcher = ({ datasetId, documentId }) => { return patch(`/datasets/${datasetId}/documents/${documentId}/status/enable`) as Promise } @@ -138,10 +142,6 @@ export const modifyDocMetadata: Fetcher } -export const getDatasetIndexingStatus: Fetcher<{ data: IndexingStatusResponse[] }, string> = (datasetId) => { - return get(`/datasets/${datasetId}/indexing-status`) as Promise<{ data: IndexingStatusResponse[] }> -} - // apis for segments in a document export const fetchSegments: Fetcher = ({ datasetId, documentId, params }) => { @@ -164,6 +164,18 @@ export const addSegment: Fetcher<{ data: SegmentDetailModel; doc_form: string }, return post(`/datasets/${datasetId}/documents/${documentId}/segment`, { body }) as Promise<{ data: SegmentDetailModel; doc_form: string }> } +export const deleteSegment: Fetcher = ({ datasetId, documentId, segmentId }) => { + return del(`/datasets/${datasetId}/documents/${documentId}/segments/${segmentId}`) as Promise +} + +export const segmentBatchImport: Fetcher<{ job_id: string; job_status: string }, { url: string; body: FormData }> = ({ url, body }) => { + return post(url, { body }, { bodyStringify: false, deleteContentType: true }) as Promise<{ job_id: string; job_status: string }> +} + +export const checkSegmentBatchImportProgress: Fetcher<{ job_id: string; job_status: string }, { jobID: string }> = ({ jobID }) => { + return get(`/datasets/batch_import_status/${jobID}`) as Promise<{ job_id: string; job_status: string }> +} + // hit testing export const hitTesting: Fetcher = ({ datasetId, queryText }) => { return post(`/datasets/${datasetId}/hit-testing`, { body: { query: queryText } }) as Promise