diff --git a/api/core/rag/datasource/retrieval_service.py b/api/core/rag/datasource/retrieval_service.py index 0d400000db..46a5330bdb 100644 --- a/api/core/rag/datasource/retrieval_service.py +++ b/api/core/rag/datasource/retrieval_service.py @@ -1,13 +1,9 @@ import concurrent.futures -import logging -import time from concurrent.futures import ThreadPoolExecutor from typing import Optional from flask import Flask, current_app -from sqlalchemy import and_, or_ from sqlalchemy.orm import load_only -from sqlalchemy.sql.expression import false from configs import dify_config from core.rag.data_post_processor.data_post_processor import DataPostProcessor @@ -182,7 +178,6 @@ class RetrievalService: if not dataset: raise ValueError("dataset not found") - start = time.time() vector = Vector(dataset=dataset) documents = vector.search_by_vector( query, @@ -192,7 +187,6 @@ class RetrievalService: filter={"group_id": [dataset.id]}, document_ids_filter=document_ids_filter, ) - logging.debug(f"embedding_search ends at {time.time() - start:.2f} seconds") if documents: if ( @@ -276,8 +270,7 @@ class RetrievalService: return [] try: - start_time = time.time() - # Collect document IDs with existence check + # Collect document IDs document_ids = {doc.metadata.get("document_id") for doc in documents if "document_id" in doc.metadata} if not document_ids: return [] @@ -295,138 +288,110 @@ class RetrievalService: include_segment_ids = set() segment_child_map = {} - # Precompute doc_forms to avoid redundant checks - doc_forms = {} - for doc in documents: - document_id = doc.metadata.get("document_id") - dataset_doc = dataset_documents.get(document_id) - if dataset_doc: - doc_forms[document_id] = dataset_doc.doc_form - - # Batch collect index node IDs with type safety - child_index_node_ids = [] - index_node_ids = [] - for doc in documents: - document_id = doc.metadata.get("document_id") - if doc_forms.get(document_id) == IndexType.PARENT_CHILD_INDEX: - child_index_node_ids.append(doc.metadata.get("doc_id")) - else: - index_node_ids.append(doc.metadata.get("doc_id")) - - # Batch query ChildChunk - child_chunks = db.session.query(ChildChunk).filter(ChildChunk.index_node_id.in_(child_index_node_ids)).all() - child_chunk_map = {chunk.index_node_id: chunk for chunk in child_chunks} - - segment_ids_from_child = [chunk.segment_id for chunk in child_chunks] - segment_conditions = [] - - if index_node_ids: - segment_conditions.append(DocumentSegment.index_node_id.in_(index_node_ids)) - - if segment_ids_from_child: - segment_conditions.append(DocumentSegment.id.in_(segment_ids_from_child)) - - if segment_conditions: - filter_expr = or_(*segment_conditions) - else: - filter_expr = false() - - segment_map = { - segment.id: segment - for segment in db.session.query(DocumentSegment) - .filter( - and_( - filter_expr, - DocumentSegment.enabled == True, - DocumentSegment.status == "completed", - ) - ) - .options( - load_only( - DocumentSegment.id, - DocumentSegment.content, - DocumentSegment.answer, - ) - ) - .all() - } - + # Process documents for document in documents: document_id = document.metadata.get("document_id") - dataset_document = dataset_documents.get(document_id) + if document_id not in dataset_documents: + continue + + dataset_document = dataset_documents[document_id] if not dataset_document: continue - doc_form = doc_forms.get(document_id) - if doc_form == IndexType.PARENT_CHILD_INDEX: - # Handle parent-child documents using preloaded data + if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX: + # Handle parent-child documents child_index_node_id = document.metadata.get("doc_id") - if not child_index_node_id: - continue - child_chunk = child_chunk_map.get(child_index_node_id) + child_chunk = ( + db.session.query(ChildChunk).filter(ChildChunk.index_node_id == child_index_node_id).first() + ) + if not child_chunk: continue - segment = segment_map.get(child_chunk.segment_id) + segment = ( + db.session.query(DocumentSegment) + .filter( + DocumentSegment.dataset_id == dataset_document.dataset_id, + DocumentSegment.enabled == True, + DocumentSegment.status == "completed", + DocumentSegment.id == child_chunk.segment_id, + ) + .options( + load_only( + DocumentSegment.id, + DocumentSegment.content, + DocumentSegment.answer, + ) + ) + .first() + ) + if not segment: continue if segment.id not in include_segment_ids: include_segment_ids.add(segment.id) - map_detail = {"max_score": document.metadata.get("score", 0.0), "child_chunks": []} + child_chunk_detail = { + "id": child_chunk.id, + "content": child_chunk.content, + "position": child_chunk.position, + "score": document.metadata.get("score", 0.0), + } + map_detail = { + "max_score": document.metadata.get("score", 0.0), + "child_chunks": [child_chunk_detail], + } segment_child_map[segment.id] = map_detail - records.append({"segment": segment}) - - # Append child chunk details - child_chunk_detail = { - "id": child_chunk.id, - "content": child_chunk.content, - "position": child_chunk.position, - "score": document.metadata.get("score", 0.0), - } - segment_child_map[segment.id]["child_chunks"].append(child_chunk_detail) - segment_child_map[segment.id]["max_score"] = max( - segment_child_map[segment.id]["max_score"], document.metadata.get("score", 0.0) - ) - + record = { + "segment": segment, + } + records.append(record) + else: + child_chunk_detail = { + "id": child_chunk.id, + "content": child_chunk.content, + "position": child_chunk.position, + "score": document.metadata.get("score", 0.0), + } + segment_child_map[segment.id]["child_chunks"].append(child_chunk_detail) + segment_child_map[segment.id]["max_score"] = max( + segment_child_map[segment.id]["max_score"], document.metadata.get("score", 0.0) + ) else: # Handle normal documents index_node_id = document.metadata.get("doc_id") if not index_node_id: continue - segment = next( - ( - s - for s in segment_map.values() - if s.index_node_id == index_node_id and s.dataset_id == dataset_document.dataset_id - ), - None, + segment = ( + db.session.query(DocumentSegment) + .filter( + DocumentSegment.dataset_id == dataset_document.dataset_id, + DocumentSegment.enabled == True, + DocumentSegment.status == "completed", + DocumentSegment.index_node_id == index_node_id, + ) + .first() ) if not segment: continue - if segment.id not in include_segment_ids: - include_segment_ids.add(segment.id) - records.append( - { - "segment": segment, - "score": document.metadata.get("score", 0.0), - } - ) + include_segment_ids.add(segment.id) + record = { + "segment": segment, + "score": document.metadata.get("score"), # type: ignore + } + records.append(record) - # Merge child chunks information + # Add child chunks information to records for record in records: - segment_id = record["segment"].id - if segment_id in segment_child_map: - record["child_chunks"] = segment_child_map[segment_id]["child_chunks"] - record["score"] = segment_child_map[segment_id]["max_score"] + if record["segment"].id in segment_child_map: + record["child_chunks"] = segment_child_map[record["segment"].id].get("child_chunks") # type: ignore + record["score"] = segment_child_map[record["segment"].id]["max_score"] - logging.debug(f"Formatting retrieval documents took {time.time() - start_time:.2f} seconds") return [RetrievalSegments(**record) for record in records] except Exception as e: - # Only rollback if there were write operations db.session.rollback() raise e