Feat/support parent child chunk (#12092)

This commit is contained in:
Jyong 2024-12-25 19:49:07 +08:00 committed by GitHub
parent 017d7538ae
commit 9231fdbf4c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
54 changed files with 2578 additions and 808 deletions

View File

@ -218,7 +218,7 @@ class DataSourceNotionApi(Resource):
args["doc_form"],
args["doc_language"],
)
return response, 200
return response.model_dump(), 200
class DataSourceNotionDatasetSyncApi(Resource):

View File

@ -464,7 +464,7 @@ class DatasetIndexingEstimateApi(Resource):
except Exception as e:
raise IndexingEstimateError(str(e))
return response, 200
return response.model_dump(), 200
class DatasetRelatedAppListApi(Resource):
@ -733,6 +733,18 @@ class DatasetPermissionUserListApi(Resource):
}, 200
class DatasetAutoDisableLogApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, dataset_id):
dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
if dataset is None:
raise NotFound("Dataset not found.")
return DatasetService.get_dataset_auto_disable_logs(dataset_id_str), 200
api.add_resource(DatasetListApi, "/datasets")
api.add_resource(DatasetApi, "/datasets/<uuid:dataset_id>")
api.add_resource(DatasetUseCheckApi, "/datasets/<uuid:dataset_id>/use-check")
@ -747,3 +759,4 @@ api.add_resource(DatasetApiBaseUrlApi, "/datasets/api-base-info")
api.add_resource(DatasetRetrievalSettingApi, "/datasets/retrieval-setting")
api.add_resource(DatasetRetrievalSettingMockApi, "/datasets/retrieval-setting/<string:vector_type>")
api.add_resource(DatasetPermissionUserListApi, "/datasets/<uuid:dataset_id>/permission-part-users")
api.add_resource(DatasetAutoDisableLogApi, "/datasets/<uuid:dataset_id>/auto-disable-logs")

View File

