Feat/support parent child chunk (#12092)

This commit is contained in:
Jyong 2024-12-25 19:49:07 +08:00 committed by GitHub
parent 017d7538ae
commit 9231fdbf4c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
54 changed files with 2578 additions and 808 deletions

View File

@ -218,7 +218,7 @@ class DataSourceNotionApi(Resource):
args["doc_form"], args["doc_form"],
args["doc_language"], args["doc_language"],
) )
return response, 200 return response.model_dump(), 200
class DataSourceNotionDatasetSyncApi(Resource): class DataSourceNotionDatasetSyncApi(Resource):

View File

@ -464,7 +464,7 @@ class DatasetIndexingEstimateApi(Resource):
except Exception as e: except Exception as e:
raise IndexingEstimateError(str(e)) raise IndexingEstimateError(str(e))
return response, 200 return response.model_dump(), 200
class DatasetRelatedAppListApi(Resource): class DatasetRelatedAppListApi(Resource):
@ -733,6 +733,18 @@ class DatasetPermissionUserListApi(Resource):
}, 200 }, 200
class DatasetAutoDisableLogApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, dataset_id):
dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
if dataset is None:
raise NotFound("Dataset not found.")
return DatasetService.get_dataset_auto_disable_logs(dataset_id_str), 200
api.add_resource(DatasetListApi, "/datasets") api.add_resource(DatasetListApi, "/datasets")
api.add_resource(DatasetApi, "/datasets/<uuid:dataset_id>") api.add_resource(DatasetApi, "/datasets/<uuid:dataset_id>")
api.add_resource(DatasetUseCheckApi, "/datasets/<uuid:dataset_id>/use-check") api.add_resource(DatasetUseCheckApi, "/datasets/<uuid:dataset_id>/use-check")
@ -747,3 +759,4 @@ api.add_resource(DatasetApiBaseUrlApi, "/datasets/api-base-info")
api.add_resource(DatasetRetrievalSettingApi, "/datasets/retrieval-setting") api.add_resource(DatasetRetrievalSettingApi, "/datasets/retrieval-setting")
api.add_resource(DatasetRetrievalSettingMockApi, "/datasets/retrieval-setting/<string:vector_type>") api.add_resource(DatasetRetrievalSettingMockApi, "/datasets/retrieval-setting/<string:vector_type>")
api.add_resource(DatasetPermissionUserListApi, "/datasets/<uuid:dataset_id>/permission-part-users") api.add_resource(DatasetPermissionUserListApi, "/datasets/<uuid:dataset_id>/permission-part-users")
api.add_resource(DatasetAutoDisableLogApi, "/datasets/<uuid:dataset_id>/auto-disable-logs")

View File

