diff --git a/api/core/callback_handler/index_tool_callback_handler.py b/api/core/callback_handler/index_tool_callback_handler.py index 8f8aaa93d6..64c734f626 100644 --- a/api/core/callback_handler/index_tool_callback_handler.py +++ b/api/core/callback_handler/index_tool_callback_handler.py @@ -1,9 +1,11 @@ from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.queue_entities import QueueRetrieverResourcesEvent +from core.rag.index_processor.constant.index_type import IndexType from core.rag.models.document import Document from extensions.ext_database import db -from models.dataset import DatasetQuery, DocumentSegment +from models.dataset import ChildChunk, DatasetQuery, DocumentSegment +from models.dataset import Document as DatasetDocument from models.model import DatasetRetrieverResource @@ -41,15 +43,29 @@ class DatasetIndexToolCallbackHandler: """Handle tool end.""" for document in documents: if document.metadata is not None: - query = db.session.query(DocumentSegment).filter( - DocumentSegment.index_node_id == document.metadata["doc_id"] - ) + dataset_document = DatasetDocument.query.filter( + DatasetDocument.id == document.metadata["document_id"] + ).first() + if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX: + child_chunk = ChildChunk.query.filter( + ChildChunk.index_node_id == document.metadata["doc_id"], + ChildChunk.dataset_id == dataset_document.dataset_id, + ChildChunk.document_id == dataset_document.id, + ).first() + if child_chunk: + segment = DocumentSegment.query.filter(DocumentSegment.id == child_chunk.segment_id).update( + {DocumentSegment.hit_count: DocumentSegment.hit_count + 1}, synchronize_session=False + ) + else: + query = db.session.query(DocumentSegment).filter( + DocumentSegment.index_node_id == document.metadata["doc_id"] + ) - if "dataset_id" in document.metadata: - query = query.filter(DocumentSegment.dataset_id == document.metadata["dataset_id"]) + if "dataset_id" in document.metadata: + query = query.filter(DocumentSegment.dataset_id == document.metadata["dataset_id"]) - # add hit count to document segment - query.update({DocumentSegment.hit_count: DocumentSegment.hit_count + 1}, synchronize_session=False) + # add hit count to document segment + query.update({DocumentSegment.hit_count: DocumentSegment.hit_count + 1}, synchronize_session=False) db.session.commit() diff --git a/api/core/rag/retrieval/dataset_retrieval.py b/api/core/rag/retrieval/dataset_retrieval.py index ac868a2250..c5ac63e853 100644 --- a/api/core/rag/retrieval/dataset_retrieval.py +++ b/api/core/rag/retrieval/dataset_retrieval.py @@ -21,6 +21,7 @@ from core.rag.data_post_processor.data_post_processor import DataPostProcessor from core.rag.datasource.keyword.jieba.jieba_keyword_table_handler import JiebaKeywordTableHandler from core.rag.datasource.retrieval_service import RetrievalService from core.rag.entities.context_entities import DocumentContext +from core.rag.index_processor.constant.index_type import IndexType from core.rag.models.document import Document from core.rag.rerank.rerank_type import RerankMode from core.rag.retrieval.retrieval_methods import RetrievalMethod @@ -28,7 +29,7 @@ from core.rag.retrieval.router.multi_dataset_function_call_router import Functio from core.rag.retrieval.router.multi_dataset_react_route import ReactMultiDatasetRouter from core.tools.utils.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool from extensions.ext_database import db -from models.dataset import Dataset, DatasetQuery, DocumentSegment +from models.dataset import ChildChunk, Dataset, DatasetQuery, DocumentSegment from models.dataset import Document as DatasetDocument from services.external_knowledge_service import ExternalDatasetService @@ -429,16 +430,31 @@ class DatasetRetrieval: dify_documents = [document for document in documents if document.provider == "dify"] for document in dify_documents: if document.metadata is not None: - query = db.session.query(DocumentSegment).filter( - DocumentSegment.index_node_id == document.metadata["doc_id"] - ) + dataset_document = DatasetDocument.query.filter( + DatasetDocument.id == document.metadata["document_id"] + ).first() + if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX: + child_chunk = ChildChunk.query.filter( + ChildChunk.index_node_id == document.metadata["doc_id"], + ChildChunk.dataset_id == dataset_document.dataset_id, + ChildChunk.document_id == dataset_document.id, + ).first() + if child_chunk: + segment = DocumentSegment.query.filter(DocumentSegment.id == child_chunk.segment_id).update( + {DocumentSegment.hit_count: DocumentSegment.hit_count + 1}, synchronize_session=False + ) + db.session.commit() + else: + query = db.session.query(DocumentSegment).filter( + DocumentSegment.index_node_id == document.metadata["doc_id"] + ) - # if 'dataset_id' in document.metadata: - if "dataset_id" in document.metadata: - query = query.filter(DocumentSegment.dataset_id == document.metadata["dataset_id"]) + # if 'dataset_id' in document.metadata: + if "dataset_id" in document.metadata: + query = query.filter(DocumentSegment.dataset_id == document.metadata["dataset_id"]) - # add hit count to document segment - query.update({DocumentSegment.hit_count: DocumentSegment.hit_count + 1}, synchronize_session=False) + # add hit count to document segment + query.update({DocumentSegment.hit_count: DocumentSegment.hit_count + 1}, synchronize_session=False) db.session.commit()