@ -52,6 +52,7 @@ from fields.document_fields import (
from libs.login import login_required
from models import Dataset, DatasetProcessRule, Document, DocumentSegment, UploadFile
from services.dataset_service import DatasetService, DocumentService
from services.entities.knowledge_entities.knowledge_entities import KnowledgeConfig
from tasks.add_document_to_index_task import add_document_to_index_task
from tasks.remove_document_from_index_task import remove_document_from_index_task
@ -255,20 +256,22 @@ class DatasetDocumentListApi(Resource):
parser.add_argument("duplicate", type=bool, default=True, nullable=False, location="json")
parser.add_argument("original_document_id", type=str, required=False, location="json")
parser.add_argument("doc_form", type=str, default="text_model", required=False, nullable=False, location="json")
parser.add_argument("retrieval_model", type=dict, required=False, nullable=False, location="json")
parser.add_argument(
"doc_language", type=str, default="English", required=False, nullable=False, location="json"
)
parser.add_argument("retrieval_model", type=dict, required=False, nullable=False, location="json")
args = parser.parse_args()
knowledge_config = KnowledgeConfig(**args)
if not dataset.indexing_technique and not args["indexing_technique"]:
if not dataset.indexing_technique and not knowledge_config.indexing_technique:
raise ValueError("indexing_technique is required.")
# validate args
DocumentService.document_create_args_validate(args)
DocumentService.document_create_args_validate(knowledge_config)
try:
documents, batch = DocumentService.save_document_with_dataset_id(dataset, args, current_user)
documents, batch = DocumentService.save_document_with_dataset_id(dataset, knowledge_config, current_user)
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
except QuotaExceededError:
@ -278,6 +281,25 @@ class DatasetDocumentListApi(Resource):
return {"documents": documents, "batch": batch}
@setup_required
@login_required
@account_initialization_required
def delete(self, dataset_id):
dataset_id = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id)
if dataset is None:
raise NotFound("Dataset not found.")
# check user's model setting
DatasetService.check_dataset_model_setting(dataset)
try:
document_ids = request.args.getlist("document_id")
DocumentService.delete_documents(dataset, document_ids)
except services.errors.document.DocumentIndexingError:
raise DocumentIndexingError("Cannot delete document during indexing.")
return {"result": "success"}, 204
class DatasetInitApi(Resource):
@setup_required
@ -313,9 +335,9 @@ class DatasetInitApi(Resource):
# The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator
if not current_user.is_dataset_editor:
raise Forbidden()
if args["indexing_technique"] == "high_quality":
if args["embedding_model"] is None or args["embedding_model_provider"] is None:
knowledge_config = KnowledgeConfig(**args)
if knowledge_config.indexing_technique == "high_quality":
if knowledge_config.embedding_model is None or knowledge_config.embedding_model_provider is None:
raise ValueError("embedding model and embedding model provider are required for high quality indexing.")
try:
model_manager = ModelManager()
@ -334,11 +356,11 @@ class DatasetInitApi(Resource):
raise ProviderNotInitializeError(ex.description)
# validate args
DocumentService.document_create_args_validate(args)
DocumentService.document_create_args_validate(knowledge_config)
try:
dataset, documents, batch = DocumentService.save_document_without_dataset_id(
tenant_id=current_user.current_tenant_id, document_data=args, account=current_user
tenant_id=current_user.current_tenant_id, knowledge_config=knowledge_config, account=current_user
)
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
@ -409,7 +431,7 @@ class DocumentIndexingEstimateApi(DocumentResource):
except Exception as e:
raise IndexingEstimateError(str(e))
return response
return response.model_dump(), 200
class DocumentBatchIndexingEstimateApi(DocumentResource):
@ -422,7 +444,7 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
documents = self.get_batch_documents(dataset_id, batch)
response = {"tokens": 0, "total_price": 0, "currency": "USD", "total_segments": 0, "preview": []}
if not documents:
return response
return response, 200
data_process_rule = documents[0].dataset_process_rule
data_process_rule_dict = data_process_rule.to_dict()
info_list = []
@ -509,7 +531,7 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
raise ProviderNotInitializeError(ex.description)
except Exception as e:
raise IndexingEstimateError(str(e))
return response
return response.model_dump(), 200
class DocumentBatchIndexingStatusApi(DocumentResource):
@ -582,7 +604,8 @@ class DocumentDetailApi(DocumentResource):
if metadata == "only":
response = {"id": document.id, "doc_type": document.doc_type, "doc_metadata": document.doc_metadata}
elif metadata == "without":
process_rules = DatasetService.get_process_rules(dataset_id)
dataset_process_rules = DatasetService.get_process_rules(dataset_id)
document_process_rules = document.dataset_process_rule.to_dict()
data_source_info = document.data_source_detail_dict
response = {
"id": document.id,
@ -590,7 +613,8 @@ class DocumentDetailApi(DocumentResource):
"data_source_type": document.data_source_type,
"data_source_info": data_source_info,
"dataset_process_rule_id": document.dataset_process_rule_id,
"dataset_process_rule": process_rules,
"dataset_process_rule": dataset_process_rules,
"document_process_rule": document_process_rules,
"name": document.name,
"created_from": document.created_from,
"created_by": document.created_by,
@ -613,7 +637,8 @@ class DocumentDetailApi(DocumentResource):
"doc_language": document.doc_language,
}
else:
process_rules = DatasetService.get_process_rules(dataset_id)
dataset_process_rules = DatasetService.get_process_rules(dataset_id)
document_process_rules = document.dataset_process_rule.to_dict()
data_source_info = document.data_source_detail_dict
response = {
"id": document.id,
@ -621,7 +646,8 @@ class DocumentDetailApi(DocumentResource):
"data_source_type": document.data_source_type,
"data_source_info": data_source_info,
"dataset_process_rule_id": document.dataset_process_rule_id,
"dataset_process_rule": process_rules,
"dataset_process_rule": dataset_process_rules,
"document_process_rule": document_process_rules,
"name": document.name,
"created_from": document.created_from,
"created_by": document.created_by,
@ -757,9 +783,8 @@ class DocumentStatusApi(DocumentResource):
@login_required
@account_initialization_required
@cloud_edition_billing_resource_check("vector_space")
def patch(self, dataset_id, document_id, action):
def patch(self, dataset_id, action):
dataset_id = str(dataset_id)
document_id = str(document_id)
dataset = DatasetService.get_dataset(dataset_id)
if dataset is None:
raise NotFound("Dataset not found.")
@ -774,17 +799,18 @@ class DocumentStatusApi(DocumentResource):
# check user's permission
DatasetService.check_dataset_permission(dataset, current_user)
document_ids = request.args.getlist("document_id")
for document_id in document_ids:
document = self.get_document(dataset_id, document_id)
indexing_cache_key = "document_{}_indexing".format(document.id)
cache_result = redis_client.get(indexing_cache_key)
if cache_result is not None:
raise InvalidActionError("Document is being indexed, please try again later")
raise InvalidActionError(f"Document:{document.name} is being indexed, please try again later")
if action == "enable":
if document.enabled:
raise InvalidActionError("Document already enabled.")
continue
document.enabled = True
document.disabled_at = None
document.disabled_by = None
@ -796,13 +822,11 @@ class DocumentStatusApi(DocumentResource):
add_document_to_index_task.delay(document_id)
return {"result": "success"}, 200
elif action == "disable":
if not document.completed_at or document.indexing_status != "completed":
raise InvalidActionError("Document is not completed.")
raise InvalidActionError(f"Document: {document.name} is not completed.")
if not document.enabled:
raise InvalidActionError("Document already disabled.")
continue
document.enabled = False
document.disabled_at = datetime.now(UTC).replace(tzinfo=None)
@ -815,11 +839,9 @@ class DocumentStatusApi(DocumentResource):
remove_document_from_index_task.delay(document_id)
return {"result": "success"}, 200
elif action == "archive":
if document.archived:
raise InvalidActionError("Document already archived.")
continue
document.archived = True
document.archived_at = datetime.now(UTC).replace(tzinfo=None)
@ -833,11 +855,9 @@ class DocumentStatusApi(DocumentResource):
remove_document_from_index_task.delay(document_id)
return {"result": "success"}, 200
elif action == "un_archive":
if not document.archived:
raise InvalidActionError("Document is not archived.")
continue
document.archived = False
document.archived_at = None
document.archived_by = None
@ -849,9 +869,9 @@ class DocumentStatusApi(DocumentResource):
add_document_to_index_task.delay(document_id)
return {"result": "success"}, 200
else:
raise InvalidActionError()
return {"result": "success"}, 200
class DocumentPauseApi(DocumentResource):
@ -1022,7 +1042,7 @@ api.add_resource(
)
api.add_resource(DocumentDeleteApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>")
api.add_resource(DocumentMetadataApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/metadata")
api.add_resource(DocumentStatusApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/status/<string:action>")
api.add_resource(DocumentStatusApi, "/datasets/<uuid:dataset_id>/documents/status/<string:action>/batch")
api.add_resource(DocumentPauseApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/processing/pause")
api.add_resource(DocumentRecoverApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/processing/resume")
api.add_resource(DocumentRetryApi, "/datasets/<uuid:dataset_id>/retry")

View File

@ -1,5 +1,4 @@
import uuid
from datetime import UTC, datetime
import pandas as pd
from flask import request
@ -10,7 +9,13 @@ from werkzeug.exceptions import Forbidden, NotFound
import services
from controllers.console import api
from controllers.console.app.error import ProviderNotInitializeError
from controllers.console.datasets.error import InvalidActionError, NoFileUploadedError, TooManyFilesError
from controllers.console.datasets.error import (
ChildChunkDeleteIndexError,
ChildChunkIndexingError,
InvalidActionError,
NoFileUploadedError,
TooManyFilesError,
)
from controllers.console.wraps import (
account_initialization_required,
cloud_edition_billing_knowledge_limit_check,
@ -20,15 +25,15 @@ from controllers.console.wraps import (
from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError
from core.model_manager import ModelManager
from core.model_runtime.entities.model_entities import ModelType
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from fields.segment_fields import segment_fields
from fields.segment_fields import child_chunk_fields, segment_fields
from libs.login import login_required
from models import DocumentSegment
from models.dataset import ChildChunk, DocumentSegment
from services.dataset_service import DatasetService, DocumentService, SegmentService
from services.entities.knowledge_entities.knowledge_entities import ChildChunkUpdateArgs, SegmentUpdateArgs
from services.errors.chunk import ChildChunkDeleteIndexError as ChildChunkDeleteIndexServiceError
from services.errors.chunk import ChildChunkIndexingError as ChildChunkIndexingServiceError
from tasks.batch_create_segment_to_index_task import batch_create_segment_to_index_task
from tasks.disable_segment_from_index_task import disable_segment_from_index_task
from tasks.enable_segment_to_index_task import enable_segment_to_index_task
class DatasetDocumentSegmentListApi(Resource):
@ -53,15 +58,16 @@ class DatasetDocumentSegmentListApi(Resource):
raise NotFound("Document not found.")
parser = reqparse.RequestParser()
parser.add_argument("last_id", type=str, default=None, location="args")
parser.add_argument("limit", type=int, default=20, location="args")
parser.add_argument("status", type=str, action="append", default=[], location="args")
parser.add_argument("hit_count_gte", type=int, default=None, location="args")
parser.add_argument("enabled", type=str, default="all", location="args")
parser.add_argument("keyword", type=str, default=None, location="args")
parser.add_argument("page", type=int, default=1, location="args")
args = parser.parse_args()
last_id = args["last_id"]
page = args["page"]
limit = min(args["limit"], 100)
status_list = args["status"]
hit_count_gte = args["hit_count_gte"]
@ -69,14 +75,7 @@ class DatasetDocumentSegmentListApi(Resource):
query = DocumentSegment.query.filter(
DocumentSegment.document_id == str(document_id), DocumentSegment.tenant_id == current_user.current_tenant_id
)
if last_id is not None:
last_segment = db.session.get(DocumentSegment, str(last_id))
if last_segment:
query = query.filter(DocumentSegment.position > last_segment.position)
else:
return {"data": [], "has_more": False, "limit": limit}, 200
).order_by(DocumentSegment.position.asc())
if status_list:
query = query.filter(DocumentSegment.status.in_(status_list))
@ -93,21 +92,44 @@ class DatasetDocumentSegmentListApi(Resource):
elif args["enabled"].lower() == "false":
query = query.filter(DocumentSegment.enabled == False)
total = query.count()
segments = query.order_by(DocumentSegment.position).limit(limit + 1).all()
segments = query.paginate(page=page, per_page=limit, max_per_page=100, error_out=False)
has_more = False
if len(segments) > limit:
has_more = True
segments = segments[:-1]
return {
"data": marshal(segments, segment_fields),
"doc_form": document.doc_form,
"has_more": has_more,
response = {
"data": marshal(segments.items, segment_fields),
"limit": limit,
"total": total,
}, 200
"total": segments.total,
"total_pages": segments.pages,
"page": page,
}
return response, 200
@setup_required
@login_required
@account_initialization_required
def delete(self, dataset_id, document_id):
# check dataset
dataset_id = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id)
if not dataset:
raise NotFound("Dataset not found.")
# check user's model setting
DatasetService.check_dataset_model_setting(dataset)
# check document
document_id = str(document_id)
document = DocumentService.get_document(dataset_id, document_id)
if not document:
raise NotFound("Document not found.")
segment_ids = request.args.getlist("segment_id")
# The role of the current user in the ta table must be admin or owner
if not current_user.is_editor:
raise Forbidden()
try:
DatasetService.check_dataset_permission(dataset, current_user)
except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e))
SegmentService.delete_segments(segment_ids, document, dataset)
return {"result": "success"}, 200
class DatasetDocumentSegmentApi(Resource):
@ -115,11 +137,15 @@ class DatasetDocumentSegmentApi(Resource):
@login_required
@account_initialization_required
@cloud_edition_billing_resource_check("vector_space")
def patch(self, dataset_id, segment_id, action):
def patch(self, dataset_id, document_id, action):
dataset_id = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id)
if not dataset:
raise NotFound("Dataset not found.")
document_id = str(document_id)
document = DocumentService.get_document(dataset_id, document_id)
if not document:
raise NotFound("Document not found.")
# check user's model setting
DatasetService.check_dataset_model_setting(dataset)
# The role of the current user in the ta table must be admin, owner, or editor
@ -147,59 +173,17 @@ class DatasetDocumentSegmentApi(Resource):
)
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
segment_ids = request.args.getlist("segment_id")
segment = DocumentSegment.query.filter(
DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id
).first()
if not segment:
raise NotFound("Segment not found.")
if segment.status != "completed":
raise NotFound("Segment is not completed, enable or disable function is not allowed")
document_indexing_cache_key = "document_{}_indexing".format(segment.document_id)
document_indexing_cache_key = "document_{}_indexing".format(document.id)
cache_result = redis_client.get(document_indexing_cache_key)
if cache_result is not None:
raise InvalidActionError("Document is being indexed, please try again later")
indexing_cache_key = "segment_{}_indexing".format(segment.id)
cache_result = redis_client.get(indexing_cache_key)
if cache_result is not None:
raise InvalidActionError("Segment is being indexed, please try again later")
if action == "enable":
if segment.enabled:
raise InvalidActionError("Segment is already enabled.")
segment.enabled = True
segment.disabled_at = None
segment.disabled_by = None
db.session.commit()
# Set cache to prevent indexing the same segment multiple times
redis_client.setex(indexing_cache_key, 600, 1)
enable_segment_to_index_task.delay(segment.id)
try:
SegmentService.update_segments_status(segment_ids, action, dataset, document)
except Exception as e:
raise InvalidActionError(str(e))
return {"result": "success"}, 200
elif action == "disable":
if not segment.enabled:
raise InvalidActionError("Segment is already disabled.")
segment.enabled = False
segment.disabled_at = datetime.now(UTC).replace(tzinfo=None)
segment.disabled_by = current_user.id
db.session.commit()
# Set cache to prevent indexing the same segment multiple times
redis_client.setex(indexing_cache_key, 600, 1)
disable_segment_from_index_task.delay(segment.id)
return {"result": "success"}, 200
else:
raise InvalidActionError()
class DatasetDocumentSegmentAddApi(Resource):
@ -307,9 +291,12 @@ class DatasetDocumentSegmentUpdateApi(Resource):
parser.add_argument("content", type=str, required=True, nullable=False, location="json")
parser.add_argument("answer", type=str, required=False, nullable=True, location="json")
parser.add_argument("keywords", type=list, required=False, nullable=True, location="json")
parser.add_argument(
"regenerate_child_chunks", type=bool, required=False, nullable=True, default=False, location="json"
)
args = parser.parse_args()
SegmentService.segment_create_args_validate(args, document)
segment = SegmentService.update_segment(args, segment, document, dataset)
segment = SegmentService.update_segment(SegmentUpdateArgs(**args), segment, document, dataset)
return {"data": marshal(segment, segment_fields), "doc_form": document.doc_form}, 200
@setup_required
@ -412,8 +399,248 @@ class DatasetDocumentSegmentBatchImportApi(Resource):
return {"job_id": job_id, "job_status": cache_result.decode()}, 200
class ChildChunkAddApi(Resource):
@setup_required
@login_required
@account_initialization_required
@cloud_edition_billing_resource_check("vector_space")
@cloud_edition_billing_knowledge_limit_check("add_segment")
def post(self, dataset_id, document_id, segment_id):
# check dataset
dataset_id = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id)
if not dataset:
raise NotFound("Dataset not found.")
# check document
document_id = str(document_id)
document = DocumentService.get_document(dataset_id, document_id)
if not document:
raise NotFound("Document not found.")
# check segment
segment_id = str(segment_id)
segment = DocumentSegment.query.filter(
DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id
).first()
if not segment:
raise NotFound("Segment not found.")
if not current_user.is_editor:
raise Forbidden()
# check embedding model setting
if dataset.indexing_technique == "high_quality":
try:
model_manager = ModelManager()
model_manager.get_model_instance(
tenant_id=current_user.current_tenant_id,
provider=dataset.embedding_model_provider,
model_type=ModelType.TEXT_EMBEDDING,
model=dataset.embedding_model,
)
except LLMBadRequestError:
raise ProviderNotInitializeError(
"No Embedding Model available. Please configure a valid provider "
"in the Settings -> Model Provider."
)
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
try:
DatasetService.check_dataset_permission(dataset, current_user)
except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e))
# validate args
parser = reqparse.RequestParser()
parser.add_argument("content", type=str, required=True, nullable=False, location="json")
args = parser.parse_args()
try:
child_chunk = SegmentService.create_child_chunk(args.get("content"), segment, document, dataset)
except ChildChunkIndexingServiceError as e:
raise ChildChunkIndexingError(str(e))
return {"data": marshal(child_chunk, child_chunk_fields)}, 200
@setup_required
@login_required
@account_initialization_required
def get(self, dataset_id, document_id, segment_id):
# check dataset
dataset_id = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id)
if not dataset:
raise NotFound("Dataset not found.")
# check user's model setting
DatasetService.check_dataset_model_setting(dataset)
# check document
document_id = str(document_id)
document = DocumentService.get_document(dataset_id, document_id)
if not document:
raise NotFound("Document not found.")
# check segment
segment_id = str(segment_id)
segment = DocumentSegment.query.filter(
DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id
).first()
if not segment:
raise NotFound("Segment not found.")
parser = reqparse.RequestParser()
parser.add_argument("limit", type=int, default=20, location="args")
parser.add_argument("keyword", type=str, default=None, location="args")
parser.add_argument("page", type=int, default=1, location="args")
args = parser.parse_args()
page = args["page"]
limit = min(args["limit"], 100)
keyword = args["keyword"]
child_chunks = SegmentService.get_child_chunks(segment_id, document_id, dataset_id, page, limit, keyword)
return {
"data": marshal(child_chunks.items, child_chunk_fields),
"total": child_chunks.total,
"total_pages": child_chunks.pages,
"page": page,
"limit": limit,
}, 200
@setup_required
@login_required
@account_initialization_required
@cloud_edition_billing_resource_check("vector_space")
def patch(self, dataset_id, document_id, segment_id):
# check dataset
dataset_id = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id)
if not dataset:
raise NotFound("Dataset not found.")
# check user's model setting
DatasetService.check_dataset_model_setting(dataset)
# check document
document_id = str(document_id)
document = DocumentService.get_document(dataset_id, document_id)
if not document:
raise NotFound("Document not found.")
# check segment
segment_id = str(segment_id)
segment = DocumentSegment.query.filter(
DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id
).first()
if not segment:
raise NotFound("Segment not found.")
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor:
raise Forbidden()
try:
DatasetService.check_dataset_permission(dataset, current_user)
except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e))
# validate args
parser = reqparse.RequestParser()
parser.add_argument("chunks", type=list, required=True, nullable=False, location="json")
args = parser.parse_args()
try:
chunks = [ChildChunkUpdateArgs(**chunk) for chunk in args.get("chunks")]
child_chunks = SegmentService.update_child_chunks(chunks, segment, document, dataset)
except ChildChunkIndexingServiceError as e:
raise ChildChunkIndexingError(str(e))
return {"data": marshal(child_chunks, child_chunk_fields)}, 200
class ChildChunkUpdateApi(Resource):
@setup_required
@login_required
@account_initialization_required
def delete(self, dataset_id, document_id, segment_id, child_chunk_id):
# check dataset
dataset_id = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id)
if not dataset:
raise NotFound("Dataset not found.")
# check user's model setting
DatasetService.check_dataset_model_setting(dataset)
# check document
document_id = str(document_id)
document = DocumentService.get_document(dataset_id, document_id)
if not document:
raise NotFound("Document not found.")
# check segment
segment_id = str(segment_id)
segment = DocumentSegment.query.filter(
DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id
).first()
if not segment:
raise NotFound("Segment not found.")
# check child chunk
child_chunk_id = str(child_chunk_id)
child_chunk = ChildChunk.query.filter(
ChildChunk.id == str(child_chunk_id), ChildChunk.tenant_id == current_user.current_tenant_id
).first()
if not child_chunk:
raise NotFound("Child chunk not found.")
# The role of the current user in the ta table must be admin or owner
if not current_user.is_editor:
raise Forbidden()
try:
DatasetService.check_dataset_permission(dataset, current_user)
except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e))
try:
SegmentService.delete_child_chunk(child_chunk, dataset)
except ChildChunkDeleteIndexServiceError as e:
raise ChildChunkDeleteIndexError(str(e))
return {"result": "success"}, 200
@setup_required
@login_required
@account_initialization_required
@cloud_edition_billing_resource_check("vector_space")
def patch(self, dataset_id, document_id, segment_id, child_chunk_id):
# check dataset
dataset_id = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id)
if not dataset:
raise NotFound("Dataset not found.")
# check user's model setting
DatasetService.check_dataset_model_setting(dataset)
# check document
document_id = str(document_id)
document = DocumentService.get_document(dataset_id, document_id)
if not document:
raise NotFound("Document not found.")
# check segment
segment_id = str(segment_id)
segment = DocumentSegment.query.filter(
DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id
).first()
if not segment:
raise NotFound("Segment not found.")
# check child chunk
child_chunk_id = str(child_chunk_id)
child_chunk = ChildChunk.query.filter(
ChildChunk.id == str(child_chunk_id), ChildChunk.tenant_id == current_user.current_tenant_id
).first()
if not child_chunk:
raise NotFound("Child chunk not found.")
# The role of the current user in the ta table must be admin or owner
if not current_user.is_editor:
raise Forbidden()
try:
DatasetService.check_dataset_permission(dataset, current_user)
except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e))
# validate args
parser = reqparse.RequestParser()
parser.add_argument("content", type=str, required=True, nullable=False, location="json")
args = parser.parse_args()
try:
child_chunk = SegmentService.update_child_chunk(
args.get("content"), child_chunk, segment, document, dataset
)
except ChildChunkIndexingServiceError as e:
raise ChildChunkIndexingError(str(e))
return {"data": marshal(child_chunk, child_chunk_fields)}, 200
api.add_resource(DatasetDocumentSegmentListApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments")
api.add_resource(DatasetDocumentSegmentApi, "/datasets/<uuid:dataset_id>/segments/<uuid:segment_id>/<string:action>")
api.add_resource(
DatasetDocumentSegmentApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segment/<string:action>"
)
api.add_resource(DatasetDocumentSegmentAddApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segment")
api.add_resource(
DatasetDocumentSegmentUpdateApi,
@ -424,3 +651,11 @@ api.add_resource(
"/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments/batch_import",
"/datasets/batch_import_status/<uuid:job_id>",
)
api.add_resource(
ChildChunkAddApi,
"/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments/<uuid:segment_id>/child_chunks",
)
api.add_resource(
ChildChunkUpdateApi,
"/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments/<uuid:segment_id>/child_chunks/<uuid:child_chunk_id>",
)

View File

@ -89,3 +89,15 @@ class IndexingEstimateError(BaseHTTPException):
error_code = "indexing_estimate_error"
description = "Knowledge indexing estimate failed: {message}"
code = 500
class ChildChunkIndexingError(BaseHTTPException):
error_code = "child_chunk_indexing_error"
description = "Create child chunk index failed: {message}"
code = 500
class ChildChunkDeleteIndexError(BaseHTTPException):
error_code = "child_chunk_delete_index_error"
description = "Delete child chunk index failed: {message}"
code = 500

View File

@ -16,6 +16,7 @@ from extensions.ext_database import db
from fields.segment_fields import segment_fields
from models.dataset import Dataset, DocumentSegment
from services.dataset_service import DatasetService, DocumentService, SegmentService
from services.entities.knowledge_entities.knowledge_entities import SegmentUpdateArgs
class SegmentApi(DatasetApiResource):
@ -193,7 +194,7 @@ class DatasetSegmentApi(DatasetApiResource):
args = parser.parse_args()
SegmentService.segment_create_args_validate(args["segment"], document)
segment = SegmentService.update_segment(args["segment"], segment, document, dataset)
segment = SegmentService.update_segment(SegmentUpdateArgs(**args["segment"]), segment, document, dataset)
return {"data": marshal(segment, segment_fields), "doc_form": document.doc_form}, 200

View File

@ -0,0 +1,19 @@
from typing import Optional
from pydantic import BaseModel
class PreviewDetail(BaseModel):
content: str
child_chunks: Optional[list[str]] = None
class QAPreviewDetail(BaseModel):
question: str
answer: str
class IndexingEstimate(BaseModel):
total_segments: int
preview: list[PreviewDetail]
qa_preview: Optional[list[QAPreviewDetail]] = None

View File

@ -8,34 +8,34 @@ import time
import uuid
from typing import Any, Optional, cast
from flask import Flask, current_app
from flask import current_app
from flask_login import current_user # type: ignore
from sqlalchemy.orm.exc import ObjectDeletedError
from configs import dify_config
from core.entities.knowledge_entities import IndexingEstimate, PreviewDetail, QAPreviewDetail
from core.errors.error import ProviderTokenNotInitError
from core.llm_generator.llm_generator import LLMGenerator
from core.model_manager import ModelInstance, ModelManager
from core.model_runtime.entities.model_entities import ModelType
from core.rag.cleaner.clean_processor import CleanProcessor
from core.rag.datasource.keyword.keyword_factory import Keyword
from core.rag.docstore.dataset_docstore import DatasetDocumentStore
from core.rag.extractor.entity.extract_setting import ExtractSetting
from core.rag.index_processor.constant.index_type import IndexType
from core.rag.index_processor.index_processor_base import BaseIndexProcessor
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from core.rag.models.document import Document
from core.rag.models.document import ChildDocument, Document
from core.rag.splitter.fixed_text_splitter import (
EnhanceRecursiveCharacterTextSplitter,
FixedRecursiveCharacterTextSplitter,
)
from core.rag.splitter.text_splitter import TextSplitter
from core.tools.utils.text_processing_utils import remove_leading_symbols
from core.tools.utils.web_reader_tool import get_image_upload_file_ids
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from extensions.ext_storage import storage
from libs import helper
from models.dataset import Dataset, DatasetProcessRule, DocumentSegment
from models.dataset import ChildChunk, Dataset, DatasetProcessRule, DocumentSegment
from models.dataset import Document as DatasetDocument
from models.model import UploadFile
from services.feature_service import FeatureService
@ -115,6 +115,9 @@ class IndexingRunner:
for document_segment in document_segments:
db.session.delete(document_segment)
if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX:
# delete child chunks
db.session.query(ChildChunk).filter(ChildChunk.segment_id == document_segment.id).delete()
db.session.commit()
# get the process rule
processing_rule = (
@ -183,7 +186,22 @@ class IndexingRunner:
"dataset_id": document_segment.dataset_id,
},
)
if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX:
child_chunks = document_segment.child_chunks
if child_chunks:
child_documents = []
for child_chunk in child_chunks:
child_document = ChildDocument(
page_content=child_chunk.content,
metadata={
"doc_id": child_chunk.index_node_id,
"doc_hash": child_chunk.index_node_hash,
"document_id": document_segment.document_id,
"dataset_id": document_segment.dataset_id,
},
)
child_documents.append(child_document)
document.children = child_documents
documents.append(document)
# build index
@ -222,7 +240,7 @@ class IndexingRunner:
doc_language: str = "English",
dataset_id: Optional[str] = None,
indexing_technique: str = "economy",
) -> dict:
) -> IndexingEstimate:
"""
Estimate the indexing for the document.
"""
@ -258,31 +276,38 @@ class IndexingRunner:
tenant_id=tenant_id,
model_type=ModelType.TEXT_EMBEDDING,
)
preview_texts: list[str] = []
preview_texts = []
total_segments = 0
index_type = doc_form
index_processor = IndexProcessorFactory(index_type).init_index_processor()
all_text_docs = []
for extract_setting in extract_settings:
# extract
text_docs = index_processor.extract(extract_setting, process_rule_mode=tmp_processing_rule["mode"])
all_text_docs.extend(text_docs)
processing_rule = DatasetProcessRule(
mode=tmp_processing_rule["mode"], rules=json.dumps(tmp_processing_rule["rules"])
)
# get splitter
splitter = self._get_splitter(processing_rule, embedding_model_instance)
# split to documents
documents = self._split_to_documents_for_estimate(
text_docs=text_docs, splitter=splitter, processing_rule=processing_rule
text_docs = index_processor.extract(extract_setting, process_rule_mode=tmp_processing_rule["mode"])
documents = index_processor.transform(
text_docs,
embedding_model_instance=embedding_model_instance,
process_rule=processing_rule.to_dict(),
tenant_id=current_user.current_tenant_id,
doc_language=doc_language,
preview=True,
)
total_segments += len(documents)
for document in documents:
if len(preview_texts) < 5:
preview_texts.append(document.page_content)
if len(preview_texts) < 10:
if doc_form and doc_form == "qa_model":
preview_detail = QAPreviewDetail(
question=document.page_content, answer=document.metadata.get("answer")
)
preview_texts.append(preview_detail)
else:
preview_detail = PreviewDetail(content=document.page_content)
if document.children:
preview_detail.child_chunks = [child.page_content for child in document.children]
preview_texts.append(preview_detail)
# delete image files and related db records
image_upload_file_ids = get_image_upload_file_ids(document.page_content)
@ -299,15 +324,8 @@ class IndexingRunner:
db.session.delete(image_file)
if doc_form and doc_form == "qa_model":
if len(preview_texts) > 0:
# qa model document
response = LLMGenerator.generate_qa_document(
current_user.current_tenant_id, preview_texts[0], doc_language
)
document_qa_list = self.format_split_text(response)
return {"total_segments": total_segments * 20, "qa_preview": document_qa_list, "preview": preview_texts}
return {"total_segments": total_segments, "preview": preview_texts}
return IndexingEstimate(total_segments=total_segments * 20, qa_preview=preview_texts, preview=[])
return IndexingEstimate(total_segments=total_segments, preview=preview_texts)
def _extract(
self, index_processor: BaseIndexProcessor, dataset_document: DatasetDocument, process_rule: dict
@ -401,31 +419,26 @@ class IndexingRunner:
@staticmethod
def _get_splitter(
processing_rule: DatasetProcessRule, embedding_model_instance: Optional[ModelInstance]
processing_rule_mode: str,
max_tokens: int,
chunk_overlap: int,
separator: str,
embedding_model_instance: Optional[ModelInstance],
) -> TextSplitter:
"""
Get the NodeParser object according to the processing rule.
"""
character_splitter: TextSplitter
if processing_rule.mode == "custom":
if processing_rule_mode in ["custom", "hierarchical"]:
# The user-defined segmentation rule
rules = json.loads(processing_rule.rules)
segmentation = rules["segmentation"]
max_segmentation_tokens_length = dify_config.INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH
if segmentation["max_tokens"] < 50 or segmentation["max_tokens"] > max_segmentation_tokens_length:
if max_tokens < 50 or max_tokens > max_segmentation_tokens_length:
raise ValueError(f"Custom segment length should be between 50 and {max_segmentation_tokens_length}.")
separator = segmentation["separator"]
if separator:
separator = separator.replace("\\n", "\n")
if segmentation.get("chunk_overlap"):
chunk_overlap = segmentation["chunk_overlap"]
else:
chunk_overlap = 0
character_splitter = FixedRecursiveCharacterTextSplitter.from_encoder(
chunk_size=segmentation["max_tokens"],
chunk_size=max_tokens,
chunk_overlap=chunk_overlap,
fixed_separator=separator,
separators=["\n\n", "", ". ", " ", ""],
@ -443,142 +456,6 @@ class IndexingRunner:
return character_splitter
def _step_split(
self,
text_docs: list[Document],
splitter: TextSplitter,
dataset: Dataset,
dataset_document: DatasetDocument,
processing_rule: DatasetProcessRule,
) -> list[Document]:
"""
Split the text documents into documents and save them to the document segment.
"""
documents = self._split_to_documents(
text_docs=text_docs,
splitter=splitter,
processing_rule=processing_rule,
tenant_id=dataset.tenant_id,
document_form=dataset_document.doc_form,
document_language=dataset_document.doc_language,
)
# save node to document segment
doc_store = DatasetDocumentStore(
dataset=dataset, user_id=dataset_document.created_by, document_id=dataset_document.id
)
# add document segments
doc_store.add_documents(documents)
# update document status to indexing
cur_time = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
self._update_document_index_status(
document_id=dataset_document.id,
after_indexing_status="indexing",
extra_update_params={
DatasetDocument.cleaning_completed_at: cur_time,
DatasetDocument.splitting_completed_at: cur_time,
},
)
# update segment status to indexing
self._update_segments_by_document(
dataset_document_id=dataset_document.id,
update_params={
DocumentSegment.status: "indexing",
DocumentSegment.indexing_at: datetime.datetime.now(datetime.UTC).replace(tzinfo=None),
},
)
return documents
def _split_to_documents(
self,
text_docs: list[Document],
splitter: TextSplitter,
processing_rule: DatasetProcessRule,
tenant_id: str,
document_form: str,
document_language: str,
) -> list[Document]:
"""
Split the text documents into nodes.
"""
all_documents: list[Document] = []
all_qa_documents: list[Document] = []
for text_doc in text_docs:
# document clean
document_text = self._document_clean(text_doc.page_content, processing_rule)
text_doc.page_content = document_text
# parse document to nodes
documents = splitter.split_documents([text_doc])
split_documents = []
for document_node in documents:
if document_node.page_content.strip():
if document_node.metadata is not None:
doc_id = str(uuid.uuid4())
hash = helper.generate_text_hash(document_node.page_content)
document_node.metadata["doc_id"] = doc_id
document_node.metadata["doc_hash"] = hash
# delete Splitter character
page_content = document_node.page_content
document_node.page_content = remove_leading_symbols(page_content)
if document_node.page_content:
split_documents.append(document_node)
all_documents.extend(split_documents)
# processing qa document
if document_form == "qa_model":
for i in range(0, len(all_documents), 10):
threads = []
sub_documents = all_documents[i : i + 10]
for doc in sub_documents:
document_format_thread = threading.Thread(
target=self.format_qa_document,
kwargs={
"flask_app": current_app._get_current_object(), # type: ignore
"tenant_id": tenant_id,
"document_node": doc,
"all_qa_documents": all_qa_documents,
"document_language": document_language,
},
)
threads.append(document_format_thread)
document_format_thread.start()
for thread in threads:
thread.join()
return all_qa_documents
return all_documents
def format_qa_document(self, flask_app: Flask, tenant_id: str, document_node, all_qa_documents, document_language):
format_documents = []
if document_node.page_content is None or not document_node.page_content.strip():
return
with flask_app.app_context():
try:
# qa model document
response = LLMGenerator.generate_qa_document(tenant_id, document_node.page_content, document_language)
document_qa_list = self.format_split_text(response)
qa_documents = []
for result in document_qa_list:
qa_document = Document(
page_content=result["question"], metadata=document_node.metadata.model_copy()
)
if qa_document.metadata is not None:
doc_id = str(uuid.uuid4())
hash = helper.generate_text_hash(result["question"])
qa_document.metadata["answer"] = result["answer"]
qa_document.metadata["doc_id"] = doc_id
qa_document.metadata["doc_hash"] = hash
qa_documents.append(qa_document)
format_documents.extend(qa_documents)
except Exception as e:
logging.exception("Failed to format qa document")
all_qa_documents.extend(format_documents)
def _split_to_documents_for_estimate(
self, text_docs: list[Document], splitter: TextSplitter, processing_rule: DatasetProcessRule
) -> list[Document]:
@ -624,11 +501,11 @@ class IndexingRunner:
return document_text
@staticmethod
def format_split_text(text):
def format_split_text(text: str) -> list[QAPreviewDetail]:
regex = r"Q\d+:\s*(.*?)\s*A\d+:\s*([\s\S]*?)(?=Q\d+:|$)"
matches = re.findall(regex, text, re.UNICODE)
return [{"question": q, "answer": re.sub(r"\n\s*", "\n", a.strip())} for q, a in matches if q and a]
return [QAPreviewDetail(question=q, answer=re.sub(r"\n\s*", "\n", a.strip())) for q, a in matches if q and a]
def _load(
self,
@ -654,13 +531,14 @@ class IndexingRunner:
indexing_start_at = time.perf_counter()
tokens = 0
chunk_size = 10
if dataset_document.doc_form != IndexType.PARENT_CHILD_INDEX:
# create keyword index
create_keyword_thread = threading.Thread(
target=self._process_keyword_index,
args=(current_app._get_current_object(), dataset.id, dataset_document.id, documents), # type: ignore
args=(current_app._get_current_object(), dataset.id, dataset_document.id, documents),
)
create_keyword_thread.start()
if dataset.indexing_technique == "high_quality":
with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor:
futures = []
@ -680,7 +558,7 @@ class IndexingRunner:
for future in futures:
tokens += future.result()
if dataset_document.doc_form != IndexType.PARENT_CHILD_INDEX:
create_keyword_thread.join()
indexing_end_at = time.perf_counter()
@ -793,28 +671,6 @@ class IndexingRunner:
DocumentSegment.query.filter_by(document_id=dataset_document_id).update(update_params)
db.session.commit()
@staticmethod
def batch_add_segments(segments: list[DocumentSegment], dataset: Dataset):
"""
Batch add segments index processing
"""
documents = []
for segment in segments:
document = Document(
page_content=segment.content,
metadata={
"doc_id": segment.index_node_id,
"doc_hash": segment.index_node_hash,
"document_id": segment.document_id,
"dataset_id": segment.dataset_id,
},
)
documents.append(document)
# save vector index
index_type = dataset.doc_form
index_processor = IndexProcessorFactory(index_type).init_index_processor()
index_processor.load(dataset, documents)
def _transform(
self,
index_processor: BaseIndexProcessor,
@ -856,7 +712,7 @@ class IndexingRunner:
)
# add document segments
doc_store.add_documents(documents)
doc_store.add_documents(docs=documents, save_child=dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX)
# update document status to indexing
cur_time = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)

View File

@ -6,11 +6,14 @@ from flask import Flask, current_app
from core.rag.data_post_processor.data_post_processor import DataPostProcessor
from core.rag.datasource.keyword.keyword_factory import Keyword
from core.rag.datasource.vdb.vector_factory import Vector
from core.rag.embedding.retrieval import RetrievalSegments
from core.rag.index_processor.constant.index_type import IndexType
from core.rag.models.document import Document
from core.rag.rerank.rerank_type import RerankMode
from core.rag.retrieval.retrieval_methods import RetrievalMethod
from extensions.ext_database import db
from models.dataset import Dataset
from models.dataset import ChildChunk, Dataset, DocumentSegment
from models.dataset import Document as DatasetDocument
from services.external_knowledge_service import ExternalDatasetService
default_retrieval_model = {
@ -248,3 +251,88 @@ class RetrievalService:
@staticmethod
def escape_query_for_search(query: str) -> str:
return query.replace('"', '\\"')
@staticmethod
def format_retrieval_documents(documents: list[Document]) -> list[RetrievalSegments]:
records = []
include_segment_ids = []
segment_child_map = {}
for document in documents:
document_id = document.metadata["document_id"]
dataset_document = db.session.query(DatasetDocument).filter(DatasetDocument.id == document_id).first()
if dataset_document and dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX:
child_index_node_id = document.metadata["doc_id"]
result = (
db.session.query(ChildChunk, DocumentSegment)
.join(DocumentSegment, ChildChunk.segment_id == DocumentSegment.id)
.filter(
ChildChunk.index_node_id == child_index_node_id,
DocumentSegment.dataset_id == dataset_document.dataset_id,
DocumentSegment.enabled == True,
DocumentSegment.status == "completed",
)
.first()
)
if result:
child_chunk, segment = result
if not segment:
continue
if segment.id not in include_segment_ids:
include_segment_ids.append(segment.id)
child_chunk_detail = {
"id": child_chunk.id,
"content": child_chunk.content,
"position": child_chunk.position,
"score": document.metadata.get("score", 0.0),
}
map_detail = {
"max_score": document.metadata.get("score", 0.0),
"child_chunks": [child_chunk_detail],
}
segment_child_map[segment.id] = map_detail
record = {
"segment": segment,
}
records.append(record)
else:
child_chunk_detail = {
"id": child_chunk.id,
"content": child_chunk.content,
"position": child_chunk.position,
"score": document.metadata.get("score", 0.0),
}
segment_child_map[segment.id]["child_chunks"].append(child_chunk_detail)
segment_child_map[segment.id]["max_score"] = max(
segment_child_map[segment.id]["max_score"], document.metadata.get("score", 0.0)
)
else:
continue
else:
index_node_id = document.metadata["doc_id"]
segment = (
db.session.query(DocumentSegment)
.filter(
DocumentSegment.dataset_id == dataset_document.dataset_id,
DocumentSegment.enabled == True,
DocumentSegment.status == "completed",
DocumentSegment.index_node_id == index_node_id,
)
.first()
)
if not segment:
continue
include_segment_ids.append(segment.id)
record = {
"segment": segment,
"score": document.metadata.get("score", None),
}
records.append(record)
for record in records:
if record["segment"].id in segment_child_map:
record["child_chunks"] = segment_child_map[record["segment"].id].get("child_chunks", None)
record["score"] = segment_child_map[record["segment"].id]["max_score"]
return [RetrievalSegments(**record) for record in records]

View File

@ -7,7 +7,7 @@ from core.model_manager import ModelManager
from core.model_runtime.entities.model_entities import ModelType
from core.rag.models.document import Document
from extensions.ext_database import db
from models.dataset import Dataset, DocumentSegment
from models.dataset import ChildChunk, Dataset, DocumentSegment
class DatasetDocumentStore:
@ -60,7 +60,7 @@ class DatasetDocumentStore:
return output
def add_documents(self, docs: Sequence[Document], allow_update: bool = True) -> None:
def add_documents(self, docs: Sequence[Document], allow_update: bool = True, save_child: bool = False) -> None:
max_position = (
db.session.query(func.max(DocumentSegment.position))
.filter(DocumentSegment.document_id == self._document_id)
@ -120,6 +120,23 @@ class DatasetDocumentStore:
segment_document.answer = doc.metadata.pop("answer", "")
db.session.add(segment_document)
db.session.flush()
if save_child:
for postion, child in enumerate(doc.children, start=1):
child_segment = ChildChunk(
tenant_id=self._dataset.tenant_id,
dataset_id=self._dataset.id,
document_id=self._document_id,
segment_id=segment_document.id,
position=postion,
index_node_id=child.metadata["doc_id"],
index_node_hash=child.metadata["doc_hash"],
content=child.page_content,
word_count=len(child.page_content),
type="automatic",
created_by=self._user_id,
)
db.session.add(child_segment)
else:
segment_document.content = doc.page_content
if doc.metadata.get("answer"):
@ -127,6 +144,30 @@ class DatasetDocumentStore:
segment_document.index_node_hash = doc.metadata["doc_hash"]
segment_document.word_count = len(doc.page_content)
segment_document.tokens = tokens
if save_child and doc.children:
# delete the existing child chunks
db.session.query(ChildChunk).filter(
ChildChunk.tenant_id == self._dataset.tenant_id,
ChildChunk.dataset_id == self._dataset.id,
ChildChunk.document_id == self._document_id,
ChildChunk.segment_id == segment_document.id,
).delete()
# add new child chunks
for position, child in enumerate(doc.children, start=1):
child_segment = ChildChunk(
tenant_id=self._dataset.tenant_id,
dataset_id=self._dataset.id,
document_id=self._document_id,
segment_id=segment_document.id,
position=position,
index_node_id=child.metadata["doc_id"],
index_node_hash=child.metadata["doc_hash"],
content=child.page_content,
word_count=len(child.page_content),
type="automatic",
created_by=self._user_id,
)
db.session.add(child_segment)
db.session.commit()

View File

@ -0,0 +1,23 @@
from typing import Optional
from pydantic import BaseModel
from models.dataset import DocumentSegment
class RetrievalChildChunk(BaseModel):
"""Retrieval segments."""
id: str
content: str
score: float
position: int
class RetrievalSegments(BaseModel):
"""Retrieval segments."""
model_config = {"arbitrary_types_allowed": True}
segment: DocumentSegment
child_chunks: Optional[list[RetrievalChildChunk]] = None
score: Optional[float] = None

View File

@ -24,7 +24,6 @@ from core.rag.extractor.unstructured.unstructured_markdown_extractor import Unst
from core.rag.extractor.unstructured.unstructured_msg_extractor import UnstructuredMsgExtractor
from core.rag.extractor.unstructured.unstructured_ppt_extractor import UnstructuredPPTExtractor
from core.rag.extractor.unstructured.unstructured_pptx_extractor import UnstructuredPPTXExtractor
from core.rag.extractor.unstructured.unstructured_text_extractor import UnstructuredTextExtractor
from core.rag.extractor.unstructured.unstructured_xml_extractor import UnstructuredXmlExtractor
from core.rag.extractor.word_extractor import WordExtractor
from core.rag.models.document import Document
@ -141,11 +140,7 @@ class ExtractProcessor:
extractor = UnstructuredEpubExtractor(file_path, unstructured_api_url, unstructured_api_key)
else:
# txt
extractor = (
UnstructuredTextExtractor(file_path, unstructured_api_url)
if is_automatic
else TextExtractor(file_path, autodetect_encoding=True)
)
extractor = TextExtractor(file_path, autodetect_encoding=True)
else:
if file_extension in {".xlsx", ".xls"}:
extractor = ExcelExtractor(file_path)

View File

@ -267,8 +267,10 @@ class WordExtractor(BaseExtractor):
if isinstance(element.tag, str) and element.tag.endswith("p"): # paragraph
para = paragraphs.pop(0)
parsed_paragraph = parse_paragraph(para)
if parsed_paragraph:
if parsed_paragraph.strip():
content.append(parsed_paragraph)
else:
content.append("\n")
elif isinstance(element.tag, str) and element.tag.endswith("tbl"): # table
table = tables.pop(0)
content.append(self._table_to_markdown(table, image_map))

View File

@ -1,8 +1,7 @@
from enum import Enum
class IndexType(Enum):
class IndexType(str, Enum):
PARAGRAPH_INDEX = "text_model"
QA_INDEX = "qa_model"
PARENT_CHILD_INDEX = "parent_child_index"
SUMMARY_INDEX = "summary_index"
PARENT_CHILD_INDEX = "hierarchical_model"

View File

@ -27,10 +27,10 @@ class BaseIndexProcessor(ABC):
raise NotImplementedError
@abstractmethod
def load(self, dataset: Dataset, documents: list[Document], with_keywords: bool = True):
def load(self, dataset: Dataset, documents: list[Document], with_keywords: bool = True, **kwargs):
raise NotImplementedError
def clean(self, dataset: Dataset, node_ids: Optional[list[str]], with_keywords: bool = True):
def clean(self, dataset: Dataset, node_ids: Optional[list[str]], with_keywords: bool = True, **kwargs):
raise NotImplementedError
@abstractmethod
@ -45,26 +45,29 @@ class BaseIndexProcessor(ABC):
) -> list[Document]:
raise NotImplementedError
def _get_splitter(self, processing_rule: dict, embedding_model_instance: Optional[ModelInstance]) -> TextSplitter:
def _get_splitter(
self,
processing_rule_mode: str,
max_tokens: int,
chunk_overlap: int,
separator: str,
embedding_model_instance: Optional[ModelInstance],
) -> TextSplitter:
"""
Get the NodeParser object according to the processing rule.
"""
character_splitter: TextSplitter
if processing_rule["mode"] == "custom":
if processing_rule_mode in ["custom", "hierarchical"]:
# The user-defined segmentation rule
rules = processing_rule["rules"]
segmentation = rules["segmentation"]
max_segmentation_tokens_length = dify_config.INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH
if segmentation["max_tokens"] < 50 or segmentation["max_tokens"] > max_segmentation_tokens_length:
if max_tokens < 50 or max_tokens > max_segmentation_tokens_length:
raise ValueError(f"Custom segment length should be between 50 and {max_segmentation_tokens_length}.")
separator = segmentation["separator"]
if separator:
separator = separator.replace("\\n", "\n")
character_splitter = FixedRecursiveCharacterTextSplitter.from_encoder(
chunk_size=segmentation["max_tokens"],
chunk_overlap=segmentation.get("chunk_overlap", 0) or 0,
chunk_size=max_tokens,
chunk_overlap=chunk_overlap,
fixed_separator=separator,
separators=["\n\n", "", ". ", " ", ""],
embedding_model_instance=embedding_model_instance,

View File

@ -3,6 +3,7 @@
from core.rag.index_processor.constant.index_type import IndexType
from core.rag.index_processor.index_processor_base import BaseIndexProcessor
from core.rag.index_processor.processor.paragraph_index_processor import ParagraphIndexProcessor
from core.rag.index_processor.processor.parent_child_index_processor import ParentChildIndexProcessor
from core.rag.index_processor.processor.qa_index_processor import QAIndexProcessor
@ -18,9 +19,11 @@ class IndexProcessorFactory:
if not self._index_type:
raise ValueError("Index type must be specified.")
if self._index_type == IndexType.PARAGRAPH_INDEX.value:
if self._index_type == IndexType.PARAGRAPH_INDEX:
return ParagraphIndexProcessor()
elif self._index_type == IndexType.QA_INDEX.value:
elif self._index_type == IndexType.QA_INDEX:
return QAIndexProcessor()
elif self._index_type == IndexType.PARENT_CHILD_INDEX:
return ParentChildIndexProcessor()
else:
raise ValueError(f"Index type {self._index_type} is not supported.")

View File

@ -13,21 +13,34 @@ from core.rag.index_processor.index_processor_base import BaseIndexProcessor
from core.rag.models.document import Document
from core.tools.utils.text_processing_utils import remove_leading_symbols
from libs import helper
from models.dataset import Dataset
from models.dataset import Dataset, DatasetProcessRule
from services.entities.knowledge_entities.knowledge_entities import Rule
class ParagraphIndexProcessor(BaseIndexProcessor):
def extract(self, extract_setting: ExtractSetting, **kwargs) -> list[Document]:
text_docs = ExtractProcessor.extract(
extract_setting=extract_setting, is_automatic=kwargs.get("process_rule_mode") == "automatic"
extract_setting=extract_setting,
is_automatic=(
kwargs.get("process_rule_mode") == "automatic" or kwargs.get("process_rule_mode") == "hierarchical"
),
)
return text_docs
def transform(self, documents: list[Document], **kwargs) -> list[Document]:
process_rule = kwargs.get("process_rule")
if process_rule.get("mode") == "automatic":
automatic_rule = DatasetProcessRule.AUTOMATIC_RULES
rules = Rule(**automatic_rule)
else:
rules = Rule(**process_rule.get("rules"))
# Split the text documents into nodes.
splitter = self._get_splitter(
processing_rule=kwargs.get("process_rule", {}),
processing_rule_mode=process_rule.get("mode"),
max_tokens=rules.segmentation.max_tokens,
chunk_overlap=rules.segmentation.chunk_overlap,
separator=rules.segmentation.separator,
embedding_model_instance=kwargs.get("embedding_model_instance"),
)
all_documents = []
@ -53,15 +66,19 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
all_documents.extend(split_documents)
return all_documents
def load(self, dataset: Dataset, documents: list[Document], with_keywords: bool = True):
def load(self, dataset: Dataset, documents: list[Document], with_keywords: bool = True, **kwargs):
if dataset.indexing_technique == "high_quality":
vector = Vector(dataset)
vector.create(documents)
if with_keywords:
keywords_list = kwargs.get("keywords_list")
keyword = Keyword(dataset)
keyword.create(documents)
if keywords_list and len(keywords_list) > 0:
keyword.add_texts(documents, keywords_list=keywords_list)
else:
keyword.add_texts(documents)
def clean(self, dataset: Dataset, node_ids: Optional[list[str]], with_keywords: bool = True):
def clean(self, dataset: Dataset, node_ids: Optional[list[str]], with_keywords: bool = True, **kwargs):
if dataset.indexing_technique == "high_quality":
vector = Vector(dataset)
if node_ids:

View File

@ -0,0 +1,189 @@
"""Paragraph index processor."""
import uuid
from typing import Optional
from core.model_manager import ModelInstance
from core.rag.cleaner.clean_processor import CleanProcessor
from core.rag.datasource.retrieval_service import RetrievalService
from core.rag.datasource.vdb.vector_factory import Vector
from core.rag.extractor.entity.extract_setting import ExtractSetting
from core.rag.extractor.extract_processor import ExtractProcessor
from core.rag.index_processor.index_processor_base import BaseIndexProcessor
from core.rag.models.document import ChildDocument, Document
from extensions.ext_database import db
from libs import helper
from models.dataset import ChildChunk, Dataset, DocumentSegment
from services.entities.knowledge_entities.knowledge_entities import ParentMode, Rule
class ParentChildIndexProcessor(BaseIndexProcessor):
def extract(self, extract_setting: ExtractSetting, **kwargs) -> list[Document]:
text_docs = ExtractProcessor.extract(
extract_setting=extract_setting,
is_automatic=(
kwargs.get("process_rule_mode") == "automatic" or kwargs.get("process_rule_mode") == "hierarchical"
),
)
return text_docs
def transform(self, documents: list[Document], **kwargs) -> list[Document]:
process_rule = kwargs.get("process_rule")
rules = Rule(**process_rule.get("rules"))
all_documents = []
if rules.parent_mode == ParentMode.PARAGRAPH:
# Split the text documents into nodes.
splitter = self._get_splitter(
processing_rule_mode=process_rule.get("mode"),
max_tokens=rules.segmentation.max_tokens,
chunk_overlap=rules.segmentation.chunk_overlap,
separator=rules.segmentation.separator,
embedding_model_instance=kwargs.get("embedding_model_instance"),
)
for document in documents:
# document clean
document_text = CleanProcessor.clean(document.page_content, process_rule)
document.page_content = document_text
# parse document to nodes
document_nodes = splitter.split_documents([document])
split_documents = []
for document_node in document_nodes:
if document_node.page_content.strip():
doc_id = str(uuid.uuid4())
hash = helper.generate_text_hash(document_node.page_content)
document_node.metadata["doc_id"] = doc_id
document_node.metadata["doc_hash"] = hash
# delete Splitter character
page_content = document_node.page_content
if page_content.startswith(".") or page_content.startswith(""):
page_content = page_content[1:].strip()
else:
page_content = page_content
if len(page_content) > 0:
document_node.page_content = page_content
# parse document to child nodes
child_nodes = self._split_child_nodes(
document_node, rules, process_rule.get("mode"), kwargs.get("embedding_model_instance")
)
document_node.children = child_nodes
split_documents.append(document_node)
all_documents.extend(split_documents)
elif rules.parent_mode == ParentMode.FULL_DOC:
page_content = "\n".join([document.page_content for document in documents])
document = Document(page_content=page_content, metadata=documents[0].metadata)
# parse document to child nodes
child_nodes = self._split_child_nodes(
document, rules, process_rule.get("mode"), kwargs.get("embedding_model_instance")
)
document.children = child_nodes
doc_id = str(uuid.uuid4())
hash = helper.generate_text_hash(document.page_content)
document.metadata["doc_id"] = doc_id
document.metadata["doc_hash"] = hash
all_documents.append(document)
return all_documents
def load(self, dataset: Dataset, documents: list[Document], with_keywords: bool = True, **kwargs):
if dataset.indexing_technique == "high_quality":
vector = Vector(dataset)
for document in documents:
child_documents = document.children
if child_documents:
formatted_child_documents = [
Document(**child_document.model_dump()) for child_document in child_documents
]
vector.create(formatted_child_documents)
def clean(self, dataset: Dataset, node_ids: Optional[list[str]], with_keywords: bool = True, **kwargs):
# node_ids is segment's node_ids
if dataset.indexing_technique == "high_quality":
delete_child_chunks = kwargs.get("delete_child_chunks") or False
vector = Vector(dataset)
if node_ids:
child_node_ids = (
db.session.query(ChildChunk.index_node_id)
.join(DocumentSegment, ChildChunk.segment_id == DocumentSegment.id)
.filter(
DocumentSegment.dataset_id == dataset.id,
DocumentSegment.index_node_id.in_(node_ids),
ChildChunk.dataset_id == dataset.id,
)
.all()
)
child_node_ids = [child_node_id[0] for child_node_id in child_node_ids]
vector.delete_by_ids(child_node_ids)
if delete_child_chunks:
db.session.query(ChildChunk).filter(
ChildChunk.dataset_id == dataset.id, ChildChunk.index_node_id.in_(child_node_ids)
).delete()
db.session.commit()
else:
vector.delete()
if delete_child_chunks:
db.session.query(ChildChunk).filter(ChildChunk.dataset_id == dataset.id).delete()
db.session.commit()
def retrieve(
self,
retrieval_method: str,
query: str,
dataset: Dataset,
top_k: int,
score_threshold: float,
reranking_model: dict,
) -> list[Document]:
# Set search parameters.
results = RetrievalService.retrieve(
retrieval_method=retrieval_method,
dataset_id=dataset.id,
query=query,
top_k=top_k,
score_threshold=score_threshold,
reranking_model=reranking_model,
)
# Organize results.
docs = []
for result in results:
metadata = result.metadata
metadata["score"] = result.score
if result.score > score_threshold:
doc = Document(page_content=result.page_content, metadata=metadata)
docs.append(doc)
return docs
def _split_child_nodes(
self,
document_node: Document,
rules: Rule,
process_rule_mode: str,
embedding_model_instance: Optional[ModelInstance],
) -> list[ChildDocument]:
child_splitter = self._get_splitter(
processing_rule_mode=process_rule_mode,
max_tokens=rules.subchunk_segmentation.max_tokens,
chunk_overlap=rules.subchunk_segmentation.chunk_overlap,
separator=rules.subchunk_segmentation.separator,
embedding_model_instance=embedding_model_instance,
)
# parse document to child nodes
child_nodes = []
child_documents = child_splitter.split_documents([document_node])
for child_document_node in child_documents:
if child_document_node.page_content.strip():
doc_id = str(uuid.uuid4())
hash = helper.generate_text_hash(child_document_node.page_content)
child_document = ChildDocument(
page_content=child_document_node.page_content, metadata=document_node.metadata
)
child_document.metadata["doc_id"] = doc_id
child_document.metadata["doc_hash"] = hash
child_page_content = child_document.page_content
if child_page_content.startswith(".") or child_page_content.startswith(""):
child_page_content = child_page_content[1:].strip()
if len(child_page_content) > 0:
child_document.page_content = child_page_content
child_nodes.append(child_document)
return child_nodes

View File

@ -21,18 +21,28 @@ from core.rag.models.document import Document
from core.tools.utils.text_processing_utils import remove_leading_symbols
from libs import helper
from models.dataset import Dataset
from services.entities.knowledge_entities.knowledge_entities import Rule
class QAIndexProcessor(BaseIndexProcessor):
def extract(self, extract_setting: ExtractSetting, **kwargs) -> list[Document]:
text_docs = ExtractProcessor.extract(
extract_setting=extract_setting, is_automatic=kwargs.get("process_rule_mode") == "automatic"
extract_setting=extract_setting,
is_automatic=(
kwargs.get("process_rule_mode") == "automatic" or kwargs.get("process_rule_mode") == "hierarchical"
),
)
return text_docs
def transform(self, documents: list[Document], **kwargs) -> list[Document]:
preview = kwargs.get("preview")
process_rule = kwargs.get("process_rule")
rules = Rule(**process_rule.get("rules"))
splitter = self._get_splitter(
processing_rule=kwargs.get("process_rule") or {},
processing_rule_mode=process_rule.get("mode"),
max_tokens=rules.segmentation.max_tokens,
chunk_overlap=rules.segmentation.chunk_overlap,
separator=rules.segmentation.separator,
embedding_model_instance=kwargs.get("embedding_model_instance"),
)
@ -59,6 +69,15 @@ class QAIndexProcessor(BaseIndexProcessor):
document_node.page_content = remove_leading_symbols(page_content)
split_documents.append(document_node)
all_documents.extend(split_documents)
if preview:
self._format_qa_document(
current_app._get_current_object(),
kwargs.get("tenant_id"),
all_documents[0],
all_qa_documents,
kwargs.get("doc_language", "English"),
)
else:
for i in range(0, len(all_documents), 10):
threads = []
sub_documents = all_documents[i : i + 10]
@ -66,7 +85,7 @@ class QAIndexProcessor(BaseIndexProcessor):
document_format_thread = threading.Thread(
target=self._format_qa_document,
kwargs={
"flask_app": current_app._get_current_object(), # type: ignore
"flask_app": current_app._get_current_object(),
"tenant_id": kwargs.get("tenant_id"),
"document_node": doc,
"all_qa_documents": all_qa_documents,
@ -98,12 +117,12 @@ class QAIndexProcessor(BaseIndexProcessor):
raise ValueError(str(e))
return text_docs
def load(self, dataset: Dataset, documents: list[Document], with_keywords: bool = True):
def load(self, dataset: Dataset, documents: list[Document], with_keywords: bool = True, **kwargs):
if dataset.indexing_technique == "high_quality":
vector = Vector(dataset)
vector.create(documents)
def clean(self, dataset: Dataset, node_ids: Optional[list[str]], with_keywords: bool = True):
def clean(self, dataset: Dataset, node_ids: Optional[list[str]], with_keywords: bool = True, **kwargs):
vector = Vector(dataset)
if node_ids:
vector.delete_by_ids(node_ids)

View File

@ -5,6 +5,19 @@ from typing import Any, Optional
from pydantic import BaseModel, Field
class ChildDocument(BaseModel):
"""Class for storing a piece of text and associated metadata."""
page_content: str
vector: Optional[list[float]] = None
"""Arbitrary metadata about the page content (e.g., source, relationships to other
documents, etc.).
"""
metadata: Optional[dict] = Field(default_factory=dict)
class Document(BaseModel):
"""Class for storing a piece of text and associated metadata."""
@ -19,6 +32,8 @@ class Document(BaseModel):
provider: Optional[str] = "dify"
children: Optional[list[ChildDocument]] = None
class BaseDocumentTransformer(ABC):
"""Abstract base class for document transformation systems.

View File

@ -166,43 +166,29 @@ class DatasetRetrieval:
"content": item.page_content,
}
retrieval_resource_list.append(source)
document_score_list = {}
# deal with dify documents
if dify_documents:
for item in dify_documents:
if item.metadata.get("score"):
document_score_list[item.metadata["doc_id"]] = item.metadata["score"]
index_node_ids = [document.metadata["doc_id"] for document in dify_documents]
segments = DocumentSegment.query.filter(
DocumentSegment.dataset_id.in_(dataset_ids),
DocumentSegment.status == "completed",
DocumentSegment.enabled == True,
DocumentSegment.index_node_id.in_(index_node_ids),
).all()
if segments:
index_node_id_to_position = {id: position for position, id in enumerate(index_node_ids)}
sorted_segments = sorted(
segments, key=lambda segment: index_node_id_to_position.get(segment.index_node_id, float("inf"))
)
for segment in sorted_segments:
records = RetrievalService.format_retrieval_documents(dify_documents)
if records:
for record in records:
segment = record.segment
if segment.answer:
document_context_list.append(
DocumentContext(
content=f"question:{segment.get_sign_content()} answer:{segment.answer}",
score=document_score_list.get(segment.index_node_id, None),
score=record.score,
)
)
else:
document_context_list.append(
DocumentContext(
content=segment.get_sign_content(),
score=document_score_list.get(segment.index_node_id, None),
score=record.score,
)
)
if show_retrieve_source:
for segment in sorted_segments:
for record in records:
segment = record.segment
dataset = Dataset.query.filter_by(id=segment.dataset_id).first()
document = DatasetDocument.query.filter(
DatasetDocument.id == segment.document_id,
@ -218,7 +204,7 @@ class DatasetRetrieval:
"data_source_type": document.data_source_type,
"segment_id": segment.id,
"retriever_from": invoke_from.to_source(),
"score": document_score_list.get(segment.index_node_id, 0.0),
"score": record.score or 0.0,
}
if invoke_from.to_source() == "dev":

View File

@ -11,6 +11,7 @@ from core.entities.model_entities import ModelStatus
from core.model_manager import ModelInstance, ModelManager
from core.model_runtime.entities.model_entities import ModelFeature, ModelType
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.rag.datasource.retrieval_service import RetrievalService
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
from core.rag.retrieval.retrieval_methods import RetrievalMethod
from core.variables import StringSegment
@ -18,7 +19,7 @@ from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.nodes.base import BaseNode
from core.workflow.nodes.enums import NodeType
from extensions.ext_database import db
from models.dataset import Dataset, Document, DocumentSegment
from models.dataset import Dataset, Document
from models.workflow import WorkflowNodeExecutionStatus
from .entities import KnowledgeRetrievalNodeData
@ -211,29 +212,12 @@ class KnowledgeRetrievalNode(BaseNode[KnowledgeRetrievalNodeData]):
"content": item.page_content,
}
retrieval_resource_list.append(source)
document_score_list: dict[str, float] = {}
# deal with dify documents
if dify_documents:
document_score_list = {}
for item in dify_documents:
if item.metadata.get("score"):
document_score_list[item.metadata["doc_id"]] = item.metadata["score"]
index_node_ids = [document.metadata["doc_id"] for document in dify_documents]
segments = DocumentSegment.query.filter(
DocumentSegment.dataset_id.in_(dataset_ids),
DocumentSegment.completed_at.isnot(None),
DocumentSegment.status == "completed",
DocumentSegment.enabled == True,
DocumentSegment.index_node_id.in_(index_node_ids),
).all()
if segments:
index_node_id_to_position = {id: position for position, id in enumerate(index_node_ids)}
sorted_segments = sorted(
segments, key=lambda segment: index_node_id_to_position.get(segment.index_node_id, float("inf"))
)
for segment in sorted_segments:
records = RetrievalService.format_retrieval_documents(dify_documents)
if records:
for record in records:
segment = record.segment
dataset = Dataset.query.filter_by(id=segment.dataset_id).first()
document = Document.query.filter(
Document.id == segment.document_id,
@ -251,7 +235,7 @@ class KnowledgeRetrievalNode(BaseNode[KnowledgeRetrievalNodeData]):
"document_data_source_type": document.data_source_type,
"segment_id": segment.id,
"retriever_from": "workflow",
"score": document_score_list.get(segment.index_node_id, None),
"score": record.score or 0.0,
"segment_hit_count": segment.hit_count,
"segment_word_count": segment.word_count,
"segment_position": segment.position,
@ -270,10 +254,8 @@ class KnowledgeRetrievalNode(BaseNode[KnowledgeRetrievalNodeData]):
key=lambda x: x["metadata"]["score"] if x["metadata"].get("score") is not None else 0.0,
reverse=True,
)
position = 1
for item in retrieval_resource_list:
for position, item in enumerate(retrieval_resource_list, start=1):
item["metadata"]["position"] = position
position += 1
return retrieval_resource_list
@classmethod

View File

@ -73,6 +73,7 @@ dataset_detail_fields = {
"embedding_available": fields.Boolean,
"retrieval_model_dict": fields.Nested(dataset_retrieval_model_fields),
"tags": fields.List(fields.Nested(tag_fields)),
"doc_form": fields.String,
"external_knowledge_info": fields.Nested(external_knowledge_info_fields),
"external_retrieval_model": fields.Nested(external_retrieval_model_fields, allow_null=True),
}

View File

@ -34,6 +34,7 @@ document_with_segments_fields = {
"data_source_info": fields.Raw(attribute="data_source_info_dict"),
"data_source_detail_dict": fields.Raw(attribute="data_source_detail_dict"),
"dataset_process_rule_id": fields.String,
"process_rule_dict": fields.Raw(attribute="process_rule_dict"),
"name": fields.String,
"created_from": fields.String,
"created_by": fields.String,

View File

@ -34,8 +34,16 @@ segment_fields = {
"document": fields.Nested(document_fields),
}
child_chunk_fields = {
"id": fields.String,
"content": fields.String,
"position": fields.Integer,
"score": fields.Float,
}
hit_testing_record_fields = {
"segment": fields.Nested(segment_fields),
"child_chunks": fields.List(fields.Nested(child_chunk_fields)),
"score": fields.Float,
"tsne_position": fields.Raw,
}

View File

@ -2,6 +2,17 @@ from flask_restful import fields # type: ignore
from libs.helper import TimestampField
child_chunk_fields = {
"id": fields.String,
"segment_id": fields.String,
"content": fields.String,
"position": fields.Integer,
"word_count": fields.Integer,
"type": fields.String,
"created_at": TimestampField,
"updated_at": TimestampField,
}
segment_fields = {
"id": fields.String,
"position": fields.Integer,
@ -20,10 +31,13 @@ segment_fields = {
"status": fields.String,
"created_by": fields.String,
"created_at": TimestampField,
"updated_at": TimestampField,
"updated_by": fields.String,
"indexing_at": TimestampField,
"completed_at": TimestampField,
"error": fields.String,
"stopped_at": TimestampField,
"child_chunks": fields.List(fields.Nested(child_chunk_fields)),
}
segment_list_response = {

View File

@ -0,0 +1,55 @@
"""parent-child-index
Revision ID: e19037032219
Revises: 01d6889832f7
Create Date: 2024-11-22 07:01:17.550037
"""
from alembic import op
import models as models
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = 'e19037032219'
down_revision = 'd7999dfa4aae'
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.create_table('child_chunks',
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
sa.Column('tenant_id', models.types.StringUUID(), nullable=False),
sa.Column('dataset_id', models.types.StringUUID(), nullable=False),
sa.Column('document_id', models.types.StringUUID(), nullable=False),
sa.Column('segment_id', models.types.StringUUID(), nullable=False),
sa.Column('position', sa.Integer(), nullable=False),
sa.Column('content', sa.Text(), nullable=False),
sa.Column('word_count', sa.Integer(), nullable=False),
sa.Column('index_node_id', sa.String(length=255), nullable=True),
sa.Column('index_node_hash', sa.String(length=255), nullable=True),
sa.Column('type', sa.String(length=255), server_default=sa.text("'automatic'::character varying"), nullable=False),
sa.Column('created_by', models.types.StringUUID(), nullable=False),
sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False),
sa.Column('updated_by', models.types.StringUUID(), nullable=True),
sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False),
sa.Column('indexing_at', sa.DateTime(), nullable=True),
sa.Column('completed_at', sa.DateTime(), nullable=True),
sa.Column('error', sa.Text(), nullable=True),
sa.PrimaryKeyConstraint('id', name='child_chunk_pkey')
)
with op.batch_alter_table('child_chunks', schema=None) as batch_op:
batch_op.create_index('child_chunk_dataset_id_idx', ['tenant_id', 'dataset_id', 'document_id', 'segment_id', 'index_node_id'], unique=False)
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('child_chunks', schema=None) as batch_op:
batch_op.drop_index('child_chunk_dataset_id_idx')
op.drop_table('child_chunks')
# ### end Alembic commands ###

View File

@ -0,0 +1,47 @@
"""add_auto_disabled_dataset_logs
Revision ID: 923752d42eb6
Revises: e19037032219
Create Date: 2024-12-25 11:37:55.467101
"""
from alembic import op
import models as models
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = '923752d42eb6'
down_revision = 'e19037032219'
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.create_table('dataset_auto_disable_logs',
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
sa.Column('tenant_id', models.types.StringUUID(), nullable=False),
sa.Column('dataset_id', models.types.StringUUID(), nullable=False),
sa.Column('document_id', models.types.StringUUID(), nullable=False),
sa.Column('notified', sa.Boolean(), server_default=sa.text('false'), nullable=False),
sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False),
sa.PrimaryKeyConstraint('id', name='dataset_auto_disable_log_pkey')
)
with op.batch_alter_table('dataset_auto_disable_logs', schema=None) as batch_op:
batch_op.create_index('dataset_auto_disable_log_created_atx', ['created_at'], unique=False)
batch_op.create_index('dataset_auto_disable_log_dataset_idx', ['dataset_id'], unique=False)
batch_op.create_index('dataset_auto_disable_log_tenant_idx', ['tenant_id'], unique=False)
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('dataset_auto_disable_logs', schema=None) as batch_op:
batch_op.drop_index('dataset_auto_disable_log_tenant_idx')
batch_op.drop_index('dataset_auto_disable_log_dataset_idx')
batch_op.drop_index('dataset_auto_disable_log_created_atx')
op.drop_table('dataset_auto_disable_logs')
# ### end Alembic commands ###

View File

@ -17,6 +17,7 @@ from sqlalchemy.dialects.postgresql import JSONB
from configs import dify_config
from core.rag.retrieval.retrieval_methods import RetrievalMethod
from extensions.ext_storage import storage
from services.entities.knowledge_entities.knowledge_entities import ParentMode, Rule
from .account import Account
from .engine import db
@ -215,7 +216,7 @@ class DatasetProcessRule(db.Model): # type: ignore[name-defined]
created_by = db.Column(StringUUID, nullable=False)
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
MODES = ["automatic", "custom"]
MODES = ["automatic", "custom", "hierarchical"]
PRE_PROCESSING_RULES = ["remove_stopwords", "remove_extra_spaces", "remove_urls_emails"]
AUTOMATIC_RULES: dict[str, Any] = {
"pre_processing_rules": [
@ -231,8 +232,6 @@ class DatasetProcessRule(db.Model): # type: ignore[name-defined]
"dataset_id": self.dataset_id,
"mode": self.mode,
"rules": self.rules_dict,
"created_by": self.created_by,
"created_at": self.created_at,
}
@property
@ -396,6 +395,12 @@ class Document(db.Model): # type: ignore[name-defined]
.scalar()
)
@property
def process_rule_dict(self):
if self.dataset_process_rule_id:
return self.dataset_process_rule.to_dict()
return None
def to_dict(self):
return {
"id": self.id,
@ -560,6 +565,24 @@ class DocumentSegment(db.Model): # type: ignore[name-defined]
.first()
)
@property
def child_chunks(self):
process_rule = self.document.dataset_process_rule
if process_rule.mode == "hierarchical":
rules = Rule(**process_rule.rules_dict)
if rules.parent_mode and rules.parent_mode != ParentMode.FULL_DOC:
child_chunks = (
db.session.query(ChildChunk)
.filter(ChildChunk.segment_id == self.id)
.order_by(ChildChunk.position.asc())
.all()
)
return child_chunks or []
else:
return []
else:
return []
def get_sign_content(self):
signed_urls = []
text = self.content
@ -605,6 +628,47 @@ class DocumentSegment(db.Model): # type: ignore[name-defined]
return text
class ChildChunk(db.Model):
__tablename__ = "child_chunks"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="child_chunk_pkey"),
db.Index("child_chunk_dataset_id_idx", "tenant_id", "dataset_id", "document_id", "segment_id", "index_node_id"),
)
# initial fields
id = db.Column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()"))
tenant_id = db.Column(StringUUID, nullable=False)
dataset_id = db.Column(StringUUID, nullable=False)
document_id = db.Column(StringUUID, nullable=False)
segment_id = db.Column(StringUUID, nullable=False)
position = db.Column(db.Integer, nullable=False)
content = db.Column(db.Text, nullable=False)
word_count = db.Column(db.Integer, nullable=False)
# indexing fields
index_node_id = db.Column(db.String(255), nullable=True)
index_node_hash = db.Column(db.String(255), nullable=True)
type = db.Column(db.String(255), nullable=False, server_default=db.text("'automatic'::character varying"))
created_by = db.Column(StringUUID, nullable=False)
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
updated_by = db.Column(StringUUID, nullable=True)
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
indexing_at = db.Column(db.DateTime, nullable=True)
completed_at = db.Column(db.DateTime, nullable=True)
error = db.Column(db.Text, nullable=True)
@property
def dataset(self):
return db.session.query(Dataset).filter(Dataset.id == self.dataset_id).first()
@property
def document(self):
return db.session.query(Document).filter(Document.id == self.document_id).first()
@property
def segment(self):
return db.session.query(DocumentSegment).filter(DocumentSegment.id == self.segment_id).first()
class AppDatasetJoin(db.Model): # type: ignore[name-defined]
__tablename__ = "app_dataset_joins"
__table_args__ = (
@ -844,3 +908,20 @@ class ExternalKnowledgeBindings(db.Model): # type: ignore[name-defined]
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
updated_by = db.Column(StringUUID, nullable=True)
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
class DatasetAutoDisableLog(db.Model):
__tablename__ = "dataset_auto_disable_logs"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="dataset_auto_disable_log_pkey"),
db.Index("dataset_auto_disable_log_tenant_idx", "tenant_id"),
db.Index("dataset_auto_disable_log_dataset_idx", "dataset_id"),
db.Index("dataset_auto_disable_log_created_atx", "created_at"),
)
id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"))
tenant_id = db.Column(StringUUID, nullable=False)
dataset_id = db.Column(StringUUID, nullable=False)
document_id = db.Column(StringUUID, nullable=False)
notified = db.Column(db.Boolean, nullable=False, server_default=db.text("false"))
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))

View File

@ -10,7 +10,7 @@ from configs import dify_config
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from models.dataset import Dataset, DatasetQuery, Document
from models.dataset import Dataset, DatasetAutoDisableLog, DatasetQuery, Document
from services.feature_service import FeatureService
@ -75,6 +75,23 @@ def clean_unused_datasets_task():
)
if not dataset_query or len(dataset_query) == 0:
try:
# add auto disable log
documents = (
db.session.query(Document)
.filter(
Document.dataset_id == dataset.id,
Document.enabled == True,
Document.archived == False,
)
.all()
)
for document in documents:
dataset_auto_disable_log = DatasetAutoDisableLog(
tenant_id=dataset.tenant_id,
dataset_id=dataset.id,
document_id=document.id,
)
db.session.add(dataset_auto_disable_log)
# remove index
index_processor = IndexProcessorFactory(dataset.doc_form).init_index_processor()
index_processor.clean(dataset, None)
@ -151,6 +168,23 @@ def clean_unused_datasets_task():
else:
plan = plan_cache.decode()
if plan == "sandbox":
# add auto disable log
documents = (
db.session.query(Document)
.filter(
Document.dataset_id == dataset.id,
Document.enabled == True,
Document.archived == False,
)
.all()
)
for document in documents:
dataset_auto_disable_log = DatasetAutoDisableLog(
tenant_id=dataset.tenant_id,
dataset_id=dataset.id,
document_id=document.id,
)
db.session.add(dataset_auto_disable_log)
# remove index
index_processor = IndexProcessorFactory(dataset.doc_form).init_index_processor()
index_processor.clean(dataset, None)

View File

@ -0,0 +1,66 @@
import logging
import time
import click
from celery import shared_task
from flask import render_template
from extensions.ext_mail import mail
from models.account import Account, Tenant, TenantAccountJoin
from models.dataset import Dataset, DatasetAutoDisableLog
@shared_task(queue="mail")
def send_document_clean_notify_task():
"""
Async Send document clean notify mail
Usage: send_document_clean_notify_task.delay()
"""
if not mail.is_inited():
return
logging.info(click.style("Start send document clean notify mail", fg="green"))
start_at = time.perf_counter()
# send document clean notify mail
try:
dataset_auto_disable_logs = DatasetAutoDisableLog.query.filter(DatasetAutoDisableLog.notified == False).all()
# group by tenant_id
dataset_auto_disable_logs_map = {}
for dataset_auto_disable_log in dataset_auto_disable_logs:
dataset_auto_disable_logs_map[dataset_auto_disable_log.tenant_id].append(dataset_auto_disable_log)
for tenant_id, tenant_dataset_auto_disable_logs in dataset_auto_disable_logs_map.items():
knowledge_details = []
tenant = Tenant.query.filter(Tenant.id == tenant_id).first()
if not tenant:
continue
current_owner_join = TenantAccountJoin.query.filter_by(tenant_id=tenant.id, role="owner").first()
account = Account.query.filter(Account.id == current_owner_join.account_id).first()
if not account:
continue
dataset_auto_dataset_map = {}
for dataset_auto_disable_log in tenant_dataset_auto_disable_logs:
dataset_auto_dataset_map[dataset_auto_disable_log.dataset_id].append(
dataset_auto_disable_log.document_id
)
for dataset_id, document_ids in dataset_auto_dataset_map.items():
dataset = Dataset.query.filter(Dataset.id == dataset_id).first()
if dataset:
document_count = len(document_ids)
knowledge_details.append(f"<li>Knowledge base {dataset.name}: {document_count} documents</li>")
html_content = render_template(
"clean_document_job_mail_template-US.html",
)
mail.send(to=to, subject="立即加入 Dify 工作空间", html=html_content)
end_at = time.perf_counter()
logging.info(
click.style("Send document clean notify mail succeeded: latency: {}".format(end_at - start_at), fg="green")
)
except Exception:
logging.exception("Send invite member mail to {} failed".format(to))

File diff suppressed because it is too large Load Diff

View File

@ -1,4 +1,5 @@
from typing import Optional
from enum import Enum
from typing import Literal, Optional
from pydantic import BaseModel
@ -8,3 +9,112 @@ class SegmentUpdateEntity(BaseModel):
answer: Optional[str] = None
keywords: Optional[list[str]] = None
enabled: Optional[bool] = None
class ParentMode(str, Enum):
FULL_DOC = "full-doc"
PARAGRAPH = "paragraph"
class NotionIcon(BaseModel):
type: str
url: Optional[str] = None
emoji: Optional[str] = None
class NotionPage(BaseModel):
page_id: str
page_name: str
page_icon: Optional[NotionIcon] = None
type: str
class NotionInfo(BaseModel):
workspace_id: str
pages: list[NotionPage]
class WebsiteInfo(BaseModel):
provider: str
job_id: str
urls: list[str]
only_main_content: bool = True
class FileInfo(BaseModel):
file_ids: list[str]
class InfoList(BaseModel):
data_source_type: Literal["upload_file", "notion_import", "website_crawl"]
notion_info_list: Optional[list[NotionInfo]] = None
file_info_list: Optional[FileInfo] = None
website_info_list: Optional[WebsiteInfo] = None
class DataSource(BaseModel):
info_list: InfoList
class PreProcessingRule(BaseModel):
id: str
enabled: bool
class Segmentation(BaseModel):
separator: str = "\n"
max_tokens: int
chunk_overlap: int = 0
class Rule(BaseModel):
pre_processing_rules: Optional[list[PreProcessingRule]] = None
segmentation: Optional[Segmentation] = None
parent_mode: Optional[Literal["full-doc", "paragraph"]] = None
subchunk_segmentation: Optional[Segmentation] = None
class ProcessRule(BaseModel):
mode: Literal["automatic", "custom", "hierarchical"]
rules: Optional[Rule] = None
class RerankingModel(BaseModel):
reranking_provider_name: Optional[str] = None
reranking_model_name: Optional[str] = None
class RetrievalModel(BaseModel):
search_method: Literal["hybrid_search", "semantic_search", "full_text_search"]
reranking_enable: bool
reranking_model: Optional[RerankingModel] = None
top_k: int
score_threshold_enabled: bool
score_threshold: Optional[float] = None
class KnowledgeConfig(BaseModel):
original_document_id: Optional[str] = None
duplicate: bool = True
indexing_technique: Literal["high_quality", "economy"]
data_source: Optional[DataSource] = None
process_rule: Optional[ProcessRule] = None
retrieval_model: Optional[RetrievalModel] = None
doc_form: str = "text_model"
doc_language: str = "English"
embedding_model: Optional[str] = None
embedding_model_provider: Optional[str] = None
name: Optional[str] = None
class SegmentUpdateArgs(BaseModel):
content: Optional[str] = None
answer: Optional[str] = None
keywords: Optional[list[str]] = None
regenerate_child_chunks: bool = False
enabled: Optional[bool] = None
class ChildChunkUpdateArgs(BaseModel):
id: Optional[str] = None
content: str

View File

@ -0,0 +1,9 @@
from services.errors.base import BaseServiceError
class ChildChunkIndexingError(BaseServiceError):
description = "{message}"
class ChildChunkDeleteIndexError(BaseServiceError):
description = "{message}"

View File

@ -7,7 +7,7 @@ from core.rag.models.document import Document
from core.rag.retrieval.retrieval_methods import RetrievalMethod
from extensions.ext_database import db
from models.account import Account
from models.dataset import Dataset, DatasetQuery, DocumentSegment
from models.dataset import Dataset, DatasetQuery
default_retrieval_model = {
"search_method": RetrievalMethod.SEMANTIC_SEARCH.value,
@ -69,7 +69,7 @@ class HitTestingService:
db.session.add(dataset_query)
db.session.commit()
return dict(cls.compact_retrieve_response(dataset, query, all_documents))
return cls.compact_retrieve_response(query, all_documents)
@classmethod
def external_retrieve(
@ -106,41 +106,14 @@ class HitTestingService:
return dict(cls.compact_external_retrieve_response(dataset, query, all_documents))
@classmethod
def compact_retrieve_response(cls, dataset: Dataset, query: str, documents: list[Document]):
records = []
for document in documents:
if document.metadata is None:
continue
index_node_id = document.metadata["doc_id"]
segment = (
db.session.query(DocumentSegment)
.filter(
DocumentSegment.dataset_id == dataset.id,
DocumentSegment.enabled == True,
DocumentSegment.status == "completed",
DocumentSegment.index_node_id == index_node_id,
)
.first()
)
if not segment:
continue
record = {
"segment": segment,
"score": document.metadata.get("score", None),
}
records.append(record)
def compact_retrieve_response(cls, query: str, documents: list[Document]):
records = RetrievalService.format_retrieval_documents(documents)
return {
"query": {
"content": query,
},
"records": records,
"records": [record.model_dump() for record in records],
}
@classmethod

View File

@ -1,18 +1,55 @@
from typing import Optional
from core.model_manager import ModelInstance, ModelManager
from core.model_runtime.entities.model_entities import ModelType
from core.rag.datasource.keyword.keyword_factory import Keyword
from core.rag.datasource.vdb.vector_factory import Vector
from core.rag.index_processor.constant.index_type import IndexType
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from core.rag.models.document import Document
from models.dataset import Dataset, DocumentSegment
from extensions.ext_database import db
from models.dataset import ChildChunk, Dataset, DatasetProcessRule, DocumentSegment
from models.dataset import Document as DatasetDocument
from services.entities.knowledge_entities.knowledge_entities import ParentMode
class VectorService:
@classmethod
def create_segments_vector(
cls, keywords_list: Optional[list[list[str]]], segments: list[DocumentSegment], dataset: Dataset
cls, keywords_list: Optional[list[list[str]]], segments: list[DocumentSegment], dataset: Dataset, doc_form: str
):
documents = []
for segment in segments:
if doc_form == IndexType.PARENT_CHILD_INDEX:
document = DatasetDocument.query.filter_by(id=segment.document_id).first()
# get the process rule
processing_rule = (
db.session.query(DatasetProcessRule)
.filter(DatasetProcessRule.id == document.dataset_process_rule_id)
.first()
)
# get embedding model instance
if dataset.indexing_technique == "high_quality":
# check embedding model setting
model_manager = ModelManager()
if dataset.embedding_model_provider:
embedding_model_instance = model_manager.get_model_instance(
tenant_id=dataset.tenant_id,
provider=dataset.embedding_model_provider,
model_type=ModelType.TEXT_EMBEDDING,
model=dataset.embedding_model,
)
else:
embedding_model_instance = model_manager.get_default_model_instance(
tenant_id=dataset.tenant_id,
model_type=ModelType.TEXT_EMBEDDING,
)
else:
raise ValueError("The knowledge base index technique is not high quality!")
cls.generate_child_chunks(segment, document, dataset, embedding_model_instance, processing_rule, False)
else:
document = Document(
page_content=segment.content,
metadata={
@ -23,18 +60,9 @@ class VectorService:
},
)
documents.append(document)
if dataset.indexing_technique == "high_quality":
# save vector index
vector = Vector(dataset=dataset)
vector.add_texts(documents, duplicate_check=True)
# save keyword index
keyword = Keyword(dataset)
if keywords_list and len(keywords_list) > 0:
keyword.add_texts(documents, keywords_list=keywords_list)
else:
keyword.add_texts(documents)
if len(documents) > 0:
index_processor = IndexProcessorFactory(doc_form).init_index_processor()
index_processor.load(dataset, documents, with_keywords=True, keywords_list=keywords_list)
@classmethod
def update_segment_vector(cls, keywords: Optional[list[str]], segment: DocumentSegment, dataset: Dataset):
@ -65,3 +93,123 @@ class VectorService:
keyword.add_texts([document], keywords_list=[keywords])
else:
keyword.add_texts([document])
@classmethod
def generate_child_chunks(
cls,
segment: DocumentSegment,
dataset_document: Document,
dataset: Dataset,
embedding_model_instance: ModelInstance,
processing_rule: DatasetProcessRule,
regenerate: bool = False,
):
index_processor = IndexProcessorFactory(dataset.doc_form).init_index_processor()
if regenerate:
# delete child chunks
index_processor.clean(dataset, [segment.index_node_id], with_keywords=True, delete_child_chunks=True)
# generate child chunks
document = Document(
page_content=segment.content,
metadata={
"doc_id": segment.index_node_id,
"doc_hash": segment.index_node_hash,
"document_id": segment.document_id,
"dataset_id": segment.dataset_id,
},
)
# use full doc mode to generate segment's child chunk
processing_rule_dict = processing_rule.to_dict()
processing_rule_dict["rules"]["parent_mode"] = ParentMode.FULL_DOC.value
documents = index_processor.transform(
[document],
embedding_model_instance=embedding_model_instance,
process_rule=processing_rule_dict,
tenant_id=dataset.tenant_id,
doc_language=dataset_document.doc_language,
)
# save child chunks
if len(documents) > 0 and len(documents[0].children) > 0:
index_processor.load(dataset, documents)
for position, child_chunk in enumerate(documents[0].children, start=1):
child_segment = ChildChunk(
tenant_id=dataset.tenant_id,
dataset_id=dataset.id,
document_id=dataset_document.id,
segment_id=segment.id,
position=position,
index_node_id=child_chunk.metadata["doc_id"],
index_node_hash=child_chunk.metadata["doc_hash"],
content=child_chunk.page_content,
word_count=len(child_chunk.page_content),
type="automatic",
created_by=dataset_document.created_by,
)
db.session.add(child_segment)
db.session.commit()
@classmethod
def create_child_chunk_vector(cls, child_segment: ChildChunk, dataset: Dataset):
child_document = Document(
page_content=child_segment.content,
metadata={
"doc_id": child_segment.index_node_id,
"doc_hash": child_segment.index_node_hash,
"document_id": child_segment.document_id,
"dataset_id": child_segment.dataset_id,
},
)
if dataset.indexing_technique == "high_quality":
# save vector index
vector = Vector(dataset=dataset)
vector.add_texts([child_document], duplicate_check=True)
@classmethod
def update_child_chunk_vector(
cls,
new_child_chunks: list[ChildChunk],
update_child_chunks: list[ChildChunk],
delete_child_chunks: list[ChildChunk],
dataset: Dataset,
):
documents = []
delete_node_ids = []
for new_child_chunk in new_child_chunks:
new_child_document = Document(
page_content=new_child_chunk.content,
metadata={
"doc_id": new_child_chunk.index_node_id,
"doc_hash": new_child_chunk.index_node_hash,
"document_id": new_child_chunk.document_id,
"dataset_id": new_child_chunk.dataset_id,
},
)
documents.append(new_child_document)
for update_child_chunk in update_child_chunks:
child_document = Document(
page_content=update_child_chunk.content,
metadata={
"doc_id": update_child_chunk.index_node_id,
"doc_hash": update_child_chunk.index_node_hash,
"document_id": update_child_chunk.document_id,
"dataset_id": update_child_chunk.dataset_id,
},
)
documents.append(child_document)
delete_node_ids.append(update_child_chunk.index_node_id)
for delete_child_chunk in delete_child_chunks:
delete_node_ids.append(delete_child_chunk.index_node_id)
if dataset.indexing_technique == "high_quality":
# update vector index
vector = Vector(dataset=dataset)
if delete_node_ids:
vector.delete_by_ids(delete_node_ids)
if documents:
vector.add_texts(documents, duplicate_check=True)
@classmethod
def delete_child_chunk_vector(cls, child_chunk: ChildChunk, dataset: Dataset):
vector = Vector(dataset=dataset)
vector.delete_by_ids([child_chunk.index_node_id])

View File

@ -6,12 +6,13 @@ import click
from celery import shared_task # type: ignore
from werkzeug.exceptions import NotFound
from core.rag.index_processor.constant.index_type import IndexType
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from core.rag.models.document import Document
from core.rag.models.document import ChildDocument, Document
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from models.dataset import DatasetAutoDisableLog, DocumentSegment
from models.dataset import Document as DatasetDocument
from models.dataset import DocumentSegment
@shared_task(queue="dataset")
@ -53,7 +54,22 @@ def add_document_to_index_task(dataset_document_id: str):
"dataset_id": segment.dataset_id,
},
)
if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX:
child_chunks = segment.child_chunks
if child_chunks:
child_documents = []
for child_chunk in child_chunks:
child_document = ChildDocument(
page_content=child_chunk.content,
metadata={
"doc_id": child_chunk.index_node_id,
"doc_hash": child_chunk.index_node_hash,
"document_id": segment.document_id,
"dataset_id": segment.dataset_id,
},
)
child_documents.append(child_document)
document.children = child_documents
documents.append(document)
dataset = dataset_document.dataset
@ -65,6 +81,12 @@ def add_document_to_index_task(dataset_document_id: str):
index_processor = IndexProcessorFactory(index_type).init_index_processor()
index_processor.load(dataset, documents)
# delete auto disable log
db.session.query(DatasetAutoDisableLog).filter(
DatasetAutoDisableLog.document_id == dataset_document.id
).delete()
db.session.commit()
end_at = time.perf_counter()
logging.info(
click.style(

View File

@ -0,0 +1,75 @@
import logging
import time
import click
from celery import shared_task
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from core.tools.utils.web_reader_tool import get_image_upload_file_ids
from extensions.ext_database import db
from extensions.ext_storage import storage
from models.dataset import Dataset, DocumentSegment
from models.model import UploadFile
@shared_task(queue="dataset")
def batch_clean_document_task(document_ids: list[str], dataset_id: str, doc_form: str, file_ids: list[str]):
"""
Clean document when document deleted.
:param document_ids: document ids
:param dataset_id: dataset id
:param doc_form: doc_form
:param file_ids: file ids
Usage: clean_document_task.delay(document_id, dataset_id)
"""
logging.info(click.style("Start batch clean documents when documents deleted", fg="green"))
start_at = time.perf_counter()
try:
dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first()
if not dataset:
raise Exception("Document has no dataset")
segments = db.session.query(DocumentSegment).filter(DocumentSegment.document_id.in_(document_ids)).all()
# check segment is exist
if segments:
index_node_ids = [segment.index_node_id for segment in segments]
index_processor = IndexProcessorFactory(doc_form).init_index_processor()
index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True)
for segment in segments:
image_upload_file_ids = get_image_upload_file_ids(segment.content)
for upload_file_id in image_upload_file_ids:
image_file = db.session.query(UploadFile).filter(UploadFile.id == upload_file_id).first()
try:
storage.delete(image_file.key)
except Exception:
logging.exception(
"Delete image_files failed when storage deleted, \
image_upload_file_is: {}".format(upload_file_id)
)
db.session.delete(image_file)
db.session.delete(segment)
db.session.commit()
if file_ids:
files = db.session.query(UploadFile).filter(UploadFile.id.in_(file_ids)).all()
for file in files:
try:
storage.delete(file.key)
except Exception:
logging.exception("Delete file failed when document deleted, file_id: {}".format(file.id))
db.session.delete(file)
db.session.commit()
end_at = time.perf_counter()
logging.info(
click.style(
"Cleaned documents when documents deleted latency: {}".format(end_at - start_at),
fg="green",
)
)
except Exception:
logging.exception("Cleaned documents when documents deleted failed")

View File

@ -7,13 +7,13 @@ import click
from celery import shared_task # type: ignore
from sqlalchemy import func
from core.indexing_runner import IndexingRunner
from core.model_manager import ModelManager
from core.model_runtime.entities.model_entities import ModelType
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from libs import helper
from models.dataset import Dataset, Document, DocumentSegment
from services.vector_service import VectorService
@shared_task(queue="dataset")
@ -96,8 +96,7 @@ def batch_create_segment_to_index_task(
dataset_document.word_count += word_count_change
db.session.add(dataset_document)
# add index to db
indexing_runner = IndexingRunner()
indexing_runner.batch_add_segments(document_segments, dataset)
VectorService.create_segments_vector(None, document_segments, dataset, dataset_document.doc_form)
db.session.commit()
redis_client.setex(indexing_cache_key, 600, "completed")
end_at = time.perf_counter()

View File

@ -62,7 +62,7 @@ def clean_dataset_task(
if doc_form is None:
raise ValueError("Index type must be specified.")
index_processor = IndexProcessorFactory(doc_form).init_index_processor()
index_processor.clean(dataset, None)
index_processor.clean(dataset, None, with_keywords=True, delete_child_chunks=True)
for document in documents:
db.session.delete(document)

View File

@ -38,7 +38,7 @@ def clean_document_task(document_id: str, dataset_id: str, doc_form: str, file_i
if segments:
index_node_ids = [segment.index_node_id for segment in segments]
index_processor = IndexProcessorFactory(doc_form).init_index_processor()
index_processor.clean(dataset, index_node_ids)
index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True)
for segment in segments:
image_upload_file_ids = get_image_upload_file_ids(segment.content)

View File

@ -37,7 +37,7 @@ def clean_notion_document_task(document_ids: list[str], dataset_id: str):
segments = db.session.query(DocumentSegment).filter(DocumentSegment.document_id == document_id).all()
index_node_ids = [segment.index_node_id for segment in segments]
index_processor.clean(dataset, index_node_ids)
index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True)
for segment in segments:
db.session.delete(segment)

View File

@ -4,8 +4,9 @@ import time
import click
from celery import shared_task # type: ignore
from core.rag.index_processor.constant.index_type import IndexType
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from core.rag.models.document import Document
from core.rag.models.document import ChildDocument, Document
from extensions.ext_database import db
from models.dataset import Dataset, DocumentSegment
from models.dataset import Document as DatasetDocument
@ -105,7 +106,7 @@ def deal_dataset_vector_index_task(dataset_id: str, action: str):
db.session.commit()
# clean index
index_processor.clean(dataset, None, with_keywords=False)
index_processor.clean(dataset, None, with_keywords=False, delete_child_chunks=False)
for dataset_document in dataset_documents:
# update from vector index
@ -128,7 +129,22 @@ def deal_dataset_vector_index_task(dataset_id: str, action: str):
"dataset_id": segment.dataset_id,
},
)
if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX:
child_chunks = segment.child_chunks
if child_chunks:
child_documents = []
for child_chunk in child_chunks:
child_document = ChildDocument(
page_content=child_chunk.content,
metadata={
"doc_id": child_chunk.index_node_id,
"doc_hash": child_chunk.index_node_hash,
"document_id": segment.document_id,
"dataset_id": segment.dataset_id,
},
)
child_documents.append(child_document)
document.children = child_documents
documents.append(document)
# save vector index
index_processor.load(dataset, documents, with_keywords=False)

View File

@ -6,48 +6,38 @@ from celery import shared_task # type: ignore
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from models.dataset import Dataset, Document
@shared_task(queue="dataset")
def delete_segment_from_index_task(segment_id: str, index_node_id: str, dataset_id: str, document_id: str):
def delete_segment_from_index_task(index_node_ids: list, dataset_id: str, document_id: str):
"""
Async Remove segment from index
:param segment_id:
:param index_node_id:
:param index_node_ids:
:param dataset_id:
:param document_id:
Usage: delete_segment_from_index_task.delay(segment_id)
Usage: delete_segment_from_index_task.delay(segment_ids)
"""
logging.info(click.style("Start delete segment from index: {}".format(segment_id), fg="green"))
logging.info(click.style("Start delete segment from index", fg="green"))
start_at = time.perf_counter()
indexing_cache_key = "segment_{}_delete_indexing".format(segment_id)
try:
dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first()
if not dataset:
logging.info(click.style("Segment {} has no dataset, pass.".format(segment_id), fg="cyan"))
return
dataset_document = db.session.query(Document).filter(Document.id == document_id).first()
if not dataset_document:
logging.info(click.style("Segment {} has no document, pass.".format(segment_id), fg="cyan"))
return
if not dataset_document.enabled or dataset_document.archived or dataset_document.indexing_status != "completed":
logging.info(click.style("Segment {} document status is invalid, pass.".format(segment_id), fg="cyan"))
return
index_type = dataset_document.doc_form
index_processor = IndexProcessorFactory(index_type).init_index_processor()
index_processor.clean(dataset, [index_node_id])
index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True)
end_at = time.perf_counter()
logging.info(
click.style("Segment deleted from index: {} latency: {}".format(segment_id, end_at - start_at), fg="green")
)
logging.info(click.style("Segment deleted from index latency: {}".format(end_at - start_at), fg="green"))
except Exception:
logging.exception("delete segment from index failed")
finally:
redis_client.delete(indexing_cache_key)

View File

@ -0,0 +1,76 @@
import logging
import time
import click
from celery import shared_task
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from models.dataset import Dataset, DocumentSegment
from models.dataset import Document as DatasetDocument
@shared_task(queue="dataset")
def disable_segments_from_index_task(segment_ids: list, dataset_id: str, document_id: str):
"""
Async disable segments from index
:param segment_ids:
Usage: disable_segments_from_index_task.delay(segment_ids, dataset_id, document_id)
"""
start_at = time.perf_counter()
dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first()
if not dataset:
logging.info(click.style("Dataset {} not found, pass.".format(dataset_id), fg="cyan"))
return
dataset_document = db.session.query(DatasetDocument).filter(DatasetDocument.id == document_id).first()
if not dataset_document:
logging.info(click.style("Document {} not found, pass.".format(document_id), fg="cyan"))
return
if not dataset_document.enabled or dataset_document.archived or dataset_document.indexing_status != "completed":
logging.info(click.style("Document {} status is invalid, pass.".format(document_id), fg="cyan"))
return
# sync index processor
index_processor = IndexProcessorFactory(dataset_document.doc_form).init_index_processor()
segments = (
db.session.query(DocumentSegment)
.filter(
DocumentSegment.id.in_(segment_ids),
DocumentSegment.dataset_id == dataset_id,
DocumentSegment.document_id == document_id,
)
.all()
)
if not segments:
return
try:
index_node_ids = [segment.index_node_id for segment in segments]
index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=False)
end_at = time.perf_counter()
logging.info(click.style("Segments removed from index latency: {}".format(end_at - start_at), fg="green"))
except Exception:
# update segment error msg
db.session.query(DocumentSegment).filter(
DocumentSegment.id.in_(segment_ids),
DocumentSegment.dataset_id == dataset_id,
DocumentSegment.document_id == document_id,
).update(
{
"disabled_at": None,
"disabled_by": None,
"enabled": True,
}
)
db.session.commit()
finally:
for segment in segments:
indexing_cache_key = "segment_{}_indexing".format(segment.id)
redis_client.delete(indexing_cache_key)

View File

@ -82,7 +82,7 @@ def document_indexing_sync_task(dataset_id: str, document_id: str):
index_node_ids = [segment.index_node_id for segment in segments]
# delete from vector index
index_processor.clean(dataset, index_node_ids)
index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True)
for segment in segments:
db.session.delete(segment)

View File

@ -47,7 +47,7 @@ def document_indexing_update_task(dataset_id: str, document_id: str):
index_node_ids = [segment.index_node_id for segment in segments]
# delete from vector index
index_processor.clean(dataset, index_node_ids)
index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True)
for segment in segments:
db.session.delete(segment)

View File

@ -51,7 +51,7 @@ def duplicate_document_indexing_task(dataset_id: str, document_ids: list):
if document:
document.indexing_status = "error"
document.error = str(e)
document.stopped_at = datetime.datetime.utcnow()
document.stopped_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
db.session.add(document)
db.session.commit()
return
@ -73,14 +73,14 @@ def duplicate_document_indexing_task(dataset_id: str, document_ids: list):
index_node_ids = [segment.index_node_id for segment in segments]
# delete from vector index
index_processor.clean(dataset, index_node_ids)
index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True)
for segment in segments:
db.session.delete(segment)
db.session.commit()
document.indexing_status = "parsing"
document.processing_started_at = datetime.datetime.utcnow()
document.processing_started_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
documents.append(document)
db.session.add(document)
db.session.commit()

View File

@ -6,8 +6,9 @@ import click
from celery import shared_task # type: ignore
from werkzeug.exceptions import NotFound
from core.rag.index_processor.constant.index_type import IndexType
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from core.rag.models.document import Document
from core.rag.models.document import ChildDocument, Document
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from models.dataset import DocumentSegment
@ -61,6 +62,22 @@ def enable_segment_to_index_task(segment_id: str):
return
index_processor = IndexProcessorFactory(dataset_document.doc_form).init_index_processor()
if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX:
child_chunks = segment.child_chunks
if child_chunks:
child_documents = []
for child_chunk in child_chunks:
child_document = ChildDocument(
page_content=child_chunk.content,
metadata={
"doc_id": child_chunk.index_node_id,
"doc_hash": child_chunk.index_node_hash,
"document_id": segment.document_id,
"dataset_id": segment.dataset_id,
},
)
child_documents.append(child_document)
document.children = child_documents
# save vector index
index_processor.load(dataset, [document])

View File

@ -0,0 +1,108 @@
import datetime
import logging
import time
import click
from celery import shared_task
from core.rag.index_processor.constant.index_type import IndexType
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from core.rag.models.document import ChildDocument, Document
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from models.dataset import Dataset, DocumentSegment
from models.dataset import Document as DatasetDocument
@shared_task(queue="dataset")
def enable_segments_to_index_task(segment_ids: list, dataset_id: str, document_id: str):
"""
Async enable segments to index
:param segment_ids:
Usage: enable_segments_to_index_task.delay(segment_ids)
"""
start_at = time.perf_counter()
dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first()
if not dataset:
logging.info(click.style("Dataset {} not found, pass.".format(dataset_id), fg="cyan"))
return
dataset_document = db.session.query(DatasetDocument).filter(DatasetDocument.id == document_id).first()
if not dataset_document:
logging.info(click.style("Document {} not found, pass.".format(document_id), fg="cyan"))
return
if not dataset_document.enabled or dataset_document.archived or dataset_document.indexing_status != "completed":
logging.info(click.style("Document {} status is invalid, pass.".format(document_id), fg="cyan"))
return
# sync index processor
index_processor = IndexProcessorFactory(dataset_document.doc_form).init_index_processor()
segments = (
db.session.query(DocumentSegment)
.filter(
DocumentSegment.id.in_(segment_ids),
DocumentSegment.dataset_id == dataset_id,
DocumentSegment.document_id == document_id,
)
.all()
)
if not segments:
return
try:
documents = []
for segment in segments:
document = Document(
page_content=segment.content,
metadata={
"doc_id": segment.index_node_id,
"doc_hash": segment.index_node_hash,
"document_id": document_id,
"dataset_id": dataset_id,
},
)
if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX:
child_chunks = segment.child_chunks
if child_chunks:
child_documents = []
for child_chunk in child_chunks:
child_document = ChildDocument(
page_content=child_chunk.content,
metadata={
"doc_id": child_chunk.index_node_id,
"doc_hash": child_chunk.index_node_hash,
"document_id": document_id,
"dataset_id": dataset_id,
},
)
child_documents.append(child_document)
document.children = child_documents
documents.append(document)
# save vector index
index_processor.load(dataset, documents)
end_at = time.perf_counter()
logging.info(click.style("Segments enabled to index latency: {}".format(end_at - start_at), fg="green"))
except Exception as e:
logging.exception("enable segments to index failed")
# update segment error msg
db.session.query(DocumentSegment).filter(
DocumentSegment.id.in_(segment_ids),
DocumentSegment.dataset_id == dataset_id,
DocumentSegment.document_id == document_id,
).update(
{
"error": str(e),
"status": "error",
"disabled_at": datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None),
"enabled": False,
}
)
db.session.commit()
finally:
for segment in segments:
indexing_cache_key = "segment_{}_indexing".format(segment.id)
redis_client.delete(indexing_cache_key)

View File

@ -43,7 +43,7 @@ def remove_document_from_index_task(document_id: str):
index_node_ids = [segment.index_node_id for segment in segments]
if index_node_ids:
try:
index_processor.clean(dataset, index_node_ids)
index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=False)
except Exception:
logging.exception(f"clean dataset {dataset.id} from index failed")

View File

@ -48,7 +48,7 @@ def retry_document_indexing_task(dataset_id: str, document_ids: list[str]):
if document:
document.indexing_status = "error"
document.error = str(e)
document.stopped_at = datetime.datetime.utcnow()
document.stopped_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
db.session.add(document)
db.session.commit()
redis_client.delete(retry_indexing_cache_key)
@ -69,14 +69,14 @@ def retry_document_indexing_task(dataset_id: str, document_ids: list[str]):
if segments:
index_node_ids = [segment.index_node_id for segment in segments]
# delete from vector index
index_processor.clean(dataset, index_node_ids)
index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True)
for segment in segments:
db.session.delete(segment)
db.session.commit()
document.indexing_status = "parsing"
document.processing_started_at = datetime.datetime.utcnow()
document.processing_started_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
db.session.add(document)
db.session.commit()
@ -86,7 +86,7 @@ def retry_document_indexing_task(dataset_id: str, document_ids: list[str]):
except Exception as ex:
document.indexing_status = "error"
document.error = str(ex)
document.stopped_at = datetime.datetime.utcnow()
document.stopped_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
db.session.add(document)
db.session.commit()
logging.info(click.style(str(ex), fg="yellow"))

View File

@ -46,7 +46,7 @@ def sync_website_document_indexing_task(dataset_id: str, document_id: str):
if document:
document.indexing_status = "error"
document.error = str(e)
document.stopped_at = datetime.datetime.utcnow()
document.stopped_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
db.session.add(document)
db.session.commit()
redis_client.delete(sync_indexing_cache_key)
@ -65,14 +65,14 @@ def sync_website_document_indexing_task(dataset_id: str, document_id: str):
if segments:
index_node_ids = [segment.index_node_id for segment in segments]
# delete from vector index
index_processor.clean(dataset, index_node_ids)
index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True)
for segment in segments:
db.session.delete(segment)
db.session.commit()
document.indexing_status = "parsing"
document.processing_started_at = datetime.datetime.utcnow()
document.processing_started_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
db.session.add(document)
db.session.commit()
@ -82,7 +82,7 @@ def sync_website_document_indexing_task(dataset_id: str, document_id: str):
except Exception as ex:
document.indexing_status = "error"
document.error = str(ex)
document.stopped_at = datetime.datetime.utcnow()
document.stopped_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
db.session.add(document)
db.session.commit()
logging.info(click.style(str(ex), fg="yellow"))

View File

@ -0,0 +1,98 @@
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>Documents Disabled Notification</title>
<style>
body {
font-family: Arial, sans-serif;
margin: 0;
padding: 0;
background-color: #f5f5f5;
}
.email-container {
max-width: 600px;
margin: 20px auto;
background: #ffffff;
border-radius: 10px;
box-shadow: 0 2px 10px rgba(0, 0, 0, 0.1);
overflow: hidden;
}
.header {
background-color: #eef2fa;
padding: 20px;
text-align: center;
}
.header img {
height: 40px;
}
.content {
padding: 20px;
line-height: 1.6;
color: #333;
}
.content h1 {
font-size: 24px;
color: #222;
}
.content p {
margin: 10px 0;
}
.content ul {
padding-left: 20px;
}
.content ul li {
margin-bottom: 10px;
}
.cta-button {
display: block;
margin: 20px auto;
padding: 10px 20px;
background-color: #4e89f9;
color: #ffffff;
text-align: center;
text-decoration: none;
border-radius: 5px;
width: fit-content;
}
.footer {
text-align: center;
padding: 10px;
font-size: 12px;
color: #777;
background-color: #f9f9f9;
}
</style>
</head>
<body>
<div class="email-container">
<!-- Header -->
<div class="header">
<img src="https://via.placeholder.com/150x40?text=Dify" alt="Dify Logo">
</div>
<!-- Content -->
<div class="content">
<h1>Some Documents in Your Knowledge Base Have Been Disabled</h1>
<p>Dear {{userName}},</p>
<p>
We're sorry for the inconvenience. To ensure optimal performance, documents
that havent been updated or accessed in the past 7 days have been disabled in
your knowledge bases:
</p>
<ul>
{{knowledge_details}}
</ul>
<p>You can re-enable them anytime.</p>
<a href={{url}} class="cta-button">Re-enable in Dify</a>
</div>
<!-- Footer -->
<div class="footer">
Sincerely,<br>
The Dify Team
</div>
</div>
</body>
</html>