@ -52,6 +52,7 @@ from fields.document_fields import (
from libs.login import login_required from libs.login import login_required
from models import Dataset, DatasetProcessRule, Document, DocumentSegment, UploadFile from models import Dataset, DatasetProcessRule, Document, DocumentSegment, UploadFile
from services.dataset_service import DatasetService, DocumentService from services.dataset_service import DatasetService, DocumentService
from services.entities.knowledge_entities.knowledge_entities import KnowledgeConfig
from tasks.add_document_to_index_task import add_document_to_index_task from tasks.add_document_to_index_task import add_document_to_index_task
from tasks.remove_document_from_index_task import remove_document_from_index_task from tasks.remove_document_from_index_task import remove_document_from_index_task
@ -255,20 +256,22 @@ class DatasetDocumentListApi(Resource):
parser.add_argument("duplicate", type=bool, default=True, nullable=False, location="json") parser.add_argument("duplicate", type=bool, default=True, nullable=False, location="json")
parser.add_argument("original_document_id", type=str, required=False, location="json") parser.add_argument("original_document_id", type=str, required=False, location="json")
parser.add_argument("doc_form", type=str, default="text_model", required=False, nullable=False, location="json") parser.add_argument("doc_form", type=str, default="text_model", required=False, nullable=False, location="json")
parser.add_argument("retrieval_model", type=dict, required=False, nullable=False, location="json")
parser.add_argument( parser.add_argument(
"doc_language", type=str, default="English", required=False, nullable=False, location="json" "doc_language", type=str, default="English", required=False, nullable=False, location="json"
) )
parser.add_argument("retrieval_model", type=dict, required=False, nullable=False, location="json")
args = parser.parse_args() args = parser.parse_args()
knowledge_config = KnowledgeConfig(**args)
if not dataset.indexing_technique and not args["indexing_technique"]: if not dataset.indexing_technique and not knowledge_config.indexing_technique:
raise ValueError("indexing_technique is required.") raise ValueError("indexing_technique is required.")
# validate args # validate args
DocumentService.document_create_args_validate(args) DocumentService.document_create_args_validate(knowledge_config)
try: try:
documents, batch = DocumentService.save_document_with_dataset_id(dataset, args, current_user) documents, batch = DocumentService.save_document_with_dataset_id(dataset, knowledge_config, current_user)
except ProviderTokenNotInitError as ex: except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description) raise ProviderNotInitializeError(ex.description)
except QuotaExceededError: except QuotaExceededError:
@ -278,6 +281,25 @@ class DatasetDocumentListApi(Resource):
return {"documents": documents, "batch": batch} return {"documents": documents, "batch": batch}
@setup_required
@login_required
@account_initialization_required
def delete(self, dataset_id):
dataset_id = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id)
if dataset is None:
raise NotFound("Dataset not found.")
# check user's model setting
DatasetService.check_dataset_model_setting(dataset)
try:
document_ids = request.args.getlist("document_id")
DocumentService.delete_documents(dataset, document_ids)
except services.errors.document.DocumentIndexingError:
raise DocumentIndexingError("Cannot delete document during indexing.")
return {"result": "success"}, 204
class DatasetInitApi(Resource): class DatasetInitApi(Resource):
@setup_required @setup_required
@ -313,9 +335,9 @@ class DatasetInitApi(Resource):
# The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator # The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator
if not current_user.is_dataset_editor: if not current_user.is_dataset_editor:
raise Forbidden() raise Forbidden()
knowledge_config = KnowledgeConfig(**args)
if args["indexing_technique"] == "high_quality": if knowledge_config.indexing_technique == "high_quality":
if args["embedding_model"] is None or args["embedding_model_provider"] is None: if knowledge_config.embedding_model is None or knowledge_config.embedding_model_provider is None:
raise ValueError("embedding model and embedding model provider are required for high quality indexing.") raise ValueError("embedding model and embedding model provider are required for high quality indexing.")
try: try:
model_manager = ModelManager() model_manager = ModelManager()
@ -334,11 +356,11 @@ class DatasetInitApi(Resource):
raise ProviderNotInitializeError(ex.description) raise ProviderNotInitializeError(ex.description)
# validate args # validate args
DocumentService.document_create_args_validate(args) DocumentService.document_create_args_validate(knowledge_config)
try: try:
dataset, documents, batch = DocumentService.save_document_without_dataset_id( dataset, documents, batch = DocumentService.save_document_without_dataset_id(
tenant_id=current_user.current_tenant_id, document_data=args, account=current_user tenant_id=current_user.current_tenant_id, knowledge_config=knowledge_config, account=current_user
) )
except ProviderTokenNotInitError as ex: except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description) raise ProviderNotInitializeError(ex.description)
@ -409,7 +431,7 @@ class DocumentIndexingEstimateApi(DocumentResource):
except Exception as e: except Exception as e:
raise IndexingEstimateError(str(e)) raise IndexingEstimateError(str(e))
return response return response.model_dump(), 200
class DocumentBatchIndexingEstimateApi(DocumentResource): class DocumentBatchIndexingEstimateApi(DocumentResource):
@ -422,7 +444,7 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
documents = self.get_batch_documents(dataset_id, batch) documents = self.get_batch_documents(dataset_id, batch)
response = {"tokens": 0, "total_price": 0, "currency": "USD", "total_segments": 0, "preview": []} response = {"tokens": 0, "total_price": 0, "currency": "USD", "total_segments": 0, "preview": []}
if not documents: if not documents:
return response return response, 200
data_process_rule = documents[0].dataset_process_rule data_process_rule = documents[0].dataset_process_rule
data_process_rule_dict = data_process_rule.to_dict() data_process_rule_dict = data_process_rule.to_dict()
info_list = [] info_list = []
@ -509,7 +531,7 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
raise ProviderNotInitializeError(ex.description) raise ProviderNotInitializeError(ex.description)
except Exception as e: except Exception as e:
raise IndexingEstimateError(str(e)) raise IndexingEstimateError(str(e))
return response return response.model_dump(), 200
class DocumentBatchIndexingStatusApi(DocumentResource): class DocumentBatchIndexingStatusApi(DocumentResource):
@ -582,7 +604,8 @@ class DocumentDetailApi(DocumentResource):
if metadata == "only": if metadata == "only":
response = {"id": document.id, "doc_type": document.doc_type, "doc_metadata": document.doc_metadata} response = {"id": document.id, "doc_type": document.doc_type, "doc_metadata": document.doc_metadata}
elif metadata == "without": elif metadata == "without":
process_rules = DatasetService.get_process_rules(dataset_id) dataset_process_rules = DatasetService.get_process_rules(dataset_id)
document_process_rules = document.dataset_process_rule.to_dict()
data_source_info = document.data_source_detail_dict data_source_info = document.data_source_detail_dict
response = { response = {
"id": document.id, "id": document.id,
@ -590,7 +613,8 @@ class DocumentDetailApi(DocumentResource):
"data_source_type": document.data_source_type, "data_source_type": document.data_source_type,
"data_source_info": data_source_info, "data_source_info": data_source_info,
"dataset_process_rule_id": document.dataset_process_rule_id, "dataset_process_rule_id": document.dataset_process_rule_id,
"dataset_process_rule": process_rules, "dataset_process_rule": dataset_process_rules,
"document_process_rule": document_process_rules,
"name": document.name, "name": document.name,
"created_from": document.created_from, "created_from": document.created_from,
"created_by": document.created_by, "created_by": document.created_by,
@ -613,7 +637,8 @@ class DocumentDetailApi(DocumentResource):
"doc_language": document.doc_language, "doc_language": document.doc_language,
} }
else: else:
process_rules = DatasetService.get_process_rules(dataset_id) dataset_process_rules = DatasetService.get_process_rules(dataset_id)
document_process_rules = document.dataset_process_rule.to_dict()
data_source_info = document.data_source_detail_dict data_source_info = document.data_source_detail_dict
response = { response = {
"id": document.id, "id": document.id,
@ -621,7 +646,8 @@ class DocumentDetailApi(DocumentResource):
"data_source_type": document.data_source_type, "data_source_type": document.data_source_type,
"data_source_info": data_source_info, "data_source_info": data_source_info,
"dataset_process_rule_id": document.dataset_process_rule_id, "dataset_process_rule_id": document.dataset_process_rule_id,
"dataset_process_rule": process_rules, "dataset_process_rule": dataset_process_rules,
"document_process_rule": document_process_rules,
"name": document.name, "name": document.name,
"created_from": document.created_from, "created_from": document.created_from,
"created_by": document.created_by, "created_by": document.created_by,
@ -757,9 +783,8 @@ class DocumentStatusApi(DocumentResource):
@login_required @login_required
@account_initialization_required @account_initialization_required
@cloud_edition_billing_resource_check("vector_space") @cloud_edition_billing_resource_check("vector_space")
def patch(self, dataset_id, document_id, action): def patch(self, dataset_id, action):
dataset_id = str(dataset_id) dataset_id = str(dataset_id)
document_id = str(document_id)
dataset = DatasetService.get_dataset(dataset_id) dataset = DatasetService.get_dataset(dataset_id)
if dataset is None: if dataset is None:
raise NotFound("Dataset not found.") raise NotFound("Dataset not found.")
@ -774,84 +799,79 @@ class DocumentStatusApi(DocumentResource):
# check user's permission # check user's permission
DatasetService.check_dataset_permission(dataset, current_user) DatasetService.check_dataset_permission(dataset, current_user)
document = self.get_document(dataset_id, document_id) document_ids = request.args.getlist("document_id")
for document_id in document_ids:
document = self.get_document(dataset_id, document_id)
indexing_cache_key = "document_{}_indexing".format(document.id) indexing_cache_key = "document_{}_indexing".format(document.id)
cache_result = redis_client.get(indexing_cache_key) cache_result = redis_client.get(indexing_cache_key)
if cache_result is not None: if cache_result is not None:
raise InvalidActionError("Document is being indexed, please try again later") raise InvalidActionError(f"Document:{document.name} is being indexed, please try again later")
if action == "enable": if action == "enable":
if document.enabled: if document.enabled:
raise InvalidActionError("Document already enabled.") continue
document.enabled = True
document.disabled_at = None
document.disabled_by = None
document.updated_at = datetime.now(UTC).replace(tzinfo=None)
db.session.commit()
document.enabled = True # Set cache to prevent indexing the same document multiple times
document.disabled_at = None redis_client.setex(indexing_cache_key, 600, 1)
document.disabled_by = None
document.updated_at = datetime.now(UTC).replace(tzinfo=None)
db.session.commit()
# Set cache to prevent indexing the same document multiple times add_document_to_index_task.delay(document_id)
redis_client.setex(indexing_cache_key, 600, 1)
add_document_to_index_task.delay(document_id) elif action == "disable":
if not document.completed_at or document.indexing_status != "completed":
raise InvalidActionError(f"Document: {document.name} is not completed.")
if not document.enabled:
continue
return {"result": "success"}, 200 document.enabled = False
document.disabled_at = datetime.now(UTC).replace(tzinfo=None)
document.disabled_by = current_user.id
document.updated_at = datetime.now(UTC).replace(tzinfo=None)
db.session.commit()
elif action == "disable":
if not document.completed_at or document.indexing_status != "completed":
raise InvalidActionError("Document is not completed.")
if not document.enabled:
raise InvalidActionError("Document already disabled.")
document.enabled = False
document.disabled_at = datetime.now(UTC).replace(tzinfo=None)
document.disabled_by = current_user.id
document.updated_at = datetime.now(UTC).replace(tzinfo=None)
db.session.commit()
# Set cache to prevent indexing the same document multiple times
redis_client.setex(indexing_cache_key, 600, 1)
remove_document_from_index_task.delay(document_id)
return {"result": "success"}, 200
elif action == "archive":
if document.archived:
raise InvalidActionError("Document already archived.")
document.archived = True
document.archived_at = datetime.now(UTC).replace(tzinfo=None)
document.archived_by = current_user.id
document.updated_at = datetime.now(UTC).replace(tzinfo=None)
db.session.commit()
if document.enabled:
# Set cache to prevent indexing the same document multiple times # Set cache to prevent indexing the same document multiple times
redis_client.setex(indexing_cache_key, 600, 1) redis_client.setex(indexing_cache_key, 600, 1)
remove_document_from_index_task.delay(document_id) remove_document_from_index_task.delay(document_id)
return {"result": "success"}, 200 elif action == "archive":
elif action == "un_archive": if document.archived:
if not document.archived: continue
raise InvalidActionError("Document is not archived.")
document.archived = False document.archived = True
document.archived_at = None document.archived_at = datetime.now(UTC).replace(tzinfo=None)
document.archived_by = None document.archived_by = current_user.id
document.updated_at = datetime.now(UTC).replace(tzinfo=None) document.updated_at = datetime.now(UTC).replace(tzinfo=None)
db.session.commit() db.session.commit()
# Set cache to prevent indexing the same document multiple times if document.enabled:
redis_client.setex(indexing_cache_key, 600, 1) # Set cache to prevent indexing the same document multiple times
redis_client.setex(indexing_cache_key, 600, 1)
add_document_to_index_task.delay(document_id) remove_document_from_index_task.delay(document_id)
return {"result": "success"}, 200 elif action == "un_archive":
else: if not document.archived:
raise InvalidActionError() continue
document.archived = False
document.archived_at = None
document.archived_by = None
document.updated_at = datetime.now(UTC).replace(tzinfo=None)
db.session.commit()
# Set cache to prevent indexing the same document multiple times
redis_client.setex(indexing_cache_key, 600, 1)
add_document_to_index_task.delay(document_id)
else:
raise InvalidActionError()
return {"result": "success"}, 200
class DocumentPauseApi(DocumentResource): class DocumentPauseApi(DocumentResource):
@ -1022,7 +1042,7 @@ api.add_resource(
) )
api.add_resource(DocumentDeleteApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>") api.add_resource(DocumentDeleteApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>")
api.add_resource(DocumentMetadataApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/metadata") api.add_resource(DocumentMetadataApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/metadata")
api.add_resource(DocumentStatusApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/status/<string:action>") api.add_resource(DocumentStatusApi, "/datasets/<uuid:dataset_id>/documents/status/<string:action>/batch")
api.add_resource(DocumentPauseApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/processing/pause") api.add_resource(DocumentPauseApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/processing/pause")
api.add_resource(DocumentRecoverApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/processing/resume") api.add_resource(DocumentRecoverApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/processing/resume")
api.add_resource(DocumentRetryApi, "/datasets/<uuid:dataset_id>/retry") api.add_resource(DocumentRetryApi, "/datasets/<uuid:dataset_id>/retry")

View File

@ -1,5 +1,4 @@
import uuid import uuid
from datetime import UTC, datetime
import pandas as pd import pandas as pd
from flask import request from flask import request
@ -10,7 +9,13 @@ from werkzeug.exceptions import Forbidden, NotFound
import services import services
from controllers.console import api from controllers.console import api
from controllers.console.app.error import ProviderNotInitializeError from controllers.console.app.error import ProviderNotInitializeError
from controllers.console.datasets.error import InvalidActionError, NoFileUploadedError, TooManyFilesError from controllers.console.datasets.error import (
ChildChunkDeleteIndexError,
ChildChunkIndexingError,
InvalidActionError,
NoFileUploadedError,
TooManyFilesError,
)
from controllers.console.wraps import ( from controllers.console.wraps import (
account_initialization_required, account_initialization_required,
cloud_edition_billing_knowledge_limit_check, cloud_edition_billing_knowledge_limit_check,
@ -20,15 +25,15 @@ from controllers.console.wraps import (
from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError
from core.model_manager import ModelManager from core.model_manager import ModelManager
from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.entities.model_entities import ModelType
from extensions.ext_database import db
from extensions.ext_redis import redis_client from extensions.ext_redis import redis_client
from fields.segment_fields import segment_fields from fields.segment_fields import child_chunk_fields, segment_fields
from libs.login import login_required from libs.login import login_required
from models import DocumentSegment from models.dataset import ChildChunk, DocumentSegment
from services.dataset_service import DatasetService, DocumentService, SegmentService from services.dataset_service import DatasetService, DocumentService, SegmentService
from services.entities.knowledge_entities.knowledge_entities import ChildChunkUpdateArgs, SegmentUpdateArgs
from services.errors.chunk import ChildChunkDeleteIndexError as ChildChunkDeleteIndexServiceError
from services.errors.chunk import ChildChunkIndexingError as ChildChunkIndexingServiceError
from tasks.batch_create_segment_to_index_task import batch_create_segment_to_index_task from tasks.batch_create_segment_to_index_task import batch_create_segment_to_index_task
from tasks.disable_segment_from_index_task import disable_segment_from_index_task
from tasks.enable_segment_to_index_task import enable_segment_to_index_task
class DatasetDocumentSegmentListApi(Resource): class DatasetDocumentSegmentListApi(Resource):
@ -53,15 +58,16 @@ class DatasetDocumentSegmentListApi(Resource):
raise NotFound("Document not found.") raise NotFound("Document not found.")
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("last_id", type=str, default=None, location="args")
parser.add_argument("limit", type=int, default=20, location="args") parser.add_argument("limit", type=int, default=20, location="args")
parser.add_argument("status", type=str, action="append", default=[], location="args") parser.add_argument("status", type=str, action="append", default=[], location="args")
parser.add_argument("hit_count_gte", type=int, default=None, location="args") parser.add_argument("hit_count_gte", type=int, default=None, location="args")
parser.add_argument("enabled", type=str, default="all", location="args") parser.add_argument("enabled", type=str, default="all", location="args")
parser.add_argument("keyword", type=str, default=None, location="args") parser.add_argument("keyword", type=str, default=None, location="args")
parser.add_argument("page", type=int, default=1, location="args")
args = parser.parse_args() args = parser.parse_args()
last_id = args["last_id"] page = args["page"]
limit = min(args["limit"], 100) limit = min(args["limit"], 100)
status_list = args["status"] status_list = args["status"]
hit_count_gte = args["hit_count_gte"] hit_count_gte = args["hit_count_gte"]
@ -69,14 +75,7 @@ class DatasetDocumentSegmentListApi(Resource):
query = DocumentSegment.query.filter( query = DocumentSegment.query.filter(
DocumentSegment.document_id == str(document_id), DocumentSegment.tenant_id == current_user.current_tenant_id DocumentSegment.document_id == str(document_id), DocumentSegment.tenant_id == current_user.current_tenant_id
) ).order_by(DocumentSegment.position.asc())
if last_id is not None:
last_segment = db.session.get(DocumentSegment, str(last_id))
if last_segment:
query = query.filter(DocumentSegment.position > last_segment.position)
else:
return {"data": [], "has_more": False, "limit": limit}, 200
if status_list: if status_list:
query = query.filter(DocumentSegment.status.in_(status_list)) query = query.filter(DocumentSegment.status.in_(status_list))
@ -93,21 +92,44 @@ class DatasetDocumentSegmentListApi(Resource):
elif args["enabled"].lower() == "false": elif args["enabled"].lower() == "false":
query = query.filter(DocumentSegment.enabled == False) query = query.filter(DocumentSegment.enabled == False)
total = query.count() segments = query.paginate(page=page, per_page=limit, max_per_page=100, error_out=False)
segments = query.order_by(DocumentSegment.position).limit(limit + 1).all()
has_more = False response = {
if len(segments) > limit: "data": marshal(segments.items, segment_fields),
has_more = True
segments = segments[:-1]
return {
"data": marshal(segments, segment_fields),
"doc_form": document.doc_form,
"has_more": has_more,
"limit": limit, "limit": limit,
"total": total, "total": segments.total,
}, 200 "total_pages": segments.pages,
"page": page,
}
return response, 200
@setup_required
@login_required
@account_initialization_required
def delete(self, dataset_id, document_id):
# check dataset
dataset_id = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id)
if not dataset:
raise NotFound("Dataset not found.")
# check user's model setting
DatasetService.check_dataset_model_setting(dataset)
# check document
document_id = str(document_id)
document = DocumentService.get_document(dataset_id, document_id)
if not document:
raise NotFound("Document not found.")
segment_ids = request.args.getlist("segment_id")
# The role of the current user in the ta table must be admin or owner
if not current_user.is_editor:
raise Forbidden()
try:
DatasetService.check_dataset_permission(dataset, current_user)
except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e))
SegmentService.delete_segments(segment_ids, document, dataset)
return {"result": "success"}, 200
class DatasetDocumentSegmentApi(Resource): class DatasetDocumentSegmentApi(Resource):
@ -115,11 +137,15 @@ class DatasetDocumentSegmentApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
@cloud_edition_billing_resource_check("vector_space") @cloud_edition_billing_resource_check("vector_space")
def patch(self, dataset_id, segment_id, action): def patch(self, dataset_id, document_id, action):
dataset_id = str(dataset_id) dataset_id = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id) dataset = DatasetService.get_dataset(dataset_id)
if not dataset: if not dataset:
raise NotFound("Dataset not found.") raise NotFound("Dataset not found.")
document_id = str(document_id)
document = DocumentService.get_document(dataset_id, document_id)
if not document:
raise NotFound("Document not found.")
# check user's model setting # check user's model setting
DatasetService.check_dataset_model_setting(dataset) DatasetService.check_dataset_model_setting(dataset)
# The role of the current user in the ta table must be admin, owner, or editor # The role of the current user in the ta table must be admin, owner, or editor
@ -147,59 +173,17 @@ class DatasetDocumentSegmentApi(Resource):
) )
except ProviderTokenNotInitError as ex: except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description) raise ProviderNotInitializeError(ex.description)
segment_ids = request.args.getlist("segment_id")
segment = DocumentSegment.query.filter( document_indexing_cache_key = "document_{}_indexing".format(document.id)
DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id
).first()
if not segment:
raise NotFound("Segment not found.")
if segment.status != "completed":
raise NotFound("Segment is not completed, enable or disable function is not allowed")
document_indexing_cache_key = "document_{}_indexing".format(segment.document_id)
cache_result = redis_client.get(document_indexing_cache_key) cache_result = redis_client.get(document_indexing_cache_key)
if cache_result is not None: if cache_result is not None:
raise InvalidActionError("Document is being indexed, please try again later") raise InvalidActionError("Document is being indexed, please try again later")
try:
indexing_cache_key = "segment_{}_indexing".format(segment.id) SegmentService.update_segments_status(segment_ids, action, dataset, document)
cache_result = redis_client.get(indexing_cache_key) except Exception as e:
if cache_result is not None: raise InvalidActionError(str(e))
raise InvalidActionError("Segment is being indexed, please try again later") return {"result": "success"}, 200
if action == "enable":
if segment.enabled:
raise InvalidActionError("Segment is already enabled.")
segment.enabled = True
segment.disabled_at = None
segment.disabled_by = None
db.session.commit()
# Set cache to prevent indexing the same segment multiple times
redis_client.setex(indexing_cache_key, 600, 1)
enable_segment_to_index_task.delay(segment.id)
return {"result": "success"}, 200
elif action == "disable":
if not segment.enabled:
raise InvalidActionError("Segment is already disabled.")
segment.enabled = False
segment.disabled_at = datetime.now(UTC).replace(tzinfo=None)
segment.disabled_by = current_user.id
db.session.commit()
# Set cache to prevent indexing the same segment multiple times
redis_client.setex(indexing_cache_key, 600, 1)
disable_segment_from_index_task.delay(segment.id)
return {"result": "success"}, 200
else:
raise InvalidActionError()
class DatasetDocumentSegmentAddApi(Resource): class DatasetDocumentSegmentAddApi(Resource):
@ -307,9 +291,12 @@ class DatasetDocumentSegmentUpdateApi(Resource):
parser.add_argument("content", type=str, required=True, nullable=False, location="json") parser.add_argument("content", type=str, required=True, nullable=False, location="json")
parser.add_argument("answer", type=str, required=False, nullable=True, location="json") parser.add_argument("answer", type=str, required=False, nullable=True, location="json")
parser.add_argument("keywords", type=list, required=False, nullable=True, location="json") parser.add_argument("keywords", type=list, required=False, nullable=True, location="json")
parser.add_argument(
"regenerate_child_chunks", type=bool, required=False, nullable=True, default=False, location="json"
)
args = parser.parse_args() args = parser.parse_args()
SegmentService.segment_create_args_validate(args, document) SegmentService.segment_create_args_validate(args, document)
segment = SegmentService.update_segment(args, segment, document, dataset) segment = SegmentService.update_segment(SegmentUpdateArgs(**args), segment, document, dataset)
return {"data": marshal(segment, segment_fields), "doc_form": document.doc_form}, 200 return {"data": marshal(segment, segment_fields), "doc_form": document.doc_form}, 200
@setup_required @setup_required
@ -412,8 +399,248 @@ class DatasetDocumentSegmentBatchImportApi(Resource):
return {"job_id": job_id, "job_status": cache_result.decode()}, 200 return {"job_id": job_id, "job_status": cache_result.decode()}, 200
class ChildChunkAddApi(Resource):
@setup_required
@login_required
@account_initialization_required
@cloud_edition_billing_resource_check("vector_space")
@cloud_edition_billing_knowledge_limit_check("add_segment")
def post(self, dataset_id, document_id, segment_id):
# check dataset
dataset_id = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id)
if not dataset:
raise NotFound("Dataset not found.")
# check document
document_id = str(document_id)
document = DocumentService.get_document(dataset_id, document_id)
if not document:
raise NotFound("Document not found.")
# check segment
segment_id = str(segment_id)
segment = DocumentSegment.query.filter(
DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id
).first()
if not segment:
raise NotFound("Segment not found.")
if not current_user.is_editor:
raise Forbidden()
# check embedding model setting
if dataset.indexing_technique == "high_quality":
try:
model_manager = ModelManager()
model_manager.get_model_instance(
tenant_id=current_user.current_tenant_id,
provider=dataset.embedding_model_provider,
model_type=ModelType.TEXT_EMBEDDING,
model=dataset.embedding_model,
)
except LLMBadRequestError:
raise ProviderNotInitializeError(
"No Embedding Model available. Please configure a valid provider "
"in the Settings -> Model Provider."
)
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
try:
DatasetService.check_dataset_permission(dataset, current_user)
except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e))
# validate args
parser = reqparse.RequestParser()
parser.add_argument("content", type=str, required=True, nullable=False, location="json")
args = parser.parse_args()
try:
child_chunk = SegmentService.create_child_chunk(args.get("content"), segment, document, dataset)
except ChildChunkIndexingServiceError as e:
raise ChildChunkIndexingError(str(e))
return {"data": marshal(child_chunk, child_chunk_fields)}, 200
@setup_required
@login_required
@account_initialization_required
def get(self, dataset_id, document_id, segment_id):
# check dataset
dataset_id = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id)
if not dataset:
raise NotFound("Dataset not found.")
# check user's model setting
DatasetService.check_dataset_model_setting(dataset)
# check document
document_id = str(document_id)
document = DocumentService.get_document(dataset_id, document_id)
if not document:
raise NotFound("Document not found.")
# check segment
segment_id = str(segment_id)
segment = DocumentSegment.query.filter(
DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id
).first()
if not segment:
raise NotFound("Segment not found.")
parser = reqparse.RequestParser()
parser.add_argument("limit", type=int, default=20, location="args")
parser.add_argument("keyword", type=str, default=None, location="args")
parser.add_argument("page", type=int, default=1, location="args")
args = parser.parse_args()
page = args["page"]
limit = min(args["limit"], 100)
keyword = args["keyword"]
child_chunks = SegmentService.get_child_chunks(segment_id, document_id, dataset_id, page, limit, keyword)
return {
"data": marshal(child_chunks.items, child_chunk_fields),
"total": child_chunks.total,
"total_pages": child_chunks.pages,
"page": page,
"limit": limit,
}, 200
@setup_required
@login_required
@account_initialization_required
@cloud_edition_billing_resource_check("vector_space")
def patch(self, dataset_id, document_id, segment_id):
# check dataset
dataset_id = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id)
if not dataset:
raise NotFound("Dataset not found.")
# check user's model setting
DatasetService.check_dataset_model_setting(dataset)
# check document
document_id = str(document_id)
document = DocumentService.get_document(dataset_id, document_id)
if not document:
raise NotFound("Document not found.")
# check segment
segment_id = str(segment_id)
segment = DocumentSegment.query.filter(
DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id
).first()
if not segment:
raise NotFound("Segment not found.")
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor:
raise Forbidden()
try:
DatasetService.check_dataset_permission(dataset, current_user)
except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e))
# validate args
parser = reqparse.RequestParser()
parser.add_argument("chunks", type=list, required=True, nullable=False, location="json")
args = parser.parse_args()
try:
chunks = [ChildChunkUpdateArgs(**chunk) for chunk in args.get("chunks")]
child_chunks = SegmentService.update_child_chunks(chunks, segment, document, dataset)
except ChildChunkIndexingServiceError as e:
raise ChildChunkIndexingError(str(e))
return {"data": marshal(child_chunks, child_chunk_fields)}, 200
class ChildChunkUpdateApi(Resource):
@setup_required
@login_required
@account_initialization_required
def delete(self, dataset_id, document_id, segment_id, child_chunk_id):
# check dataset
dataset_id = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id)
if not dataset:
raise NotFound("Dataset not found.")
# check user's model setting
DatasetService.check_dataset_model_setting(dataset)
# check document
document_id = str(document_id)
document = DocumentService.get_document(dataset_id, document_id)
if not document:
raise NotFound("Document not found.")
# check segment
segment_id = str(segment_id)
segment = DocumentSegment.query.filter(
DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id
).first()
if not segment:
raise NotFound("Segment not found.")
# check child chunk
child_chunk_id = str(child_chunk_id)
child_chunk = ChildChunk.query.filter(
ChildChunk.id == str(child_chunk_id), ChildChunk.tenant_id == current_user.current_tenant_id
).first()
if not child_chunk:
raise NotFound("Child chunk not found.")
# The role of the current user in the ta table must be admin or owner
if not current_user.is_editor:
raise Forbidden()
try:
DatasetService.check_dataset_permission(dataset, current_user)
except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e))
try:
SegmentService.delete_child_chunk(child_chunk, dataset)
except ChildChunkDeleteIndexServiceError as e:
raise ChildChunkDeleteIndexError(str(e))
return {"result": "success"}, 200
@setup_required
@login_required
@account_initialization_required
@cloud_edition_billing_resource_check("vector_space")
def patch(self, dataset_id, document_id, segment_id, child_chunk_id):
# check dataset
dataset_id = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id)
if not dataset:
raise NotFound("Dataset not found.")
# check user's model setting
DatasetService.check_dataset_model_setting(dataset)
# check document
document_id = str(document_id)
document = DocumentService.get_document(dataset_id, document_id)
if not document:
raise NotFound("Document not found.")
# check segment
segment_id = str(segment_id)
segment = DocumentSegment.query.filter(
DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id
).first()
if not segment:
raise NotFound("Segment not found.")
# check child chunk
child_chunk_id = str(child_chunk_id)
child_chunk = ChildChunk.query.filter(
ChildChunk.id == str(child_chunk_id), ChildChunk.tenant_id == current_user.current_tenant_id
).first()
if not child_chunk:
raise NotFound("Child chunk not found.")
# The role of the current user in the ta table must be admin or owner
if not current_user.is_editor:
raise Forbidden()
try:
DatasetService.check_dataset_permission(dataset, current_user)
except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e))
# validate args
parser = reqparse.RequestParser()
parser.add_argument("content", type=str, required=True, nullable=False, location="json")
args = parser.parse_args()
try:
child_chunk = SegmentService.update_child_chunk(
args.get("content"), child_chunk, segment, document, dataset
)
except ChildChunkIndexingServiceError as e:
raise ChildChunkIndexingError(str(e))
return {"data": marshal(child_chunk, child_chunk_fields)}, 200
api.add_resource(DatasetDocumentSegmentListApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments") api.add_resource(DatasetDocumentSegmentListApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments")
api.add_resource(DatasetDocumentSegmentApi, "/datasets/<uuid:dataset_id>/segments/<uuid:segment_id>/<string:action>") api.add_resource(
DatasetDocumentSegmentApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segment/<string:action>"
)
api.add_resource(DatasetDocumentSegmentAddApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segment") api.add_resource(DatasetDocumentSegmentAddApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segment")
api.add_resource( api.add_resource(
DatasetDocumentSegmentUpdateApi, DatasetDocumentSegmentUpdateApi,
@ -424,3 +651,11 @@ api.add_resource(
"/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments/batch_import", "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments/batch_import",
"/datasets/batch_import_status/<uuid:job_id>", "/datasets/batch_import_status/<uuid:job_id>",
) )
api.add_resource(
ChildChunkAddApi,
"/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments/<uuid:segment_id>/child_chunks",
)
api.add_resource(
ChildChunkUpdateApi,
"/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments/<uuid:segment_id>/child_chunks/<uuid:child_chunk_id>",
)

View File

@ -89,3 +89,15 @@ class IndexingEstimateError(BaseHTTPException):
error_code = "indexing_estimate_error" error_code = "indexing_estimate_error"
description = "Knowledge indexing estimate failed: {message}" description = "Knowledge indexing estimate failed: {message}"
code = 500 code = 500
class ChildChunkIndexingError(BaseHTTPException):
error_code = "child_chunk_indexing_error"
description = "Create child chunk index failed: {message}"
code = 500
class ChildChunkDeleteIndexError(BaseHTTPException):
error_code = "child_chunk_delete_index_error"
description = "Delete child chunk index failed: {message}"
code = 500

View File

@ -16,6 +16,7 @@ from extensions.ext_database import db
from fields.segment_fields import segment_fields from fields.segment_fields import segment_fields
from models.dataset import Dataset, DocumentSegment from models.dataset import Dataset, DocumentSegment
from services.dataset_service import DatasetService, DocumentService, SegmentService from services.dataset_service import DatasetService, DocumentService, SegmentService
from services.entities.knowledge_entities.knowledge_entities import SegmentUpdateArgs
class SegmentApi(DatasetApiResource): class SegmentApi(DatasetApiResource):
@ -193,7 +194,7 @@ class DatasetSegmentApi(DatasetApiResource):
args = parser.parse_args() args = parser.parse_args()
SegmentService.segment_create_args_validate(args["segment"], document) SegmentService.segment_create_args_validate(args["segment"], document)
segment = SegmentService.update_segment(args["segment"], segment, document, dataset) segment = SegmentService.update_segment(SegmentUpdateArgs(**args["segment"]), segment, document, dataset)
return {"data": marshal(segment, segment_fields), "doc_form": document.doc_form}, 200 return {"data": marshal(segment, segment_fields), "doc_form": document.doc_form}, 200

View File

@ -0,0 +1,19 @@
from typing import Optional
from pydantic import BaseModel
class PreviewDetail(BaseModel):
content: str
child_chunks: Optional[list[str]] = None
class QAPreviewDetail(BaseModel):
question: str
answer: str
class IndexingEstimate(BaseModel):
total_segments: int
preview: list[PreviewDetail]
qa_preview: Optional[list[QAPreviewDetail]] = None

View File

@ -8,34 +8,34 @@ import time
import uuid import uuid
from typing import Any, Optional, cast from typing import Any, Optional, cast
from flask import Flask, current_app from flask import current_app
from flask_login import current_user # type: ignore from flask_login import current_user # type: ignore
from sqlalchemy.orm.exc import ObjectDeletedError from sqlalchemy.orm.exc import ObjectDeletedError
from configs import dify_config from configs import dify_config
from core.entities.knowledge_entities import IndexingEstimate, PreviewDetail, QAPreviewDetail
from core.errors.error import ProviderTokenNotInitError from core.errors.error import ProviderTokenNotInitError
from core.llm_generator.llm_generator import LLMGenerator
from core.model_manager import ModelInstance, ModelManager from core.model_manager import ModelInstance, ModelManager
from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.entities.model_entities import ModelType
from core.rag.cleaner.clean_processor import CleanProcessor from core.rag.cleaner.clean_processor import CleanProcessor
from core.rag.datasource.keyword.keyword_factory import Keyword from core.rag.datasource.keyword.keyword_factory import Keyword
from core.rag.docstore.dataset_docstore import DatasetDocumentStore from core.rag.docstore.dataset_docstore import DatasetDocumentStore
from core.rag.extractor.entity.extract_setting import ExtractSetting from core.rag.extractor.entity.extract_setting import ExtractSetting
from core.rag.index_processor.constant.index_type import IndexType
from core.rag.index_processor.index_processor_base import BaseIndexProcessor from core.rag.index_processor.index_processor_base import BaseIndexProcessor
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from core.rag.models.document import Document from core.rag.models.document import ChildDocument, Document
from core.rag.splitter.fixed_text_splitter import ( from core.rag.splitter.fixed_text_splitter import (
EnhanceRecursiveCharacterTextSplitter, EnhanceRecursiveCharacterTextSplitter,
FixedRecursiveCharacterTextSplitter, FixedRecursiveCharacterTextSplitter,
) )
from core.rag.splitter.text_splitter import TextSplitter from core.rag.splitter.text_splitter import TextSplitter
from core.tools.utils.text_processing_utils import remove_leading_symbols
from core.tools.utils.web_reader_tool import get_image_upload_file_ids from core.tools.utils.web_reader_tool import get_image_upload_file_ids
from extensions.ext_database import db from extensions.ext_database import db
from extensions.ext_redis import redis_client from extensions.ext_redis import redis_client
from extensions.ext_storage import storage from extensions.ext_storage import storage
from libs import helper from libs import helper
from models.dataset import Dataset, DatasetProcessRule, DocumentSegment from models.dataset import ChildChunk, Dataset, DatasetProcessRule, DocumentSegment
from models.dataset import Document as DatasetDocument from models.dataset import Document as DatasetDocument
from models.model import UploadFile from models.model import UploadFile
from services.feature_service import FeatureService from services.feature_service import FeatureService
@ -115,6 +115,9 @@ class IndexingRunner:
for document_segment in document_segments: for document_segment in document_segments:
db.session.delete(document_segment) db.session.delete(document_segment)
if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX:
# delete child chunks
db.session.query(ChildChunk).filter(ChildChunk.segment_id == document_segment.id).delete()
db.session.commit() db.session.commit()
# get the process rule # get the process rule
processing_rule = ( processing_rule = (
@ -183,7 +186,22 @@ class IndexingRunner:
"dataset_id": document_segment.dataset_id, "dataset_id": document_segment.dataset_id,
}, },
) )
if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX:
child_chunks = document_segment.child_chunks
if child_chunks:
child_documents = []
for child_chunk in child_chunks:
child_document = ChildDocument(
page_content=child_chunk.content,
metadata={
"doc_id": child_chunk.index_node_id,
"doc_hash": child_chunk.index_node_hash,
"document_id": document_segment.document_id,
"dataset_id": document_segment.dataset_id,
},
)
child_documents.append(child_document)
document.children = child_documents
documents.append(document) documents.append(document)
# build index # build index
@ -222,7 +240,7 @@ class IndexingRunner:
doc_language: str = "English", doc_language: str = "English",
dataset_id: Optional[str] = None, dataset_id: Optional[str] = None,
indexing_technique: str = "economy", indexing_technique: str = "economy",
) -> dict: ) -> IndexingEstimate:
""" """
Estimate the indexing for the document. Estimate the indexing for the document.
""" """
@ -258,31 +276,38 @@ class IndexingRunner:
tenant_id=tenant_id, tenant_id=tenant_id,
model_type=ModelType.TEXT_EMBEDDING, model_type=ModelType.TEXT_EMBEDDING,
) )
preview_texts: list[str] = [] preview_texts = []
total_segments = 0 total_segments = 0
index_type = doc_form index_type = doc_form
index_processor = IndexProcessorFactory(index_type).init_index_processor() index_processor = IndexProcessorFactory(index_type).init_index_processor()
all_text_docs = []
for extract_setting in extract_settings: for extract_setting in extract_settings:
# extract # extract
text_docs = index_processor.extract(extract_setting, process_rule_mode=tmp_processing_rule["mode"])
all_text_docs.extend(text_docs)
processing_rule = DatasetProcessRule( processing_rule = DatasetProcessRule(
mode=tmp_processing_rule["mode"], rules=json.dumps(tmp_processing_rule["rules"]) mode=tmp_processing_rule["mode"], rules=json.dumps(tmp_processing_rule["rules"])
) )
text_docs = index_processor.extract(extract_setting, process_rule_mode=tmp_processing_rule["mode"])
# get splitter documents = index_processor.transform(
splitter = self._get_splitter(processing_rule, embedding_model_instance) text_docs,
embedding_model_instance=embedding_model_instance,
# split to documents process_rule=processing_rule.to_dict(),
documents = self._split_to_documents_for_estimate( tenant_id=current_user.current_tenant_id,
text_docs=text_docs, splitter=splitter, processing_rule=processing_rule doc_language=doc_language,
preview=True,
) )
total_segments += len(documents) total_segments += len(documents)
for document in documents: for document in documents:
if len(preview_texts) < 5: if len(preview_texts) < 10:
preview_texts.append(document.page_content) if doc_form and doc_form == "qa_model":
preview_detail = QAPreviewDetail(
question=document.page_content, answer=document.metadata.get("answer")
)
preview_texts.append(preview_detail)
else:
preview_detail = PreviewDetail(content=document.page_content)
if document.children:
preview_detail.child_chunks = [child.page_content for child in document.children]
preview_texts.append(preview_detail)
# delete image files and related db records # delete image files and related db records
image_upload_file_ids = get_image_upload_file_ids(document.page_content) image_upload_file_ids = get_image_upload_file_ids(document.page_content)
@ -299,15 +324,8 @@ class IndexingRunner:
db.session.delete(image_file) db.session.delete(image_file)
if doc_form and doc_form == "qa_model": if doc_form and doc_form == "qa_model":
if len(preview_texts) > 0: return IndexingEstimate(total_segments=total_segments * 20, qa_preview=preview_texts, preview=[])
# qa model document return IndexingEstimate(total_segments=total_segments, preview=preview_texts)
response = LLMGenerator.generate_qa_document(
current_user.current_tenant_id, preview_texts[0], doc_language
)
document_qa_list = self.format_split_text(response)
return {"total_segments": total_segments * 20, "qa_preview": document_qa_list, "preview": preview_texts}
return {"total_segments": total_segments, "preview": preview_texts}
def _extract( def _extract(
self, index_processor: BaseIndexProcessor, dataset_document: DatasetDocument, process_rule: dict self, index_processor: BaseIndexProcessor, dataset_document: DatasetDocument, process_rule: dict
@ -401,31 +419,26 @@ class IndexingRunner:
@staticmethod @staticmethod
def _get_splitter( def _get_splitter(
processing_rule: DatasetProcessRule, embedding_model_instance: Optional[ModelInstance] processing_rule_mode: str,
max_tokens: int,
chunk_overlap: int,
separator: str,
embedding_model_instance: Optional[ModelInstance],
) -> TextSplitter: ) -> TextSplitter:
""" """
Get the NodeParser object according to the processing rule. Get the NodeParser object according to the processing rule.
""" """
character_splitter: TextSplitter if processing_rule_mode in ["custom", "hierarchical"]:
if processing_rule.mode == "custom":
# The user-defined segmentation rule # The user-defined segmentation rule
rules = json.loads(processing_rule.rules)
segmentation = rules["segmentation"]
max_segmentation_tokens_length = dify_config.INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH max_segmentation_tokens_length = dify_config.INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH
if segmentation["max_tokens"] < 50 or segmentation["max_tokens"] > max_segmentation_tokens_length: if max_tokens < 50 or max_tokens > max_segmentation_tokens_length:
raise ValueError(f"Custom segment length should be between 50 and {max_segmentation_tokens_length}.") raise ValueError(f"Custom segment length should be between 50 and {max_segmentation_tokens_length}.")
separator = segmentation["separator"]
if separator: if separator:
separator = separator.replace("\\n", "\n") separator = separator.replace("\\n", "\n")
if segmentation.get("chunk_overlap"):
chunk_overlap = segmentation["chunk_overlap"]
else:
chunk_overlap = 0
character_splitter = FixedRecursiveCharacterTextSplitter.from_encoder( character_splitter = FixedRecursiveCharacterTextSplitter.from_encoder(
chunk_size=segmentation["max_tokens"], chunk_size=max_tokens,
chunk_overlap=chunk_overlap, chunk_overlap=chunk_overlap,
fixed_separator=separator, fixed_separator=separator,
separators=["\n\n", "", ". ", " ", ""], separators=["\n\n", "", ". ", " ", ""],
@ -443,142 +456,6 @@ class IndexingRunner:
return character_splitter return character_splitter
def _step_split(
self,
text_docs: list[Document],
splitter: TextSplitter,
dataset: Dataset,
dataset_document: DatasetDocument,
processing_rule: DatasetProcessRule,
) -> list[Document]:
"""
Split the text documents into documents and save them to the document segment.
"""
documents = self._split_to_documents(
text_docs=text_docs,
splitter=splitter,
processing_rule=processing_rule,
tenant_id=dataset.tenant_id,
document_form=dataset_document.doc_form,
document_language=dataset_document.doc_language,
)
# save node to document segment
doc_store = DatasetDocumentStore(
dataset=dataset, user_id=dataset_document.created_by, document_id=dataset_document.id
)
# add document segments
doc_store.add_documents(documents)
# update document status to indexing
cur_time = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
self._update_document_index_status(
document_id=dataset_document.id,
after_indexing_status="indexing",
extra_update_params={
DatasetDocument.cleaning_completed_at: cur_time,
DatasetDocument.splitting_completed_at: cur_time,
},
)
# update segment status to indexing
self._update_segments_by_document(
dataset_document_id=dataset_document.id,
update_params={
DocumentSegment.status: "indexing",
DocumentSegment.indexing_at: datetime.datetime.now(datetime.UTC).replace(tzinfo=None),
},
)
return documents
def _split_to_documents(
self,
text_docs: list[Document],
splitter: TextSplitter,
processing_rule: DatasetProcessRule,
tenant_id: str,
document_form: str,
document_language: str,
) -> list[Document]:
"""
Split the text documents into nodes.
"""
all_documents: list[Document] = []
all_qa_documents: list[Document] = []
for text_doc in text_docs:
# document clean
document_text = self._document_clean(text_doc.page_content, processing_rule)
text_doc.page_content = document_text
# parse document to nodes
documents = splitter.split_documents([text_doc])
split_documents = []
for document_node in documents:
if document_node.page_content.strip():
if document_node.metadata is not None:
doc_id = str(uuid.uuid4())
hash = helper.generate_text_hash(document_node.page_content)
document_node.metadata["doc_id"] = doc_id
document_node.metadata["doc_hash"] = hash
# delete Splitter character
page_content = document_node.page_content
document_node.page_content = remove_leading_symbols(page_content)
if document_node.page_content:
split_documents.append(document_node)
all_documents.extend(split_documents)
# processing qa document
if document_form == "qa_model":
for i in range(0, len(all_documents), 10):
threads = []
sub_documents = all_documents[i : i + 10]
for doc in sub_documents:
document_format_thread = threading.Thread(
target=self.format_qa_document,
kwargs={
"flask_app": current_app._get_current_object(), # type: ignore
"tenant_id": tenant_id,
"document_node": doc,
"all_qa_documents": all_qa_documents,
"document_language": document_language,
},
)
threads.append(document_format_thread)
document_format_thread.start()
for thread in threads:
thread.join()
return all_qa_documents
return all_documents
def format_qa_document(self, flask_app: Flask, tenant_id: str, document_node, all_qa_documents, document_language):
format_documents = []
if document_node.page_content is None or not document_node.page_content.strip():
return
with flask_app.app_context():
try:
# qa model document
response = LLMGenerator.generate_qa_document(tenant_id, document_node.page_content, document_language)
document_qa_list = self.format_split_text(response)
qa_documents = []
for result in document_qa_list:
qa_document = Document(
page_content=result["question"], metadata=document_node.metadata.model_copy()
)
if qa_document.metadata is not None:
doc_id = str(uuid.uuid4())
hash = helper.generate_text_hash(result["question"])
qa_document.metadata["answer"] = result["answer"]
qa_document.metadata["doc_id"] = doc_id
qa_document.metadata["doc_hash"] = hash
qa_documents.append(qa_document)
format_documents.extend(qa_documents)
except Exception as e:
logging.exception("Failed to format qa document")
all_qa_documents.extend(format_documents)
def _split_to_documents_for_estimate( def _split_to_documents_for_estimate(
self, text_docs: list[Document], splitter: TextSplitter, processing_rule: DatasetProcessRule self, text_docs: list[Document], splitter: TextSplitter, processing_rule: DatasetProcessRule
) -> list[Document]: ) -> list[Document]:
@ -624,11 +501,11 @@ class IndexingRunner:
return document_text return document_text
@staticmethod @staticmethod
def format_split_text(text): def format_split_text(text: str) -> list[QAPreviewDetail]:
regex = r"Q\d+:\s*(.*?)\s*A\d+:\s*([\s\S]*?)(?=Q\d+:|$)" regex = r"Q\d+:\s*(.*?)\s*A\d+:\s*([\s\S]*?)(?=Q\d+:|$)"
matches = re.findall(regex, text, re.UNICODE) matches = re.findall(regex, text, re.UNICODE)
return [{"question": q, "answer": re.sub(r"\n\s*", "\n", a.strip())} for q, a in matches if q and a] return [QAPreviewDetail(question=q, answer=re.sub(r"\n\s*", "\n", a.strip())) for q, a in matches if q and a]
def _load( def _load(
self, self,
@ -654,13 +531,14 @@ class IndexingRunner:
indexing_start_at = time.perf_counter() indexing_start_at = time.perf_counter()
tokens = 0 tokens = 0
chunk_size = 10 chunk_size = 10
if dataset_document.doc_form != IndexType.PARENT_CHILD_INDEX:
# create keyword index
create_keyword_thread = threading.Thread(
target=self._process_keyword_index,
args=(current_app._get_current_object(), dataset.id, dataset_document.id, documents),
)
create_keyword_thread.start()
# create keyword index
create_keyword_thread = threading.Thread(
target=self._process_keyword_index,
args=(current_app._get_current_object(), dataset.id, dataset_document.id, documents), # type: ignore
)
create_keyword_thread.start()
if dataset.indexing_technique == "high_quality": if dataset.indexing_technique == "high_quality":
with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor: with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor:
futures = [] futures = []
@ -680,8 +558,8 @@ class IndexingRunner:
for future in futures: for future in futures:
tokens += future.result() tokens += future.result()
if dataset_document.doc_form != IndexType.PARENT_CHILD_INDEX:
create_keyword_thread.join() create_keyword_thread.join()
indexing_end_at = time.perf_counter() indexing_end_at = time.perf_counter()
# update document status to completed # update document status to completed
@ -793,28 +671,6 @@ class IndexingRunner:
DocumentSegment.query.filter_by(document_id=dataset_document_id).update(update_params) DocumentSegment.query.filter_by(document_id=dataset_document_id).update(update_params)
db.session.commit() db.session.commit()
@staticmethod
def batch_add_segments(segments: list[DocumentSegment], dataset: Dataset):
"""
Batch add segments index processing
"""
documents = []
for segment in segments:
document = Document(
page_content=segment.content,
metadata={
"doc_id": segment.index_node_id,
"doc_hash": segment.index_node_hash,
"document_id": segment.document_id,
"dataset_id": segment.dataset_id,
},
)
documents.append(document)
# save vector index
index_type = dataset.doc_form
index_processor = IndexProcessorFactory(index_type).init_index_processor()
index_processor.load(dataset, documents)
def _transform( def _transform(
self, self,
index_processor: BaseIndexProcessor, index_processor: BaseIndexProcessor,
@ -856,7 +712,7 @@ class IndexingRunner:
) )
# add document segments # add document segments
doc_store.add_documents(documents) doc_store.add_documents(docs=documents, save_child=dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX)
# update document status to indexing # update document status to indexing
cur_time = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) cur_time = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)

View File

@ -6,11 +6,14 @@ from flask import Flask, current_app
from core.rag.data_post_processor.data_post_processor import DataPostProcessor from core.rag.data_post_processor.data_post_processor import DataPostProcessor
from core.rag.datasource.keyword.keyword_factory import Keyword from core.rag.datasource.keyword.keyword_factory import Keyword
from core.rag.datasource.vdb.vector_factory import Vector from core.rag.datasource.vdb.vector_factory import Vector
from core.rag.embedding.retrieval import RetrievalSegments
from core.rag.index_processor.constant.index_type import IndexType
from core.rag.models.document import Document from core.rag.models.document import Document
from core.rag.rerank.rerank_type import RerankMode from core.rag.rerank.rerank_type import RerankMode
from core.rag.retrieval.retrieval_methods import RetrievalMethod from core.rag.retrieval.retrieval_methods import RetrievalMethod
from extensions.ext_database import db from extensions.ext_database import db
from models.dataset import Dataset from models.dataset import ChildChunk, Dataset, DocumentSegment
from models.dataset import Document as DatasetDocument
from services.external_knowledge_service import ExternalDatasetService from services.external_knowledge_service import ExternalDatasetService
default_retrieval_model = { default_retrieval_model = {
@ -248,3 +251,88 @@ class RetrievalService:
@staticmethod @staticmethod
def escape_query_for_search(query: str) -> str: def escape_query_for_search(query: str) -> str:
return query.replace('"', '\\"') return query.replace('"', '\\"')
@staticmethod
def format_retrieval_documents(documents: list[Document]) -> list[RetrievalSegments]:
records = []
include_segment_ids = []
segment_child_map = {}
for document in documents:
document_id = document.metadata["document_id"]
dataset_document = db.session.query(DatasetDocument).filter(DatasetDocument.id == document_id).first()
if dataset_document and dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX:
child_index_node_id = document.metadata["doc_id"]
result = (
db.session.query(ChildChunk, DocumentSegment)
.join(DocumentSegment, ChildChunk.segment_id == DocumentSegment.id)
.filter(
ChildChunk.index_node_id == child_index_node_id,
DocumentSegment.dataset_id == dataset_document.dataset_id,
DocumentSegment.enabled == True,
DocumentSegment.status == "completed",
)
.first()
)
if result:
child_chunk, segment = result
if not segment:
continue
if segment.id not in include_segment_ids:
include_segment_ids.append(segment.id)
child_chunk_detail = {
"id": child_chunk.id,
"content": child_chunk.content,
"position": child_chunk.position,
"score": document.metadata.get("score", 0.0),
}
map_detail = {
"max_score": document.metadata.get("score", 0.0),
"child_chunks": [child_chunk_detail],
}
segment_child_map[segment.id] = map_detail
record = {
"segment": segment,
}
records.append(record)
else:
child_chunk_detail = {
"id": child_chunk.id,
"content": child_chunk.content,
"position": child_chunk.position,
"score": document.metadata.get("score", 0.0),
}
segment_child_map[segment.id]["child_chunks"].append(child_chunk_detail)
segment_child_map[segment.id]["max_score"] = max(
segment_child_map[segment.id]["max_score"], document.metadata.get("score", 0.0)
)
else:
continue
else:
index_node_id = document.metadata["doc_id"]
segment = (
db.session.query(DocumentSegment)
.filter(
DocumentSegment.dataset_id == dataset_document.dataset_id,
DocumentSegment.enabled == True,
DocumentSegment.status == "completed",
DocumentSegment.index_node_id == index_node_id,
)
.first()
)
if not segment:
continue
include_segment_ids.append(segment.id)
record = {
"segment": segment,
"score": document.metadata.get("score", None),
}
records.append(record)
for record in records:
if record["segment"].id in segment_child_map:
record["child_chunks"] = segment_child_map[record["segment"].id].get("child_chunks", None)
record["score"] = segment_child_map[record["segment"].id]["max_score"]
return [RetrievalSegments(**record) for record in records]

View File

@ -7,7 +7,7 @@ from core.model_manager import ModelManager
from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.entities.model_entities import ModelType
from core.rag.models.document import Document from core.rag.models.document import Document
from extensions.ext_database import db from extensions.ext_database import db
from models.dataset import Dataset, DocumentSegment from models.dataset import ChildChunk, Dataset, DocumentSegment
class DatasetDocumentStore: class DatasetDocumentStore:
@ -60,7 +60,7 @@ class DatasetDocumentStore:
return output return output
def add_documents(self, docs: Sequence[Document], allow_update: bool = True) -> None: def add_documents(self, docs: Sequence[Document], allow_update: bool = True, save_child: bool = False) -> None:
max_position = ( max_position = (
db.session.query(func.max(DocumentSegment.position)) db.session.query(func.max(DocumentSegment.position))
.filter(DocumentSegment.document_id == self._document_id) .filter(DocumentSegment.document_id == self._document_id)
@ -120,6 +120,23 @@ class DatasetDocumentStore:
segment_document.answer = doc.metadata.pop("answer", "") segment_document.answer = doc.metadata.pop("answer", "")
db.session.add(segment_document) db.session.add(segment_document)
db.session.flush()
if save_child:
for postion, child in enumerate(doc.children, start=1):
child_segment = ChildChunk(
tenant_id=self._dataset.tenant_id,
dataset_id=self._dataset.id,
document_id=self._document_id,
segment_id=segment_document.id,
position=postion,
index_node_id=child.metadata["doc_id"],
index_node_hash=child.metadata["doc_hash"],
content=child.page_content,
word_count=len(child.page_content),
type="automatic",
created_by=self._user_id,
)
db.session.add(child_segment)
else: else:
segment_document.content = doc.page_content segment_document.content = doc.page_content
if doc.metadata.get("answer"): if doc.metadata.get("answer"):
@ -127,6 +144,30 @@ class DatasetDocumentStore:
segment_document.index_node_hash = doc.metadata["doc_hash"] segment_document.index_node_hash = doc.metadata["doc_hash"]
segment_document.word_count = len(doc.page_content) segment_document.word_count = len(doc.page_content)
segment_document.tokens = tokens segment_document.tokens = tokens
if save_child and doc.children:
# delete the existing child chunks
db.session.query(ChildChunk).filter(
ChildChunk.tenant_id == self._dataset.tenant_id,
ChildChunk.dataset_id == self._dataset.id,
ChildChunk.document_id == self._document_id,
ChildChunk.segment_id == segment_document.id,
).delete()
# add new child chunks
for position, child in enumerate(doc.children, start=1):
child_segment = ChildChunk(
tenant_id=self._dataset.tenant_id,
dataset_id=self._dataset.id,
document_id=self._document_id,
segment_id=segment_document.id,
position=position,
index_node_id=child.metadata["doc_id"],
index_node_hash=child.metadata["doc_hash"],
content=child.page_content,
word_count=len(child.page_content),
type="automatic",
created_by=self._user_id,
)
db.session.add(child_segment)
db.session.commit() db.session.commit()

View File

@ -0,0 +1,23 @@
from typing import Optional
from pydantic import BaseModel
from models.dataset import DocumentSegment
class RetrievalChildChunk(BaseModel):
"""Retrieval segments."""
id: str
content: str
score: float
position: int
class RetrievalSegments(BaseModel):
"""Retrieval segments."""
model_config = {"arbitrary_types_allowed": True}
segment: DocumentSegment
child_chunks: Optional[list[RetrievalChildChunk]] = None
score: Optional[float] = None

View File

@ -24,7 +24,6 @@ from core.rag.extractor.unstructured.unstructured_markdown_extractor import Unst
from core.rag.extractor.unstructured.unstructured_msg_extractor import UnstructuredMsgExtractor from core.rag.extractor.unstructured.unstructured_msg_extractor import UnstructuredMsgExtractor
from core.rag.extractor.unstructured.unstructured_ppt_extractor import UnstructuredPPTExtractor from core.rag.extractor.unstructured.unstructured_ppt_extractor import UnstructuredPPTExtractor
from core.rag.extractor.unstructured.unstructured_pptx_extractor import UnstructuredPPTXExtractor from core.rag.extractor.unstructured.unstructured_pptx_extractor import UnstructuredPPTXExtractor
from core.rag.extractor.unstructured.unstructured_text_extractor import UnstructuredTextExtractor
from core.rag.extractor.unstructured.unstructured_xml_extractor import UnstructuredXmlExtractor from core.rag.extractor.unstructured.unstructured_xml_extractor import UnstructuredXmlExtractor
from core.rag.extractor.word_extractor import WordExtractor from core.rag.extractor.word_extractor import WordExtractor
from core.rag.models.document import Document from core.rag.models.document import Document
@ -141,11 +140,7 @@ class ExtractProcessor:
extractor = UnstructuredEpubExtractor(file_path, unstructured_api_url, unstructured_api_key) extractor = UnstructuredEpubExtractor(file_path, unstructured_api_url, unstructured_api_key)
else: else:
# txt # txt
extractor = ( extractor = TextExtractor(file_path, autodetect_encoding=True)
UnstructuredTextExtractor(file_path, unstructured_api_url)
if is_automatic
else TextExtractor(file_path, autodetect_encoding=True)
)
else: else:
if file_extension in {".xlsx", ".xls"}: if file_extension in {".xlsx", ".xls"}:
extractor = ExcelExtractor(file_path) extractor = ExcelExtractor(file_path)

View File

@ -267,8 +267,10 @@ class WordExtractor(BaseExtractor):
if isinstance(element.tag, str) and element.tag.endswith("p"): # paragraph if isinstance(element.tag, str) and element.tag.endswith("p"): # paragraph
para = paragraphs.pop(0) para = paragraphs.pop(0)
parsed_paragraph = parse_paragraph(para) parsed_paragraph = parse_paragraph(para)
if parsed_paragraph: if parsed_paragraph.strip():
content.append(parsed_paragraph) content.append(parsed_paragraph)
else:
content.append("\n")
elif isinstance(element.tag, str) and element.tag.endswith("tbl"): # table elif isinstance(element.tag, str) and element.tag.endswith("tbl"): # table
table = tables.pop(0) table = tables.pop(0)
content.append(self._table_to_markdown(table, image_map)) content.append(self._table_to_markdown(table, image_map))

View File

@ -1,8 +1,7 @@
from enum import Enum from enum import Enum
class IndexType(Enum): class IndexType(str, Enum):
PARAGRAPH_INDEX = "text_model" PARAGRAPH_INDEX = "text_model"
QA_INDEX = "qa_model" QA_INDEX = "qa_model"
PARENT_CHILD_INDEX = "parent_child_index" PARENT_CHILD_INDEX = "hierarchical_model"
SUMMARY_INDEX = "summary_index"

View File

@ -27,10 +27,10 @@ class BaseIndexProcessor(ABC):
raise NotImplementedError raise NotImplementedError
@abstractmethod @abstractmethod
def load(self, dataset: Dataset, documents: list[Document], with_keywords: bool = True): def load(self, dataset: Dataset, documents: list[Document], with_keywords: bool = True, **kwargs):
raise NotImplementedError raise NotImplementedError
def clean(self, dataset: Dataset, node_ids: Optional[list[str]], with_keywords: bool = True): def clean(self, dataset: Dataset, node_ids: Optional[list[str]], with_keywords: bool = True, **kwargs):
raise NotImplementedError raise NotImplementedError
@abstractmethod @abstractmethod
@ -45,26 +45,29 @@ class BaseIndexProcessor(ABC):
) -> list[Document]: ) -> list[Document]:
raise NotImplementedError raise NotImplementedError
def _get_splitter(self, processing_rule: dict, embedding_model_instance: Optional[ModelInstance]) -> TextSplitter: def _get_splitter(
self,
processing_rule_mode: str,
max_tokens: int,
chunk_overlap: int,
separator: str,
embedding_model_instance: Optional[ModelInstance],
) -> TextSplitter:
""" """
Get the NodeParser object according to the processing rule. Get the NodeParser object according to the processing rule.
""" """
character_splitter: TextSplitter if processing_rule_mode in ["custom", "hierarchical"]:
if processing_rule["mode"] == "custom":
# The user-defined segmentation rule # The user-defined segmentation rule
rules = processing_rule["rules"]
segmentation = rules["segmentation"]
max_segmentation_tokens_length = dify_config.INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH max_segmentation_tokens_length = dify_config.INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH
if segmentation["max_tokens"] < 50 or segmentation["max_tokens"] > max_segmentation_tokens_length: if max_tokens < 50 or max_tokens > max_segmentation_tokens_length:
raise ValueError(f"Custom segment length should be between 50 and {max_segmentation_tokens_length}.") raise ValueError(f"Custom segment length should be between 50 and {max_segmentation_tokens_length}.")
separator = segmentation["separator"]
if separator: if separator:
separator = separator.replace("\\n", "\n") separator = separator.replace("\\n", "\n")
character_splitter = FixedRecursiveCharacterTextSplitter.from_encoder( character_splitter = FixedRecursiveCharacterTextSplitter.from_encoder(
chunk_size=segmentation["max_tokens"], chunk_size=max_tokens,
chunk_overlap=segmentation.get("chunk_overlap", 0) or 0, chunk_overlap=chunk_overlap,
fixed_separator=separator, fixed_separator=separator,
separators=["\n\n", "", ". ", " ", ""], separators=["\n\n", "", ". ", " ", ""],
embedding_model_instance=embedding_model_instance, embedding_model_instance=embedding_model_instance,

View File

@ -3,6 +3,7 @@
from core.rag.index_processor.constant.index_type import IndexType from core.rag.index_processor.constant.index_type import IndexType
from core.rag.index_processor.index_processor_base import BaseIndexProcessor from core.rag.index_processor.index_processor_base import BaseIndexProcessor
from core.rag.index_processor.processor.paragraph_index_processor import ParagraphIndexProcessor from core.rag.index_processor.processor.paragraph_index_processor import ParagraphIndexProcessor
from core.rag.index_processor.processor.parent_child_index_processor import ParentChildIndexProcessor
from core.rag.index_processor.processor.qa_index_processor import QAIndexProcessor from core.rag.index_processor.processor.qa_index_processor import QAIndexProcessor
@ -18,9 +19,11 @@ class IndexProcessorFactory:
if not self._index_type: if not self._index_type:
raise ValueError("Index type must be specified.") raise ValueError("Index type must be specified.")
if self._index_type == IndexType.PARAGRAPH_INDEX.value: if self._index_type == IndexType.PARAGRAPH_INDEX:
return ParagraphIndexProcessor() return ParagraphIndexProcessor()
elif self._index_type == IndexType.QA_INDEX.value: elif self._index_type == IndexType.QA_INDEX:
return QAIndexProcessor() return QAIndexProcessor()
elif self._index_type == IndexType.PARENT_CHILD_INDEX:
return ParentChildIndexProcessor()
else: else:
raise ValueError(f"Index type {self._index_type} is not supported.") raise ValueError(f"Index type {self._index_type} is not supported.")

View File

@ -13,21 +13,34 @@ from core.rag.index_processor.index_processor_base import BaseIndexProcessor
from core.rag.models.document import Document from core.rag.models.document import Document
from core.tools.utils.text_processing_utils import remove_leading_symbols from core.tools.utils.text_processing_utils import remove_leading_symbols
from libs import helper from libs import helper
from models.dataset import Dataset from models.dataset import Dataset, DatasetProcessRule
from services.entities.knowledge_entities.knowledge_entities import Rule
class ParagraphIndexProcessor(BaseIndexProcessor): class ParagraphIndexProcessor(BaseIndexProcessor):
def extract(self, extract_setting: ExtractSetting, **kwargs) -> list[Document]: def extract(self, extract_setting: ExtractSetting, **kwargs) -> list[Document]:
text_docs = ExtractProcessor.extract( text_docs = ExtractProcessor.extract(
extract_setting=extract_setting, is_automatic=kwargs.get("process_rule_mode") == "automatic" extract_setting=extract_setting,
is_automatic=(
kwargs.get("process_rule_mode") == "automatic" or kwargs.get("process_rule_mode") == "hierarchical"
),
) )
return text_docs return text_docs
def transform(self, documents: list[Document], **kwargs) -> list[Document]: def transform(self, documents: list[Document], **kwargs) -> list[Document]:
process_rule = kwargs.get("process_rule")
if process_rule.get("mode") == "automatic":
automatic_rule = DatasetProcessRule.AUTOMATIC_RULES
rules = Rule(**automatic_rule)
else:
rules = Rule(**process_rule.get("rules"))
# Split the text documents into nodes. # Split the text documents into nodes.
splitter = self._get_splitter( splitter = self._get_splitter(
processing_rule=kwargs.get("process_rule", {}), processing_rule_mode=process_rule.get("mode"),
max_tokens=rules.segmentation.max_tokens,
chunk_overlap=rules.segmentation.chunk_overlap,
separator=rules.segmentation.separator,
embedding_model_instance=kwargs.get("embedding_model_instance"), embedding_model_instance=kwargs.get("embedding_model_instance"),
) )
all_documents = [] all_documents = []
@ -53,15 +66,19 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
all_documents.extend(split_documents) all_documents.extend(split_documents)
return all_documents return all_documents
def load(self, dataset: Dataset, documents: list[Document], with_keywords: bool = True): def load(self, dataset: Dataset, documents: list[Document], with_keywords: bool = True, **kwargs):
if dataset.indexing_technique == "high_quality": if dataset.indexing_technique == "high_quality":
vector = Vector(dataset) vector = Vector(dataset)
vector.create(documents) vector.create(documents)
if with_keywords: if with_keywords:
keywords_list = kwargs.get("keywords_list")
keyword = Keyword(dataset) keyword = Keyword(dataset)
keyword.create(documents) if keywords_list and len(keywords_list) > 0:
keyword.add_texts(documents, keywords_list=keywords_list)
else:
keyword.add_texts(documents)
def clean(self, dataset: Dataset, node_ids: Optional[list[str]], with_keywords: bool = True): def clean(self, dataset: Dataset, node_ids: Optional[list[str]], with_keywords: bool = True, **kwargs):
if dataset.indexing_technique == "high_quality": if dataset.indexing_technique == "high_quality":
vector = Vector(dataset) vector = Vector(dataset)
if node_ids: if node_ids:

View File

@ -0,0 +1,189 @@
"""Paragraph index processor."""
import uuid
from typing import Optional
from core.model_manager import ModelInstance
from core.rag.cleaner.clean_processor import CleanProcessor
from core.rag.datasource.retrieval_service import RetrievalService
from core.rag.datasource.vdb.vector_factory import Vector
from core.rag.extractor.entity.extract_setting import ExtractSetting
from core.rag.extractor.extract_processor import ExtractProcessor
from core.rag.index_processor.index_processor_base import BaseIndexProcessor
from core.rag.models.document import ChildDocument, Document
from extensions.ext_database import db
from libs import helper
from models.dataset import ChildChunk, Dataset, DocumentSegment
from services.entities.knowledge_entities.knowledge_entities import ParentMode, Rule
class ParentChildIndexProcessor(BaseIndexProcessor):
def extract(self, extract_setting: ExtractSetting, **kwargs) -> list[Document]:
text_docs = ExtractProcessor.extract(
extract_setting=extract_setting,
is_automatic=(
kwargs.get("process_rule_mode") == "automatic" or kwargs.get("process_rule_mode") == "hierarchical"
),
)
return text_docs
def transform(self, documents: list[Document], **kwargs) -> list[Document]:
process_rule = kwargs.get("process_rule")
rules = Rule(**process_rule.get("rules"))
all_documents = []
if rules.parent_mode == ParentMode.PARAGRAPH:
# Split the text documents into nodes.
splitter = self._get_splitter(
processing_rule_mode=process_rule.get("mode"),
max_tokens=rules.segmentation.max_tokens,
chunk_overlap=rules.segmentation.chunk_overlap,
separator=rules.segmentation.separator,
embedding_model_instance=kwargs.get("embedding_model_instance"),
)
for document in documents:
# document clean
document_text = CleanProcessor.clean(document.page_content, process_rule)
document.page_content = document_text
# parse document to nodes
document_nodes = splitter.split_documents([document])
split_documents = []
for document_node in document_nodes:
if document_node.page_content.strip():
doc_id = str(uuid.uuid4())
hash = helper.generate_text_hash(document_node.page_content)
document_node.metadata["doc_id"] = doc_id
document_node.metadata["doc_hash"] = hash
# delete Splitter character
page_content = document_node.page_content
if page_content.startswith(".") or page_content.startswith(""):
page_content = page_content[1:].strip()
else:
page_content = page_content
if len(page_content) > 0:
document_node.page_content = page_content
# parse document to child nodes
child_nodes = self._split_child_nodes(
document_node, rules, process_rule.get("mode"), kwargs.get("embedding_model_instance")
)
document_node.children = child_nodes
split_documents.append(document_node)
all_documents.extend(split_documents)
elif rules.parent_mode == ParentMode.FULL_DOC:
page_content = "\n".join([document.page_content for document in documents])
document = Document(page_content=page_content, metadata=documents[0].metadata)
# parse document to child nodes
child_nodes = self._split_child_nodes(
document, rules, process_rule.get("mode"), kwargs.get("embedding_model_instance")
)
document.children = child_nodes
doc_id = str(uuid.uuid4())
hash = helper.generate_text_hash(document.page_content)
document.metadata["doc_id"] = doc_id
document.metadata["doc_hash"] = hash
all_documents.append(document)
return all_documents
def load(self, dataset: Dataset, documents: list[Document], with_keywords: bool = True, **kwargs):
if dataset.indexing_technique == "high_quality":
vector = Vector(dataset)
for document in documents:
child_documents = document.children
if child_documents:
formatted_child_documents = [
Document(**child_document.model_dump()) for child_document in child_documents
]
vector.create(formatted_child_documents)
def clean(self, dataset: Dataset, node_ids: Optional[list[str]], with_keywords: bool = True, **kwargs):
# node_ids is segment's node_ids
if dataset.indexing_technique == "high_quality":
delete_child_chunks = kwargs.get("delete_child_chunks") or False
vector = Vector(dataset)
if node_ids:
child_node_ids = (
db.session.query(ChildChunk.index_node_id)
.join(DocumentSegment, ChildChunk.segment_id == DocumentSegment.id)
.filter(
DocumentSegment.dataset_id == dataset.id,
DocumentSegment.index_node_id.in_(node_ids),
ChildChunk.dataset_id == dataset.id,
)
.all()
)
child_node_ids = [child_node_id[0] for child_node_id in child_node_ids]
vector.delete_by_ids(child_node_ids)
if delete_child_chunks:
db.session.query(ChildChunk).filter(
ChildChunk.dataset_id == dataset.id, ChildChunk.index_node_id.in_(child_node_ids)
).delete()
db.session.commit()
else:
vector.delete()
if delete_child_chunks:
db.session.query(ChildChunk).filter(ChildChunk.dataset_id == dataset.id).delete()
db.session.commit()
def retrieve(
self,
retrieval_method: str,
query: str,
dataset: Dataset,
top_k: int,
score_threshold: float,
reranking_model: dict,
) -> list[Document]:
# Set search parameters.
results = RetrievalService.retrieve(
retrieval_method=retrieval_method,
dataset_id=dataset.id,
query=query,
top_k=top_k,
score_threshold=score_threshold,
reranking_model=reranking_model,
)
# Organize results.
docs = []
for result in results:
metadata = result.metadata
metadata["score"] = result.score
if result.score > score_threshold:
doc = Document(page_content=result.page_content, metadata=metadata)
docs.append(doc)
return docs
def _split_child_nodes(
self,
document_node: Document,
rules: Rule,
process_rule_mode: str,
embedding_model_instance: Optional[ModelInstance],
) -> list[ChildDocument]:
child_splitter = self._get_splitter(
processing_rule_mode=process_rule_mode,
max_tokens=rules.subchunk_segmentation.max_tokens,
chunk_overlap=rules.subchunk_segmentation.chunk_overlap,
separator=rules.subchunk_segmentation.separator,
embedding_model_instance=embedding_model_instance,
)
# parse document to child nodes
child_nodes = []
child_documents = child_splitter.split_documents([document_node])
for child_document_node in child_documents:
if child_document_node.page_content.strip():
doc_id = str(uuid.uuid4())
hash = helper.generate_text_hash(child_document_node.page_content)
child_document = ChildDocument(
page_content=child_document_node.page_content, metadata=document_node.metadata
)
child_document.metadata["doc_id"] = doc_id
child_document.metadata["doc_hash"] = hash
child_page_content = child_document.page_content
if child_page_content.startswith(".") or child_page_content.startswith(""):
child_page_content = child_page_content[1:].strip()
if len(child_page_content) > 0:
child_document.page_content = child_page_content
child_nodes.append(child_document)
return child_nodes

View File

@ -21,18 +21,28 @@ from core.rag.models.document import Document
from core.tools.utils.text_processing_utils import remove_leading_symbols from core.tools.utils.text_processing_utils import remove_leading_symbols
from libs import helper from libs import helper
from models.dataset import Dataset from models.dataset import Dataset
from services.entities.knowledge_entities.knowledge_entities import Rule
class QAIndexProcessor(BaseIndexProcessor): class QAIndexProcessor(BaseIndexProcessor):
def extract(self, extract_setting: ExtractSetting, **kwargs) -> list[Document]: def extract(self, extract_setting: ExtractSetting, **kwargs) -> list[Document]:
text_docs = ExtractProcessor.extract( text_docs = ExtractProcessor.extract(
extract_setting=extract_setting, is_automatic=kwargs.get("process_rule_mode") == "automatic" extract_setting=extract_setting,
is_automatic=(
kwargs.get("process_rule_mode") == "automatic" or kwargs.get("process_rule_mode") == "hierarchical"
),
) )
return text_docs return text_docs
def transform(self, documents: list[Document], **kwargs) -> list[Document]: def transform(self, documents: list[Document], **kwargs) -> list[Document]:
preview = kwargs.get("preview")
process_rule = kwargs.get("process_rule")
rules = Rule(**process_rule.get("rules"))
splitter = self._get_splitter( splitter = self._get_splitter(
processing_rule=kwargs.get("process_rule") or {}, processing_rule_mode=process_rule.get("mode"),
max_tokens=rules.segmentation.max_tokens,
chunk_overlap=rules.segmentation.chunk_overlap,
separator=rules.segmentation.separator,
embedding_model_instance=kwargs.get("embedding_model_instance"), embedding_model_instance=kwargs.get("embedding_model_instance"),
) )
@ -59,24 +69,33 @@ class QAIndexProcessor(BaseIndexProcessor):
document_node.page_content = remove_leading_symbols(page_content) document_node.page_content = remove_leading_symbols(page_content)
split_documents.append(document_node) split_documents.append(document_node)
all_documents.extend(split_documents) all_documents.extend(split_documents)
for i in range(0, len(all_documents), 10): if preview:
threads = [] self._format_qa_document(
sub_documents = all_documents[i : i + 10] current_app._get_current_object(),
for doc in sub_documents: kwargs.get("tenant_id"),
document_format_thread = threading.Thread( all_documents[0],
target=self._format_qa_document, all_qa_documents,
kwargs={ kwargs.get("doc_language", "English"),
"flask_app": current_app._get_current_object(), # type: ignore )
"tenant_id": kwargs.get("tenant_id"), else:
"document_node": doc, for i in range(0, len(all_documents), 10):
"all_qa_documents": all_qa_documents, threads = []
"document_language": kwargs.get("doc_language", "English"), sub_documents = all_documents[i : i + 10]
}, for doc in sub_documents:
) document_format_thread = threading.Thread(
threads.append(document_format_thread) target=self._format_qa_document,
document_format_thread.start() kwargs={
for thread in threads: "flask_app": current_app._get_current_object(),
thread.join() "tenant_id": kwargs.get("tenant_id"),
"document_node": doc,
"all_qa_documents": all_qa_documents,
"document_language": kwargs.get("doc_language", "English"),
},
)
threads.append(document_format_thread)
document_format_thread.start()
for thread in threads:
thread.join()
return all_qa_documents return all_qa_documents
def format_by_template(self, file: FileStorage, **kwargs) -> list[Document]: def format_by_template(self, file: FileStorage, **kwargs) -> list[Document]:
@ -98,12 +117,12 @@ class QAIndexProcessor(BaseIndexProcessor):
raise ValueError(str(e)) raise ValueError(str(e))
return text_docs return text_docs
def load(self, dataset: Dataset, documents: list[Document], with_keywords: bool = True): def load(self, dataset: Dataset, documents: list[Document], with_keywords: bool = True, **kwargs):
if dataset.indexing_technique == "high_quality": if dataset.indexing_technique == "high_quality":
vector = Vector(dataset) vector = Vector(dataset)
vector.create(documents) vector.create(documents)
def clean(self, dataset: Dataset, node_ids: Optional[list[str]], with_keywords: bool = True): def clean(self, dataset: Dataset, node_ids: Optional[list[str]], with_keywords: bool = True, **kwargs):
vector = Vector(dataset) vector = Vector(dataset)
if node_ids: if node_ids:
vector.delete_by_ids(node_ids) vector.delete_by_ids(node_ids)

View File

@ -5,6 +5,19 @@ from typing import Any, Optional
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
class ChildDocument(BaseModel):
"""Class for storing a piece of text and associated metadata."""
page_content: str
vector: Optional[list[float]] = None
"""Arbitrary metadata about the page content (e.g., source, relationships to other
documents, etc.).
"""
metadata: Optional[dict] = Field(default_factory=dict)
class Document(BaseModel): class Document(BaseModel):
"""Class for storing a piece of text and associated metadata.""" """Class for storing a piece of text and associated metadata."""
@ -19,6 +32,8 @@ class Document(BaseModel):
provider: Optional[str] = "dify" provider: Optional[str] = "dify"
children: Optional[list[ChildDocument]] = None
class BaseDocumentTransformer(ABC): class BaseDocumentTransformer(ABC):
"""Abstract base class for document transformation systems. """Abstract base class for document transformation systems.

View File

@ -166,43 +166,29 @@ class DatasetRetrieval:
"content": item.page_content, "content": item.page_content,
} }
retrieval_resource_list.append(source) retrieval_resource_list.append(source)
document_score_list = {}
# deal with dify documents # deal with dify documents
if dify_documents: if dify_documents:
for item in dify_documents: records = RetrievalService.format_retrieval_documents(dify_documents)
if item.metadata.get("score"): if records:
document_score_list[item.metadata["doc_id"]] = item.metadata["score"] for record in records:
segment = record.segment
index_node_ids = [document.metadata["doc_id"] for document in dify_documents]
segments = DocumentSegment.query.filter(
DocumentSegment.dataset_id.in_(dataset_ids),
DocumentSegment.status == "completed",
DocumentSegment.enabled == True,
DocumentSegment.index_node_id.in_(index_node_ids),
).all()
if segments:
index_node_id_to_position = {id: position for position, id in enumerate(index_node_ids)}
sorted_segments = sorted(
segments, key=lambda segment: index_node_id_to_position.get(segment.index_node_id, float("inf"))
)
for segment in sorted_segments:
if segment.answer: if segment.answer:
document_context_list.append( document_context_list.append(
DocumentContext( DocumentContext(
content=f"question:{segment.get_sign_content()} answer:{segment.answer}", content=f"question:{segment.get_sign_content()} answer:{segment.answer}",
score=document_score_list.get(segment.index_node_id, None), score=record.score,
) )
) )
else: else:
document_context_list.append( document_context_list.append(
DocumentContext( DocumentContext(
content=segment.get_sign_content(), content=segment.get_sign_content(),
score=document_score_list.get(segment.index_node_id, None), score=record.score,
) )
) )
if show_retrieve_source: if show_retrieve_source:
for segment in sorted_segments: for record in records:
segment = record.segment
dataset = Dataset.query.filter_by(id=segment.dataset_id).first() dataset = Dataset.query.filter_by(id=segment.dataset_id).first()
document = DatasetDocument.query.filter( document = DatasetDocument.query.filter(
DatasetDocument.id == segment.document_id, DatasetDocument.id == segment.document_id,
@ -218,7 +204,7 @@ class DatasetRetrieval:
"data_source_type": document.data_source_type, "data_source_type": document.data_source_type,
"segment_id": segment.id, "segment_id": segment.id,
"retriever_from": invoke_from.to_source(), "retriever_from": invoke_from.to_source(),
"score": document_score_list.get(segment.index_node_id, 0.0), "score": record.score or 0.0,
} }
if invoke_from.to_source() == "dev": if invoke_from.to_source() == "dev":

View File

@ -11,6 +11,7 @@ from core.entities.model_entities import ModelStatus
from core.model_manager import ModelInstance, ModelManager from core.model_manager import ModelInstance, ModelManager
from core.model_runtime.entities.model_entities import ModelFeature, ModelType from core.model_runtime.entities.model_entities import ModelFeature, ModelType
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.rag.datasource.retrieval_service import RetrievalService
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
from core.rag.retrieval.retrieval_methods import RetrievalMethod from core.rag.retrieval.retrieval_methods import RetrievalMethod
from core.variables import StringSegment from core.variables import StringSegment
@ -18,7 +19,7 @@ from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.nodes.base import BaseNode from core.workflow.nodes.base import BaseNode
from core.workflow.nodes.enums import NodeType from core.workflow.nodes.enums import NodeType
from extensions.ext_database import db from extensions.ext_database import db
from models.dataset import Dataset, Document, DocumentSegment from models.dataset import Dataset, Document
from models.workflow import WorkflowNodeExecutionStatus from models.workflow import WorkflowNodeExecutionStatus
from .entities import KnowledgeRetrievalNodeData from .entities import KnowledgeRetrievalNodeData
@ -211,29 +212,12 @@ class KnowledgeRetrievalNode(BaseNode[KnowledgeRetrievalNodeData]):
"content": item.page_content, "content": item.page_content,
} }
retrieval_resource_list.append(source) retrieval_resource_list.append(source)
document_score_list: dict[str, float] = {}
# deal with dify documents # deal with dify documents
if dify_documents: if dify_documents:
document_score_list = {} records = RetrievalService.format_retrieval_documents(dify_documents)
for item in dify_documents: if records:
if item.metadata.get("score"): for record in records:
document_score_list[item.metadata["doc_id"]] = item.metadata["score"] segment = record.segment
index_node_ids = [document.metadata["doc_id"] for document in dify_documents]
segments = DocumentSegment.query.filter(
DocumentSegment.dataset_id.in_(dataset_ids),
DocumentSegment.completed_at.isnot(None),
DocumentSegment.status == "completed",
DocumentSegment.enabled == True,
DocumentSegment.index_node_id.in_(index_node_ids),
).all()
if segments:
index_node_id_to_position = {id: position for position, id in enumerate(index_node_ids)}
sorted_segments = sorted(
segments, key=lambda segment: index_node_id_to_position.get(segment.index_node_id, float("inf"))
)
for segment in sorted_segments:
dataset = Dataset.query.filter_by(id=segment.dataset_id).first() dataset = Dataset.query.filter_by(id=segment.dataset_id).first()
document = Document.query.filter( document = Document.query.filter(
Document.id == segment.document_id, Document.id == segment.document_id,
@ -251,7 +235,7 @@ class KnowledgeRetrievalNode(BaseNode[KnowledgeRetrievalNodeData]):
"document_data_source_type": document.data_source_type, "document_data_source_type": document.data_source_type,
"segment_id": segment.id, "segment_id": segment.id,
"retriever_from": "workflow", "retriever_from": "workflow",
"score": document_score_list.get(segment.index_node_id, None), "score": record.score or 0.0,
"segment_hit_count": segment.hit_count, "segment_hit_count": segment.hit_count,
"segment_word_count": segment.word_count, "segment_word_count": segment.word_count,
"segment_position": segment.position, "segment_position": segment.position,
@ -270,10 +254,8 @@ class KnowledgeRetrievalNode(BaseNode[KnowledgeRetrievalNodeData]):
key=lambda x: x["metadata"]["score"] if x["metadata"].get("score") is not None else 0.0, key=lambda x: x["metadata"]["score"] if x["metadata"].get("score") is not None else 0.0,
reverse=True, reverse=True,
) )
position = 1 for position, item in enumerate(retrieval_resource_list, start=1):
for item in retrieval_resource_list:
item["metadata"]["position"] = position item["metadata"]["position"] = position
position += 1
return retrieval_resource_list return retrieval_resource_list
@classmethod @classmethod

View File

@ -73,6 +73,7 @@ dataset_detail_fields = {
"embedding_available": fields.Boolean, "embedding_available": fields.Boolean,
"retrieval_model_dict": fields.Nested(dataset_retrieval_model_fields), "retrieval_model_dict": fields.Nested(dataset_retrieval_model_fields),
"tags": fields.List(fields.Nested(tag_fields)), "tags": fields.List(fields.Nested(tag_fields)),
"doc_form": fields.String,
"external_knowledge_info": fields.Nested(external_knowledge_info_fields), "external_knowledge_info": fields.Nested(external_knowledge_info_fields),
"external_retrieval_model": fields.Nested(external_retrieval_model_fields, allow_null=True), "external_retrieval_model": fields.Nested(external_retrieval_model_fields, allow_null=True),
} }

View File

@ -34,6 +34,7 @@ document_with_segments_fields = {
"data_source_info": fields.Raw(attribute="data_source_info_dict"), "data_source_info": fields.Raw(attribute="data_source_info_dict"),
"data_source_detail_dict": fields.Raw(attribute="data_source_detail_dict"), "data_source_detail_dict": fields.Raw(attribute="data_source_detail_dict"),
"dataset_process_rule_id": fields.String, "dataset_process_rule_id": fields.String,
"process_rule_dict": fields.Raw(attribute="process_rule_dict"),
"name": fields.String, "name": fields.String,
"created_from": fields.String, "created_from": fields.String,
"created_by": fields.String, "created_by": fields.String,

View File

@ -34,8 +34,16 @@ segment_fields = {
"document": fields.Nested(document_fields), "document": fields.Nested(document_fields),
} }
child_chunk_fields = {
"id": fields.String,
"content": fields.String,
"position": fields.Integer,
"score": fields.Float,
}
hit_testing_record_fields = { hit_testing_record_fields = {
"segment": fields.Nested(segment_fields), "segment": fields.Nested(segment_fields),
"child_chunks": fields.List(fields.Nested(child_chunk_fields)),
"score": fields.Float, "score": fields.Float,
"tsne_position": fields.Raw, "tsne_position": fields.Raw,
} }

View File

@ -2,6 +2,17 @@ from flask_restful import fields # type: ignore
from libs.helper import TimestampField from libs.helper import TimestampField
child_chunk_fields = {
"id": fields.String,
"segment_id": fields.String,
"content": fields.String,
"position": fields.Integer,
"word_count": fields.Integer,
"type": fields.String,
"created_at": TimestampField,
"updated_at": TimestampField,
}
segment_fields = { segment_fields = {
"id": fields.String, "id": fields.String,
"position": fields.Integer, "position": fields.Integer,
@ -20,10 +31,13 @@ segment_fields = {
"status": fields.String, "status": fields.String,
"created_by": fields.String, "created_by": fields.String,
"created_at": TimestampField, "created_at": TimestampField,
"updated_at": TimestampField,
"updated_by": fields.String,
"indexing_at": TimestampField, "indexing_at": TimestampField,
"completed_at": TimestampField, "completed_at": TimestampField,
"error": fields.String, "error": fields.String,
"stopped_at": TimestampField, "stopped_at": TimestampField,
"child_chunks": fields.List(fields.Nested(child_chunk_fields)),
} }
segment_list_response = { segment_list_response = {

View File

@ -0,0 +1,55 @@
"""parent-child-index
Revision ID: e19037032219
Revises: 01d6889832f7
Create Date: 2024-11-22 07:01:17.550037
"""
from alembic import op
import models as models
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = 'e19037032219'
down_revision = 'd7999dfa4aae'
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.create_table('child_chunks',
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
sa.Column('tenant_id', models.types.StringUUID(), nullable=False),
sa.Column('dataset_id', models.types.StringUUID(), nullable=False),
sa.Column('document_id', models.types.StringUUID(), nullable=False),
sa.Column('segment_id', models.types.StringUUID(), nullable=False),
sa.Column('position', sa.Integer(), nullable=False),
sa.Column('content', sa.Text(), nullable=False),
sa.Column('word_count', sa.Integer(), nullable=False),
sa.Column('index_node_id', sa.String(length=255), nullable=True),
sa.Column('index_node_hash', sa.String(length=255), nullable=True),
sa.Column('type', sa.String(length=255), server_default=sa.text("'automatic'::character varying"), nullable=False),
sa.Column('created_by', models.types.StringUUID(), nullable=False),
sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False),
sa.Column('updated_by', models.types.StringUUID(), nullable=True),
sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False),
sa.Column('indexing_at', sa.DateTime(), nullable=True),
sa.Column('completed_at', sa.DateTime(), nullable=True),
sa.Column('error', sa.Text(), nullable=True),
sa.PrimaryKeyConstraint('id', name='child_chunk_pkey')
)
with op.batch_alter_table('child_chunks', schema=None) as batch_op:
batch_op.create_index('child_chunk_dataset_id_idx', ['tenant_id', 'dataset_id', 'document_id', 'segment_id', 'index_node_id'], unique=False)
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('child_chunks', schema=None) as batch_op:
batch_op.drop_index('child_chunk_dataset_id_idx')
op.drop_table('child_chunks')
# ### end Alembic commands ###

View File

@ -0,0 +1,47 @@
"""add_auto_disabled_dataset_logs
Revision ID: 923752d42eb6
Revises: e19037032219
Create Date: 2024-12-25 11:37:55.467101
"""
from alembic import op
import models as models
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = '923752d42eb6'
down_revision = 'e19037032219'
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.create_table('dataset_auto_disable_logs',
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
sa.Column('tenant_id', models.types.StringUUID(), nullable=False),
sa.Column('dataset_id', models.types.StringUUID(), nullable=False),
sa.Column('document_id', models.types.StringUUID(), nullable=False),
sa.Column('notified', sa.Boolean(), server_default=sa.text('false'), nullable=False),
sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False),
sa.PrimaryKeyConstraint('id', name='dataset_auto_disable_log_pkey')
)
with op.batch_alter_table('dataset_auto_disable_logs', schema=None) as batch_op:
batch_op.create_index('dataset_auto_disable_log_created_atx', ['created_at'], unique=False)
batch_op.create_index('dataset_auto_disable_log_dataset_idx', ['dataset_id'], unique=False)
batch_op.create_index('dataset_auto_disable_log_tenant_idx', ['tenant_id'], unique=False)
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('dataset_auto_disable_logs', schema=None) as batch_op:
batch_op.drop_index('dataset_auto_disable_log_tenant_idx')
batch_op.drop_index('dataset_auto_disable_log_dataset_idx')
batch_op.drop_index('dataset_auto_disable_log_created_atx')
op.drop_table('dataset_auto_disable_logs')
# ### end Alembic commands ###

View File

@ -17,6 +17,7 @@ from sqlalchemy.dialects.postgresql import JSONB
from configs import dify_config from configs import dify_config
from core.rag.retrieval.retrieval_methods import RetrievalMethod from core.rag.retrieval.retrieval_methods import RetrievalMethod
from extensions.ext_storage import storage from extensions.ext_storage import storage
from services.entities.knowledge_entities.knowledge_entities import ParentMode, Rule
from .account import Account from .account import Account
from .engine import db from .engine import db
@ -215,7 +216,7 @@ class DatasetProcessRule(db.Model): # type: ignore[name-defined]
created_by = db.Column(StringUUID, nullable=False) created_by = db.Column(StringUUID, nullable=False)
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
MODES = ["automatic", "custom"] MODES = ["automatic", "custom", "hierarchical"]
PRE_PROCESSING_RULES = ["remove_stopwords", "remove_extra_spaces", "remove_urls_emails"] PRE_PROCESSING_RULES = ["remove_stopwords", "remove_extra_spaces", "remove_urls_emails"]
AUTOMATIC_RULES: dict[str, Any] = { AUTOMATIC_RULES: dict[str, Any] = {
"pre_processing_rules": [ "pre_processing_rules": [
@ -231,8 +232,6 @@ class DatasetProcessRule(db.Model): # type: ignore[name-defined]
"dataset_id": self.dataset_id, "dataset_id": self.dataset_id,
"mode": self.mode, "mode": self.mode,
"rules": self.rules_dict, "rules": self.rules_dict,
"created_by": self.created_by,
"created_at": self.created_at,
} }
@property @property
@ -396,6 +395,12 @@ class Document(db.Model): # type: ignore[name-defined]
.scalar() .scalar()
) )
@property
def process_rule_dict(self):
if self.dataset_process_rule_id:
return self.dataset_process_rule.to_dict()
return None
def to_dict(self): def to_dict(self):
return { return {
"id": self.id, "id": self.id,
@ -560,6 +565,24 @@ class DocumentSegment(db.Model): # type: ignore[name-defined]
.first() .first()
) )
@property
def child_chunks(self):
process_rule = self.document.dataset_process_rule
if process_rule.mode == "hierarchical":
rules = Rule(**process_rule.rules_dict)
if rules.parent_mode and rules.parent_mode != ParentMode.FULL_DOC:
child_chunks = (
db.session.query(ChildChunk)
.filter(ChildChunk.segment_id == self.id)
.order_by(ChildChunk.position.asc())
.all()
)
return child_chunks or []
else:
return []
else:
return []
def get_sign_content(self): def get_sign_content(self):
signed_urls = [] signed_urls = []
text = self.content text = self.content
@ -605,6 +628,47 @@ class DocumentSegment(db.Model): # type: ignore[name-defined]
return text return text
class ChildChunk(db.Model):
__tablename__ = "child_chunks"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="child_chunk_pkey"),
db.Index("child_chunk_dataset_id_idx", "tenant_id", "dataset_id", "document_id", "segment_id", "index_node_id"),
)
# initial fields
id = db.Column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()"))
tenant_id = db.Column(StringUUID, nullable=False)
dataset_id = db.Column(StringUUID, nullable=False)
document_id = db.Column(StringUUID, nullable=False)
segment_id = db.Column(StringUUID, nullable=False)
position = db.Column(db.Integer, nullable=False)
content = db.Column(db.Text, nullable=False)
word_count = db.Column(db.Integer, nullable=False)
# indexing fields
index_node_id = db.Column(db.String(255), nullable=True)
index_node_hash = db.Column(db.String(255), nullable=True)
type = db.Column(db.String(255), nullable=False, server_default=db.text("'automatic'::character varying"))
created_by = db.Column(StringUUID, nullable=False)
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
updated_by = db.Column(StringUUID, nullable=True)
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
indexing_at = db.Column(db.DateTime, nullable=True)
completed_at = db.Column(db.DateTime, nullable=True)
error = db.Column(db.Text, nullable=True)
@property
def dataset(self):
return db.session.query(Dataset).filter(Dataset.id == self.dataset_id).first()
@property
def document(self):
return db.session.query(Document).filter(Document.id == self.document_id).first()
@property
def segment(self):
return db.session.query(DocumentSegment).filter(DocumentSegment.id == self.segment_id).first()
class AppDatasetJoin(db.Model): # type: ignore[name-defined] class AppDatasetJoin(db.Model): # type: ignore[name-defined]
__tablename__ = "app_dataset_joins" __tablename__ = "app_dataset_joins"
__table_args__ = ( __table_args__ = (
@ -844,3 +908,20 @@ class ExternalKnowledgeBindings(db.Model): # type: ignore[name-defined]
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
updated_by = db.Column(StringUUID, nullable=True) updated_by = db.Column(StringUUID, nullable=True)
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
class DatasetAutoDisableLog(db.Model):
__tablename__ = "dataset_auto_disable_logs"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="dataset_auto_disable_log_pkey"),
db.Index("dataset_auto_disable_log_tenant_idx", "tenant_id"),
db.Index("dataset_auto_disable_log_dataset_idx", "dataset_id"),
db.Index("dataset_auto_disable_log_created_atx", "created_at"),
)
id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"))
tenant_id = db.Column(StringUUID, nullable=False)
dataset_id = db.Column(StringUUID, nullable=False)
document_id = db.Column(StringUUID, nullable=False)
notified = db.Column(db.Boolean, nullable=False, server_default=db.text("false"))
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))

View File

@ -10,7 +10,7 @@ from configs import dify_config
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from extensions.ext_database import db from extensions.ext_database import db
from extensions.ext_redis import redis_client from extensions.ext_redis import redis_client
from models.dataset import Dataset, DatasetQuery, Document from models.dataset import Dataset, DatasetAutoDisableLog, DatasetQuery, Document
from services.feature_service import FeatureService from services.feature_service import FeatureService
@ -75,6 +75,23 @@ def clean_unused_datasets_task():
) )
if not dataset_query or len(dataset_query) == 0: if not dataset_query or len(dataset_query) == 0:
try: try:
# add auto disable log
documents = (
db.session.query(Document)
.filter(
Document.dataset_id == dataset.id,
Document.enabled == True,
Document.archived == False,
)
.all()
)
for document in documents:
dataset_auto_disable_log = DatasetAutoDisableLog(
tenant_id=dataset.tenant_id,
dataset_id=dataset.id,
document_id=document.id,
)
db.session.add(dataset_auto_disable_log)
# remove index # remove index
index_processor = IndexProcessorFactory(dataset.doc_form).init_index_processor() index_processor = IndexProcessorFactory(dataset.doc_form).init_index_processor()
index_processor.clean(dataset, None) index_processor.clean(dataset, None)
@ -151,6 +168,23 @@ def clean_unused_datasets_task():
else: else:
plan = plan_cache.decode() plan = plan_cache.decode()
if plan == "sandbox": if plan == "sandbox":
# add auto disable log
documents = (
db.session.query(Document)
.filter(
Document.dataset_id == dataset.id,
Document.enabled == True,
Document.archived == False,
)
.all()
)
for document in documents:
dataset_auto_disable_log = DatasetAutoDisableLog(
tenant_id=dataset.tenant_id,
dataset_id=dataset.id,
document_id=document.id,
)
db.session.add(dataset_auto_disable_log)
# remove index # remove index
index_processor = IndexProcessorFactory(dataset.doc_form).init_index_processor() index_processor = IndexProcessorFactory(dataset.doc_form).init_index_processor()
index_processor.clean(dataset, None) index_processor.clean(dataset, None)

View File

@ -0,0 +1,66 @@
import logging
import time
import click
from celery import shared_task
from flask import render_template
from extensions.ext_mail import mail
from models.account import Account, Tenant, TenantAccountJoin
from models.dataset import Dataset, DatasetAutoDisableLog
@shared_task(queue="mail")
def send_document_clean_notify_task():
"""
Async Send document clean notify mail
Usage: send_document_clean_notify_task.delay()
"""
if not mail.is_inited():
return
logging.info(click.style("Start send document clean notify mail", fg="green"))
start_at = time.perf_counter()
# send document clean notify mail
try:
dataset_auto_disable_logs = DatasetAutoDisableLog.query.filter(DatasetAutoDisableLog.notified == False).all()
# group by tenant_id
dataset_auto_disable_logs_map = {}
for dataset_auto_disable_log in dataset_auto_disable_logs:
dataset_auto_disable_logs_map[dataset_auto_disable_log.tenant_id].append(dataset_auto_disable_log)
for tenant_id, tenant_dataset_auto_disable_logs in dataset_auto_disable_logs_map.items():
knowledge_details = []
tenant = Tenant.query.filter(Tenant.id == tenant_id).first()
if not tenant:
continue
current_owner_join = TenantAccountJoin.query.filter_by(tenant_id=tenant.id, role="owner").first()
account = Account.query.filter(Account.id == current_owner_join.account_id).first()
if not account:
continue
dataset_auto_dataset_map = {}
for dataset_auto_disable_log in tenant_dataset_auto_disable_logs:
dataset_auto_dataset_map[dataset_auto_disable_log.dataset_id].append(
dataset_auto_disable_log.document_id
)
for dataset_id, document_ids in dataset_auto_dataset_map.items():
dataset = Dataset.query.filter(Dataset.id == dataset_id).first()
if dataset:
document_count = len(document_ids)
knowledge_details.append(f"<li>Knowledge base {dataset.name}: {document_count} documents</li>")
html_content = render_template(
"clean_document_job_mail_template-US.html",
)
mail.send(to=to, subject="立即加入 Dify 工作空间", html=html_content)
end_at = time.perf_counter()
logging.info(
click.style("Send document clean notify mail succeeded: latency: {}".format(end_at - start_at), fg="green")
)
except Exception:
logging.exception("Send invite member mail to {} failed".format(to))

File diff suppressed because it is too large Load Diff

View File

@ -1,4 +1,5 @@
from typing import Optional from enum import Enum
from typing import Literal, Optional
from pydantic import BaseModel from pydantic import BaseModel
@ -8,3 +9,112 @@ class SegmentUpdateEntity(BaseModel):
answer: Optional[str] = None answer: Optional[str] = None
keywords: Optional[list[str]] = None keywords: Optional[list[str]] = None
enabled: Optional[bool] = None enabled: Optional[bool] = None
class ParentMode(str, Enum):
FULL_DOC = "full-doc"
PARAGRAPH = "paragraph"
class NotionIcon(BaseModel):
type: str
url: Optional[str] = None
emoji: Optional[str] = None
class NotionPage(BaseModel):
page_id: str
page_name: str
page_icon: Optional[NotionIcon] = None
type: str
class NotionInfo(BaseModel):
workspace_id: str
pages: list[NotionPage]
class WebsiteInfo(BaseModel):
provider: str
job_id: str
urls: list[str]
only_main_content: bool = True
class FileInfo(BaseModel):
file_ids: list[str]
class InfoList(BaseModel):
data_source_type: Literal["upload_file", "notion_import", "website_crawl"]
notion_info_list: Optional[list[NotionInfo]] = None
file_info_list: Optional[FileInfo] = None
website_info_list: Optional[WebsiteInfo] = None
class DataSource(BaseModel):
info_list: InfoList
class PreProcessingRule(BaseModel):
id: str
enabled: bool
class Segmentation(BaseModel):
separator: str = "\n"
max_tokens: int
chunk_overlap: int = 0
class Rule(BaseModel):
pre_processing_rules: Optional[list[PreProcessingRule]] = None
segmentation: Optional[Segmentation] = None
parent_mode: Optional[Literal["full-doc", "paragraph"]] = None
subchunk_segmentation: Optional[Segmentation] = None
class ProcessRule(BaseModel):
mode: Literal["automatic", "custom", "hierarchical"]
rules: Optional[Rule] = None
class RerankingModel(BaseModel):
reranking_provider_name: Optional[str] = None
reranking_model_name: Optional[str] = None
class RetrievalModel(BaseModel):
search_method: Literal["hybrid_search", "semantic_search", "full_text_search"]
reranking_enable: bool
reranking_model: Optional[RerankingModel] = None
top_k: int
score_threshold_enabled: bool
score_threshold: Optional[float] = None
class KnowledgeConfig(BaseModel):
original_document_id: Optional[str] = None
duplicate: bool = True
indexing_technique: Literal["high_quality", "economy"]
data_source: Optional[DataSource] = None
process_rule: Optional[ProcessRule] = None
retrieval_model: Optional[RetrievalModel] = None
doc_form: str = "text_model"
doc_language: str = "English"
embedding_model: Optional[str] = None
embedding_model_provider: Optional[str] = None
name: Optional[str] = None
class SegmentUpdateArgs(BaseModel):
content: Optional[str] = None
answer: Optional[str] = None
keywords: Optional[list[str]] = None
regenerate_child_chunks: bool = False
enabled: Optional[bool] = None
class ChildChunkUpdateArgs(BaseModel):
id: Optional[str] = None
content: str

View File

@ -0,0 +1,9 @@
from services.errors.base import BaseServiceError
class ChildChunkIndexingError(BaseServiceError):
description = "{message}"
class ChildChunkDeleteIndexError(BaseServiceError):
description = "{message}"

View File

@ -7,7 +7,7 @@ from core.rag.models.document import Document
from core.rag.retrieval.retrieval_methods import RetrievalMethod from core.rag.retrieval.retrieval_methods import RetrievalMethod
from extensions.ext_database import db from extensions.ext_database import db
from models.account import Account from models.account import Account
from models.dataset import Dataset, DatasetQuery, DocumentSegment from models.dataset import Dataset, DatasetQuery
default_retrieval_model = { default_retrieval_model = {
"search_method": RetrievalMethod.SEMANTIC_SEARCH.value, "search_method": RetrievalMethod.SEMANTIC_SEARCH.value,
@ -69,7 +69,7 @@ class HitTestingService:
db.session.add(dataset_query) db.session.add(dataset_query)
db.session.commit() db.session.commit()
return dict(cls.compact_retrieve_response(dataset, query, all_documents)) return cls.compact_retrieve_response(query, all_documents)
@classmethod @classmethod
def external_retrieve( def external_retrieve(
@ -106,41 +106,14 @@ class HitTestingService:
return dict(cls.compact_external_retrieve_response(dataset, query, all_documents)) return dict(cls.compact_external_retrieve_response(dataset, query, all_documents))
@classmethod @classmethod
def compact_retrieve_response(cls, dataset: Dataset, query: str, documents: list[Document]): def compact_retrieve_response(cls, query: str, documents: list[Document]):
records = [] records = RetrievalService.format_retrieval_documents(documents)
for document in documents:
if document.metadata is None:
continue
index_node_id = document.metadata["doc_id"]
segment = (
db.session.query(DocumentSegment)
.filter(
DocumentSegment.dataset_id == dataset.id,
DocumentSegment.enabled == True,
DocumentSegment.status == "completed",
DocumentSegment.index_node_id == index_node_id,
)
.first()
)
if not segment:
continue
record = {
"segment": segment,
"score": document.metadata.get("score", None),
}
records.append(record)
return { return {
"query": { "query": {
"content": query, "content": query,
}, },
"records": records, "records": [record.model_dump() for record in records],
} }
@classmethod @classmethod

View File

@ -1,40 +1,68 @@
from typing import Optional from typing import Optional
from core.model_manager import ModelInstance, ModelManager
from core.model_runtime.entities.model_entities import ModelType
from core.rag.datasource.keyword.keyword_factory import Keyword from core.rag.datasource.keyword.keyword_factory import Keyword
from core.rag.datasource.vdb.vector_factory import Vector from core.rag.datasource.vdb.vector_factory import Vector
from core.rag.index_processor.constant.index_type import IndexType
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from core.rag.models.document import Document from core.rag.models.document import Document
from models.dataset import Dataset, DocumentSegment from extensions.ext_database import db
from models.dataset import ChildChunk, Dataset, DatasetProcessRule, DocumentSegment
from models.dataset import Document as DatasetDocument
from services.entities.knowledge_entities.knowledge_entities import ParentMode
class VectorService: class VectorService:
@classmethod @classmethod
def create_segments_vector( def create_segments_vector(
cls, keywords_list: Optional[list[list[str]]], segments: list[DocumentSegment], dataset: Dataset cls, keywords_list: Optional[list[list[str]]], segments: list[DocumentSegment], dataset: Dataset, doc_form: str
): ):
documents = [] documents = []
for segment in segments: for segment in segments:
document = Document( if doc_form == IndexType.PARENT_CHILD_INDEX:
page_content=segment.content, document = DatasetDocument.query.filter_by(id=segment.document_id).first()
metadata={ # get the process rule
"doc_id": segment.index_node_id, processing_rule = (
"doc_hash": segment.index_node_hash, db.session.query(DatasetProcessRule)
"document_id": segment.document_id, .filter(DatasetProcessRule.id == document.dataset_process_rule_id)
"dataset_id": segment.dataset_id, .first()
}, )
) # get embedding model instance
documents.append(document) if dataset.indexing_technique == "high_quality":
if dataset.indexing_technique == "high_quality": # check embedding model setting
# save vector index model_manager = ModelManager()
vector = Vector(dataset=dataset)
vector.add_texts(documents, duplicate_check=True)
# save keyword index if dataset.embedding_model_provider:
keyword = Keyword(dataset) embedding_model_instance = model_manager.get_model_instance(
tenant_id=dataset.tenant_id,
if keywords_list and len(keywords_list) > 0: provider=dataset.embedding_model_provider,
keyword.add_texts(documents, keywords_list=keywords_list) model_type=ModelType.TEXT_EMBEDDING,
else: model=dataset.embedding_model,
keyword.add_texts(documents) )
else:
embedding_model_instance = model_manager.get_default_model_instance(
tenant_id=dataset.tenant_id,
model_type=ModelType.TEXT_EMBEDDING,
)
else:
raise ValueError("The knowledge base index technique is not high quality!")
cls.generate_child_chunks(segment, document, dataset, embedding_model_instance, processing_rule, False)
else:
document = Document(
page_content=segment.content,
metadata={
"doc_id": segment.index_node_id,
"doc_hash": segment.index_node_hash,
"document_id": segment.document_id,
"dataset_id": segment.dataset_id,
},
)
documents.append(document)
if len(documents) > 0:
index_processor = IndexProcessorFactory(doc_form).init_index_processor()
index_processor.load(dataset, documents, with_keywords=True, keywords_list=keywords_list)
@classmethod @classmethod
def update_segment_vector(cls, keywords: Optional[list[str]], segment: DocumentSegment, dataset: Dataset): def update_segment_vector(cls, keywords: Optional[list[str]], segment: DocumentSegment, dataset: Dataset):
@ -65,3 +93,123 @@ class VectorService:
keyword.add_texts([document], keywords_list=[keywords]) keyword.add_texts([document], keywords_list=[keywords])
else: else:
keyword.add_texts([document]) keyword.add_texts([document])
@classmethod
def generate_child_chunks(
cls,
segment: DocumentSegment,
dataset_document: Document,
dataset: Dataset,
embedding_model_instance: ModelInstance,
processing_rule: DatasetProcessRule,
regenerate: bool = False,
):
index_processor = IndexProcessorFactory(dataset.doc_form).init_index_processor()
if regenerate:
# delete child chunks
index_processor.clean(dataset, [segment.index_node_id], with_keywords=True, delete_child_chunks=True)
# generate child chunks
document = Document(
page_content=segment.content,
metadata={
"doc_id": segment.index_node_id,
"doc_hash": segment.index_node_hash,
"document_id": segment.document_id,
"dataset_id": segment.dataset_id,
},
)
# use full doc mode to generate segment's child chunk
processing_rule_dict = processing_rule.to_dict()
processing_rule_dict["rules"]["parent_mode"] = ParentMode.FULL_DOC.value
documents = index_processor.transform(
[document],
embedding_model_instance=embedding_model_instance,
process_rule=processing_rule_dict,
tenant_id=dataset.tenant_id,
doc_language=dataset_document.doc_language,
)
# save child chunks
if len(documents) > 0 and len(documents[0].children) > 0:
index_processor.load(dataset, documents)
for position, child_chunk in enumerate(documents[0].children, start=1):
child_segment = ChildChunk(
tenant_id=dataset.tenant_id,
dataset_id=dataset.id,
document_id=dataset_document.id,
segment_id=segment.id,
position=position,
index_node_id=child_chunk.metadata["doc_id"],
index_node_hash=child_chunk.metadata["doc_hash"],
content=child_chunk.page_content,
word_count=len(child_chunk.page_content),
type="automatic",
created_by=dataset_document.created_by,
)
db.session.add(child_segment)
db.session.commit()
@classmethod
def create_child_chunk_vector(cls, child_segment: ChildChunk, dataset: Dataset):
child_document = Document(
page_content=child_segment.content,
metadata={
"doc_id": child_segment.index_node_id,
"doc_hash": child_segment.index_node_hash,
"document_id": child_segment.document_id,
"dataset_id": child_segment.dataset_id,
},
)
if dataset.indexing_technique == "high_quality":
# save vector index
vector = Vector(dataset=dataset)
vector.add_texts([child_document], duplicate_check=True)
@classmethod
def update_child_chunk_vector(
cls,
new_child_chunks: list[ChildChunk],
update_child_chunks: list[ChildChunk],
delete_child_chunks: list[ChildChunk],
dataset: Dataset,
):
documents = []
delete_node_ids = []
for new_child_chunk in new_child_chunks:
new_child_document = Document(
page_content=new_child_chunk.content,
metadata={
"doc_id": new_child_chunk.index_node_id,
"doc_hash": new_child_chunk.index_node_hash,
"document_id": new_child_chunk.document_id,
"dataset_id": new_child_chunk.dataset_id,
},
)
documents.append(new_child_document)
for update_child_chunk in update_child_chunks:
child_document = Document(
page_content=update_child_chunk.content,
metadata={
"doc_id": update_child_chunk.index_node_id,
"doc_hash": update_child_chunk.index_node_hash,
"document_id": update_child_chunk.document_id,
"dataset_id": update_child_chunk.dataset_id,
},
)
documents.append(child_document)
delete_node_ids.append(update_child_chunk.index_node_id)
for delete_child_chunk in delete_child_chunks:
delete_node_ids.append(delete_child_chunk.index_node_id)
if dataset.indexing_technique == "high_quality":
# update vector index
vector = Vector(dataset=dataset)
if delete_node_ids:
vector.delete_by_ids(delete_node_ids)
if documents:
vector.add_texts(documents, duplicate_check=True)
@classmethod
def delete_child_chunk_vector(cls, child_chunk: ChildChunk, dataset: Dataset):
vector = Vector(dataset=dataset)
vector.delete_by_ids([child_chunk.index_node_id])

View File

@ -6,12 +6,13 @@ import click
from celery import shared_task # type: ignore from celery import shared_task # type: ignore
from werkzeug.exceptions import NotFound from werkzeug.exceptions import NotFound
from core.rag.index_processor.constant.index_type import IndexType
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from core.rag.models.document import Document from core.rag.models.document import ChildDocument, Document
from extensions.ext_database import db from extensions.ext_database import db
from extensions.ext_redis import redis_client from extensions.ext_redis import redis_client
from models.dataset import DatasetAutoDisableLog, DocumentSegment
from models.dataset import Document as DatasetDocument from models.dataset import Document as DatasetDocument
from models.dataset import DocumentSegment
@shared_task(queue="dataset") @shared_task(queue="dataset")
@ -53,7 +54,22 @@ def add_document_to_index_task(dataset_document_id: str):
"dataset_id": segment.dataset_id, "dataset_id": segment.dataset_id,
}, },
) )
if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX:
child_chunks = segment.child_chunks
if child_chunks:
child_documents = []
for child_chunk in child_chunks:
child_document = ChildDocument(
page_content=child_chunk.content,
metadata={
"doc_id": child_chunk.index_node_id,
"doc_hash": child_chunk.index_node_hash,
"document_id": segment.document_id,
"dataset_id": segment.dataset_id,
},
)
child_documents.append(child_document)
document.children = child_documents
documents.append(document) documents.append(document)
dataset = dataset_document.dataset dataset = dataset_document.dataset
@ -65,6 +81,12 @@ def add_document_to_index_task(dataset_document_id: str):
index_processor = IndexProcessorFactory(index_type).init_index_processor() index_processor = IndexProcessorFactory(index_type).init_index_processor()
index_processor.load(dataset, documents) index_processor.load(dataset, documents)
# delete auto disable log
db.session.query(DatasetAutoDisableLog).filter(
DatasetAutoDisableLog.document_id == dataset_document.id
).delete()
db.session.commit()
end_at = time.perf_counter() end_at = time.perf_counter()
logging.info( logging.info(
click.style( click.style(

View File

@ -0,0 +1,75 @@
import logging
import time
import click
from celery import shared_task
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from core.tools.utils.web_reader_tool import get_image_upload_file_ids
from extensions.ext_database import db
from extensions.ext_storage import storage
from models.dataset import Dataset, DocumentSegment
from models.model import UploadFile
@shared_task(queue="dataset")
def batch_clean_document_task(document_ids: list[str], dataset_id: str, doc_form: str, file_ids: list[str]):
"""
Clean document when document deleted.
:param document_ids: document ids
:param dataset_id: dataset id
:param doc_form: doc_form
:param file_ids: file ids
Usage: clean_document_task.delay(document_id, dataset_id)
"""
logging.info(click.style("Start batch clean documents when documents deleted", fg="green"))
start_at = time.perf_counter()
try:
dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first()
if not dataset:
raise Exception("Document has no dataset")
segments = db.session.query(DocumentSegment).filter(DocumentSegment.document_id.in_(document_ids)).all()
# check segment is exist
if segments:
index_node_ids = [segment.index_node_id for segment in segments]
index_processor = IndexProcessorFactory(doc_form).init_index_processor()
index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True)
for segment in segments:
image_upload_file_ids = get_image_upload_file_ids(segment.content)
for upload_file_id in image_upload_file_ids:
image_file = db.session.query(UploadFile).filter(UploadFile.id == upload_file_id).first()
try:
storage.delete(image_file.key)
except Exception:
logging.exception(
"Delete image_files failed when storage deleted, \
image_upload_file_is: {}".format(upload_file_id)
)
db.session.delete(image_file)
db.session.delete(segment)
db.session.commit()
if file_ids:
files = db.session.query(UploadFile).filter(UploadFile.id.in_(file_ids)).all()
for file in files:
try:
storage.delete(file.key)
except Exception:
logging.exception("Delete file failed when document deleted, file_id: {}".format(file.id))
db.session.delete(file)
db.session.commit()
end_at = time.perf_counter()
logging.info(
click.style(
"Cleaned documents when documents deleted latency: {}".format(end_at - start_at),
fg="green",
)
)
except Exception:
logging.exception("Cleaned documents when documents deleted failed")

View File

@ -7,13 +7,13 @@ import click
from celery import shared_task # type: ignore from celery import shared_task # type: ignore
from sqlalchemy import func from sqlalchemy import func
from core.indexing_runner import IndexingRunner
from core.model_manager import ModelManager from core.model_manager import ModelManager
from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.entities.model_entities import ModelType
from extensions.ext_database import db from extensions.ext_database import db
from extensions.ext_redis import redis_client from extensions.ext_redis import redis_client
from libs import helper from libs import helper
from models.dataset import Dataset, Document, DocumentSegment from models.dataset import Dataset, Document, DocumentSegment
from services.vector_service import VectorService
@shared_task(queue="dataset") @shared_task(queue="dataset")
@ -96,8 +96,7 @@ def batch_create_segment_to_index_task(
dataset_document.word_count += word_count_change dataset_document.word_count += word_count_change
db.session.add(dataset_document) db.session.add(dataset_document)
# add index to db # add index to db
indexing_runner = IndexingRunner() VectorService.create_segments_vector(None, document_segments, dataset, dataset_document.doc_form)
indexing_runner.batch_add_segments(document_segments, dataset)
db.session.commit() db.session.commit()
redis_client.setex(indexing_cache_key, 600, "completed") redis_client.setex(indexing_cache_key, 600, "completed")
end_at = time.perf_counter() end_at = time.perf_counter()

View File

@ -62,7 +62,7 @@ def clean_dataset_task(
if doc_form is None: if doc_form is None:
raise ValueError("Index type must be specified.") raise ValueError("Index type must be specified.")
index_processor = IndexProcessorFactory(doc_form).init_index_processor() index_processor = IndexProcessorFactory(doc_form).init_index_processor()
index_processor.clean(dataset, None) index_processor.clean(dataset, None, with_keywords=True, delete_child_chunks=True)
for document in documents: for document in documents:
db.session.delete(document) db.session.delete(document)

View File

@ -38,7 +38,7 @@ def clean_document_task(document_id: str, dataset_id: str, doc_form: str, file_i
if segments: if segments:
index_node_ids = [segment.index_node_id for segment in segments] index_node_ids = [segment.index_node_id for segment in segments]
index_processor = IndexProcessorFactory(doc_form).init_index_processor() index_processor = IndexProcessorFactory(doc_form).init_index_processor()
index_processor.clean(dataset, index_node_ids) index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True)
for segment in segments: for segment in segments:
image_upload_file_ids = get_image_upload_file_ids(segment.content) image_upload_file_ids = get_image_upload_file_ids(segment.content)

View File

@ -37,7 +37,7 @@ def clean_notion_document_task(document_ids: list[str], dataset_id: str):
segments = db.session.query(DocumentSegment).filter(DocumentSegment.document_id == document_id).all() segments = db.session.query(DocumentSegment).filter(DocumentSegment.document_id == document_id).all()
index_node_ids = [segment.index_node_id for segment in segments] index_node_ids = [segment.index_node_id for segment in segments]
index_processor.clean(dataset, index_node_ids) index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True)
for segment in segments: for segment in segments:
db.session.delete(segment) db.session.delete(segment)

View File

@ -4,8 +4,9 @@ import time
import click import click
from celery import shared_task # type: ignore from celery import shared_task # type: ignore
from core.rag.index_processor.constant.index_type import IndexType
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from core.rag.models.document import Document from core.rag.models.document import ChildDocument, Document
from extensions.ext_database import db from extensions.ext_database import db
from models.dataset import Dataset, DocumentSegment from models.dataset import Dataset, DocumentSegment
from models.dataset import Document as DatasetDocument from models.dataset import Document as DatasetDocument
@ -105,7 +106,7 @@ def deal_dataset_vector_index_task(dataset_id: str, action: str):
db.session.commit() db.session.commit()
# clean index # clean index
index_processor.clean(dataset, None, with_keywords=False) index_processor.clean(dataset, None, with_keywords=False, delete_child_chunks=False)
for dataset_document in dataset_documents: for dataset_document in dataset_documents:
# update from vector index # update from vector index
@ -128,7 +129,22 @@ def deal_dataset_vector_index_task(dataset_id: str, action: str):
"dataset_id": segment.dataset_id, "dataset_id": segment.dataset_id,
}, },
) )
if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX:
child_chunks = segment.child_chunks
if child_chunks:
child_documents = []
for child_chunk in child_chunks:
child_document = ChildDocument(
page_content=child_chunk.content,
metadata={
"doc_id": child_chunk.index_node_id,
"doc_hash": child_chunk.index_node_hash,
"document_id": segment.document_id,
"dataset_id": segment.dataset_id,
},
)
child_documents.append(child_document)
document.children = child_documents
documents.append(document) documents.append(document)
# save vector index # save vector index
index_processor.load(dataset, documents, with_keywords=False) index_processor.load(dataset, documents, with_keywords=False)

View File

@ -6,48 +6,38 @@ from celery import shared_task # type: ignore
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from extensions.ext_database import db from extensions.ext_database import db
from extensions.ext_redis import redis_client
from models.dataset import Dataset, Document from models.dataset import Dataset, Document
@shared_task(queue="dataset") @shared_task(queue="dataset")
def delete_segment_from_index_task(segment_id: str, index_node_id: str, dataset_id: str, document_id: str): def delete_segment_from_index_task(index_node_ids: list, dataset_id: str, document_id: str):
""" """
Async Remove segment from index Async Remove segment from index
:param segment_id: :param index_node_ids:
:param index_node_id:
:param dataset_id: :param dataset_id:
:param document_id: :param document_id:
Usage: delete_segment_from_index_task.delay(segment_id) Usage: delete_segment_from_index_task.delay(segment_ids)
""" """
logging.info(click.style("Start delete segment from index: {}".format(segment_id), fg="green")) logging.info(click.style("Start delete segment from index", fg="green"))
start_at = time.perf_counter() start_at = time.perf_counter()
indexing_cache_key = "segment_{}_delete_indexing".format(segment_id)
try: try:
dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first() dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first()
if not dataset: if not dataset:
logging.info(click.style("Segment {} has no dataset, pass.".format(segment_id), fg="cyan"))
return return
dataset_document = db.session.query(Document).filter(Document.id == document_id).first() dataset_document = db.session.query(Document).filter(Document.id == document_id).first()
if not dataset_document: if not dataset_document:
logging.info(click.style("Segment {} has no document, pass.".format(segment_id), fg="cyan"))
return return
if not dataset_document.enabled or dataset_document.archived or dataset_document.indexing_status != "completed": if not dataset_document.enabled or dataset_document.archived or dataset_document.indexing_status != "completed":
logging.info(click.style("Segment {} document status is invalid, pass.".format(segment_id), fg="cyan"))
return return
index_type = dataset_document.doc_form index_type = dataset_document.doc_form
index_processor = IndexProcessorFactory(index_type).init_index_processor() index_processor = IndexProcessorFactory(index_type).init_index_processor()
index_processor.clean(dataset, [index_node_id]) index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True)
end_at = time.perf_counter() end_at = time.perf_counter()
logging.info( logging.info(click.style("Segment deleted from index latency: {}".format(end_at - start_at), fg="green"))
click.style("Segment deleted from index: {} latency: {}".format(segment_id, end_at - start_at), fg="green")
)
except Exception: except Exception:
logging.exception("delete segment from index failed") logging.exception("delete segment from index failed")
finally:
redis_client.delete(indexing_cache_key)

View File

@ -0,0 +1,76 @@
import logging
import time
import click
from celery import shared_task
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from models.dataset import Dataset, DocumentSegment
from models.dataset import Document as DatasetDocument
@shared_task(queue="dataset")
def disable_segments_from_index_task(segment_ids: list, dataset_id: str, document_id: str):
"""
Async disable segments from index
:param segment_ids:
Usage: disable_segments_from_index_task.delay(segment_ids, dataset_id, document_id)
"""
start_at = time.perf_counter()
dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first()
if not dataset:
logging.info(click.style("Dataset {} not found, pass.".format(dataset_id), fg="cyan"))
return
dataset_document = db.session.query(DatasetDocument).filter(DatasetDocument.id == document_id).first()
if not dataset_document:
logging.info(click.style("Document {} not found, pass.".format(document_id), fg="cyan"))
return
if not dataset_document.enabled or dataset_document.archived or dataset_document.indexing_status != "completed":
logging.info(click.style("Document {} status is invalid, pass.".format(document_id), fg="cyan"))
return
# sync index processor
index_processor = IndexProcessorFactory(dataset_document.doc_form).init_index_processor()
segments = (
db.session.query(DocumentSegment)
.filter(
DocumentSegment.id.in_(segment_ids),
DocumentSegment.dataset_id == dataset_id,
DocumentSegment.document_id == document_id,
)
.all()
)
if not segments:
return
try:
index_node_ids = [segment.index_node_id for segment in segments]
index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=False)
end_at = time.perf_counter()
logging.info(click.style("Segments removed from index latency: {}".format(end_at - start_at), fg="green"))
except Exception:
# update segment error msg
db.session.query(DocumentSegment).filter(
DocumentSegment.id.in_(segment_ids),
DocumentSegment.dataset_id == dataset_id,
DocumentSegment.document_id == document_id,
).update(
{
"disabled_at": None,
"disabled_by": None,
"enabled": True,
}
)
db.session.commit()
finally:
for segment in segments:
indexing_cache_key = "segment_{}_indexing".format(segment.id)
redis_client.delete(indexing_cache_key)

View File

@ -82,7 +82,7 @@ def document_indexing_sync_task(dataset_id: str, document_id: str):
index_node_ids = [segment.index_node_id for segment in segments] index_node_ids = [segment.index_node_id for segment in segments]
# delete from vector index # delete from vector index
index_processor.clean(dataset, index_node_ids) index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True)
for segment in segments: for segment in segments:
db.session.delete(segment) db.session.delete(segment)

View File

@ -47,7 +47,7 @@ def document_indexing_update_task(dataset_id: str, document_id: str):
index_node_ids = [segment.index_node_id for segment in segments] index_node_ids = [segment.index_node_id for segment in segments]
# delete from vector index # delete from vector index
index_processor.clean(dataset, index_node_ids) index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True)
for segment in segments: for segment in segments:
db.session.delete(segment) db.session.delete(segment)

View File

@ -51,7 +51,7 @@ def duplicate_document_indexing_task(dataset_id: str, document_ids: list):
if document: if document:
document.indexing_status = "error" document.indexing_status = "error"
document.error = str(e) document.error = str(e)
document.stopped_at = datetime.datetime.utcnow() document.stopped_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
db.session.add(document) db.session.add(document)
db.session.commit() db.session.commit()
return return
@ -73,14 +73,14 @@ def duplicate_document_indexing_task(dataset_id: str, document_ids: list):
index_node_ids = [segment.index_node_id for segment in segments] index_node_ids = [segment.index_node_id for segment in segments]
# delete from vector index # delete from vector index
index_processor.clean(dataset, index_node_ids) index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True)
for segment in segments: for segment in segments:
db.session.delete(segment) db.session.delete(segment)
db.session.commit() db.session.commit()
document.indexing_status = "parsing" document.indexing_status = "parsing"
document.processing_started_at = datetime.datetime.utcnow() document.processing_started_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
documents.append(document) documents.append(document)
db.session.add(document) db.session.add(document)
db.session.commit() db.session.commit()

View File

@ -6,8 +6,9 @@ import click
from celery import shared_task # type: ignore from celery import shared_task # type: ignore
from werkzeug.exceptions import NotFound from werkzeug.exceptions import NotFound
from core.rag.index_processor.constant.index_type import IndexType
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from core.rag.models.document import Document from core.rag.models.document import ChildDocument, Document
from extensions.ext_database import db from extensions.ext_database import db
from extensions.ext_redis import redis_client from extensions.ext_redis import redis_client
from models.dataset import DocumentSegment from models.dataset import DocumentSegment
@ -61,6 +62,22 @@ def enable_segment_to_index_task(segment_id: str):
return return
index_processor = IndexProcessorFactory(dataset_document.doc_form).init_index_processor() index_processor = IndexProcessorFactory(dataset_document.doc_form).init_index_processor()
if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX:
child_chunks = segment.child_chunks
if child_chunks:
child_documents = []
for child_chunk in child_chunks:
child_document = ChildDocument(
page_content=child_chunk.content,
metadata={
"doc_id": child_chunk.index_node_id,
"doc_hash": child_chunk.index_node_hash,
"document_id": segment.document_id,
"dataset_id": segment.dataset_id,
},
)
child_documents.append(child_document)
document.children = child_documents
# save vector index # save vector index
index_processor.load(dataset, [document]) index_processor.load(dataset, [document])

View File

@ -0,0 +1,108 @@
import datetime
import logging
import time
import click
from celery import shared_task
from core.rag.index_processor.constant.index_type import IndexType
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from core.rag.models.document import ChildDocument, Document
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from models.dataset import Dataset, DocumentSegment
from models.dataset import Document as DatasetDocument
@shared_task(queue="dataset")
def enable_segments_to_index_task(segment_ids: list, dataset_id: str, document_id: str):
"""
Async enable segments to index
:param segment_ids:
Usage: enable_segments_to_index_task.delay(segment_ids)
"""
start_at = time.perf_counter()
dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first()
if not dataset:
logging.info(click.style("Dataset {} not found, pass.".format(dataset_id), fg="cyan"))
return
dataset_document = db.session.query(DatasetDocument).filter(DatasetDocument.id == document_id).first()
if not dataset_document:
logging.info(click.style("Document {} not found, pass.".format(document_id), fg="cyan"))
return
if not dataset_document.enabled or dataset_document.archived or dataset_document.indexing_status != "completed":
logging.info(click.style("Document {} status is invalid, pass.".format(document_id), fg="cyan"))
return
# sync index processor
index_processor = IndexProcessorFactory(dataset_document.doc_form).init_index_processor()
segments = (
db.session.query(DocumentSegment)
.filter(
DocumentSegment.id.in_(segment_ids),
DocumentSegment.dataset_id == dataset_id,
DocumentSegment.document_id == document_id,
)
.all()
)
if not segments:
return
try:
documents = []
for segment in segments:
document = Document(
page_content=segment.content,
metadata={
"doc_id": segment.index_node_id,
"doc_hash": segment.index_node_hash,
"document_id": document_id,
"dataset_id": dataset_id,
},
)
if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX:
child_chunks = segment.child_chunks
if child_chunks:
child_documents = []
for child_chunk in child_chunks:
child_document = ChildDocument(
page_content=child_chunk.content,
metadata={
"doc_id": child_chunk.index_node_id,
"doc_hash": child_chunk.index_node_hash,
"document_id": document_id,
"dataset_id": dataset_id,
},
)
child_documents.append(child_document)
document.children = child_documents
documents.append(document)
# save vector index
index_processor.load(dataset, documents)
end_at = time.perf_counter()
logging.info(click.style("Segments enabled to index latency: {}".format(end_at - start_at), fg="green"))
except Exception as e:
logging.exception("enable segments to index failed")
# update segment error msg
db.session.query(DocumentSegment).filter(
DocumentSegment.id.in_(segment_ids),
DocumentSegment.dataset_id == dataset_id,
DocumentSegment.document_id == document_id,
).update(
{
"error": str(e),
"status": "error",
"disabled_at": datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None),
"enabled": False,
}
)
db.session.commit()
finally:
for segment in segments:
indexing_cache_key = "segment_{}_indexing".format(segment.id)
redis_client.delete(indexing_cache_key)

View File

@ -43,7 +43,7 @@ def remove_document_from_index_task(document_id: str):
index_node_ids = [segment.index_node_id for segment in segments] index_node_ids = [segment.index_node_id for segment in segments]
if index_node_ids: if index_node_ids:
try: try:
index_processor.clean(dataset, index_node_ids) index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=False)
except Exception: except Exception:
logging.exception(f"clean dataset {dataset.id} from index failed") logging.exception(f"clean dataset {dataset.id} from index failed")

View File

@ -48,7 +48,7 @@ def retry_document_indexing_task(dataset_id: str, document_ids: list[str]):
if document: if document:
document.indexing_status = "error" document.indexing_status = "error"
document.error = str(e) document.error = str(e)
document.stopped_at = datetime.datetime.utcnow() document.stopped_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
db.session.add(document) db.session.add(document)
db.session.commit() db.session.commit()
redis_client.delete(retry_indexing_cache_key) redis_client.delete(retry_indexing_cache_key)
@ -69,14 +69,14 @@ def retry_document_indexing_task(dataset_id: str, document_ids: list[str]):
if segments: if segments:
index_node_ids = [segment.index_node_id for segment in segments] index_node_ids = [segment.index_node_id for segment in segments]
# delete from vector index # delete from vector index
index_processor.clean(dataset, index_node_ids) index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True)
for segment in segments: for segment in segments:
db.session.delete(segment) db.session.delete(segment)
db.session.commit() db.session.commit()
document.indexing_status = "parsing" document.indexing_status = "parsing"
document.processing_started_at = datetime.datetime.utcnow() document.processing_started_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
db.session.add(document) db.session.add(document)
db.session.commit() db.session.commit()
@ -86,7 +86,7 @@ def retry_document_indexing_task(dataset_id: str, document_ids: list[str]):
except Exception as ex: except Exception as ex:
document.indexing_status = "error" document.indexing_status = "error"
document.error = str(ex) document.error = str(ex)
document.stopped_at = datetime.datetime.utcnow() document.stopped_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
db.session.add(document) db.session.add(document)
db.session.commit() db.session.commit()
logging.info(click.style(str(ex), fg="yellow")) logging.info(click.style(str(ex), fg="yellow"))

View File

@ -46,7 +46,7 @@ def sync_website_document_indexing_task(dataset_id: str, document_id: str):
if document: if document:
document.indexing_status = "error" document.indexing_status = "error"
document.error = str(e) document.error = str(e)
document.stopped_at = datetime.datetime.utcnow() document.stopped_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
db.session.add(document) db.session.add(document)
db.session.commit() db.session.commit()
redis_client.delete(sync_indexing_cache_key) redis_client.delete(sync_indexing_cache_key)
@ -65,14 +65,14 @@ def sync_website_document_indexing_task(dataset_id: str, document_id: str):
if segments: if segments:
index_node_ids = [segment.index_node_id for segment in segments] index_node_ids = [segment.index_node_id for segment in segments]
# delete from vector index # delete from vector index
index_processor.clean(dataset, index_node_ids) index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True)
for segment in segments: for segment in segments:
db.session.delete(segment) db.session.delete(segment)
db.session.commit() db.session.commit()
document.indexing_status = "parsing" document.indexing_status = "parsing"
document.processing_started_at = datetime.datetime.utcnow() document.processing_started_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
db.session.add(document) db.session.add(document)
db.session.commit() db.session.commit()
@ -82,7 +82,7 @@ def sync_website_document_indexing_task(dataset_id: str, document_id: str):
except Exception as ex: except Exception as ex:
document.indexing_status = "error" document.indexing_status = "error"
document.error = str(ex) document.error = str(ex)
document.stopped_at = datetime.datetime.utcnow() document.stopped_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
db.session.add(document) db.session.add(document)
db.session.commit() db.session.commit()
logging.info(click.style(str(ex), fg="yellow")) logging.info(click.style(str(ex), fg="yellow"))

View File

@ -0,0 +1,98 @@
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>Documents Disabled Notification</title>
<style>
body {
font-family: Arial, sans-serif;
margin: 0;
padding: 0;
background-color: #f5f5f5;
}
.email-container {
max-width: 600px;
margin: 20px auto;
background: #ffffff;
border-radius: 10px;
box-shadow: 0 2px 10px rgba(0, 0, 0, 0.1);
overflow: hidden;
}
.header {
background-color: #eef2fa;
padding: 20px;
text-align: center;
}
.header img {
height: 40px;
}
.content {
padding: 20px;
line-height: 1.6;
color: #333;
}
.content h1 {
font-size: 24px;
color: #222;
}
.content p {
margin: 10px 0;
}
.content ul {
padding-left: 20px;
}
.content ul li {
margin-bottom: 10px;
}
.cta-button {
display: block;
margin: 20px auto;
padding: 10px 20px;
background-color: #4e89f9;
color: #ffffff;
text-align: center;
text-decoration: none;
border-radius: 5px;
width: fit-content;
}
.footer {
text-align: center;
padding: 10px;
font-size: 12px;
color: #777;
background-color: #f9f9f9;
}
</style>
</head>
<body>
<div class="email-container">
<!-- Header -->
<div class="header">
<img src="https://via.placeholder.com/150x40?text=Dify" alt="Dify Logo">
</div>
<!-- Content -->
<div class="content">
<h1>Some Documents in Your Knowledge Base Have Been Disabled</h1>
<p>Dear {{userName}},</p>
<p>
We're sorry for the inconvenience. To ensure optimal performance, documents
that havent been updated or accessed in the past 7 days have been disabled in
your knowledge bases:
</p>
<ul>
{{knowledge_details}}
</ul>
<p>You can re-enable them anytime.</p>
<a href={{url}} class="cta-button">Re-enable in Dify</a>
</div>
<!-- Footer -->
<div class="footer">
Sincerely,<br>
The Dify Team
</div>
</div>
</body>
</html>