diff --git a/api/core/rag/datasource/retrieval_service.py b/api/core/rag/datasource/retrieval_service.py index fea4d0edf7..c4a1e9f059 100644 --- a/api/core/rag/datasource/retrieval_service.py +++ b/api/core/rag/datasource/retrieval_service.py @@ -1,4 +1,6 @@ import concurrent.futures +import logging +import time from concurrent.futures import ThreadPoolExecutor from typing import Optional @@ -46,7 +48,7 @@ class RetrievalService: if not query: return [] dataset = cls._get_dataset(dataset_id) - if not dataset or dataset.available_document_count == 0 or dataset.available_segment_count == 0: + if not dataset: return [] all_documents: list[Document] = [] @@ -178,6 +180,7 @@ class RetrievalService: if not dataset: raise ValueError("dataset not found") + start = time.time() vector = Vector(dataset=dataset) documents = vector.search_by_vector( query, @@ -187,6 +190,7 @@ class RetrievalService: filter={"group_id": [dataset.id]}, document_ids_filter=document_ids_filter, ) + logging.debug(f"embedding_search ends at {time.time() - start:.2f} seconds") if documents: if ( @@ -270,7 +274,8 @@ class RetrievalService: return [] try: - # Collect document IDs + start_time = time.time() + # Collect document IDs with existence check document_ids = {doc.metadata.get("document_id") for doc in documents if "document_id" in doc.metadata} if not document_ids: return [] @@ -288,43 +293,102 @@ class RetrievalService: include_segment_ids = set() segment_child_map = {} - # Process documents + # Precompute doc_forms to avoid redundant checks + doc_forms = {} + for doc in documents: + document_id = doc.metadata.get("document_id") + dataset_doc = dataset_documents.get(document_id) + if dataset_doc: + doc_forms[document_id] = dataset_doc.doc_form + + # Batch collect index node IDs with type safety + child_index_node_ids = [] + index_node_ids = [] + for doc in documents: + document_id = doc.metadata.get("document_id") + if doc_forms.get(document_id) == IndexType.PARENT_CHILD_INDEX: + child_index_node_ids.append(doc.metadata.get("doc_id")) + else: + index_node_ids.append(doc.metadata.get("doc_id")) + + # Batch query ChildChunk + child_chunks = db.session.query(ChildChunk).filter(ChildChunk.index_node_id.in_(child_index_node_ids)).all() + child_chunk_map = {chunk.index_node_id: chunk for chunk in child_chunks} + + # Batch query DocumentSegment with unified conditions + segment_map = { + segment.id: segment + for segment in db.session.query(DocumentSegment) + .filter( + ( + DocumentSegment.index_node_id.in_(index_node_ids) + | DocumentSegment.id.in_([chunk.segment_id for chunk in child_chunks]) + ), + DocumentSegment.enabled == True, + DocumentSegment.status == "completed", + ) + .options( + load_only( + DocumentSegment.id, + DocumentSegment.content, + DocumentSegment.answer, + ) + ) + .all() + } + for document in documents: document_id = document.metadata.get("document_id") - if document_id not in dataset_documents: - continue - - dataset_document = dataset_documents[document_id] + dataset_document = dataset_documents.get(document_id) if not dataset_document: continue - if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX: - # Handle parent-child documents + doc_form = doc_forms.get(document_id) + if doc_form == IndexType.PARENT_CHILD_INDEX: + # Handle parent-child documents using preloaded data child_index_node_id = document.metadata.get("doc_id") + if not child_index_node_id: + continue - child_chunk = ( - db.session.query(ChildChunk).filter(ChildChunk.index_node_id == child_index_node_id).first() - ) - + child_chunk = child_chunk_map.get(child_index_node_id) if not child_chunk: continue - segment = ( - db.session.query(DocumentSegment) - .filter( - DocumentSegment.dataset_id == dataset_document.dataset_id, - DocumentSegment.enabled == True, - DocumentSegment.status == "completed", - DocumentSegment.id == child_chunk.segment_id, - ) - .options( - load_only( - DocumentSegment.id, - DocumentSegment.content, - DocumentSegment.answer, - ) - ) - .first() + segment = segment_map.get(child_chunk.segment_id) + if not segment: + continue + + if segment.id not in include_segment_ids: + include_segment_ids.add(segment.id) + map_detail = {"max_score": document.metadata.get("score", 0.0), "child_chunks": []} + segment_child_map[segment.id] = map_detail + records.append({"segment": segment}) + + # Append child chunk details + child_chunk_detail = { + "id": child_chunk.id, + "content": child_chunk.content, + "position": child_chunk.position, + "score": document.metadata.get("score", 0.0), + } + segment_child_map[segment.id]["child_chunks"].append(child_chunk_detail) + segment_child_map[segment.id]["max_score"] = max( + segment_child_map[segment.id]["max_score"], document.metadata.get("score", 0.0) + ) + + else: + # Handle normal documents + index_node_id = document.metadata.get("doc_id") + if not index_node_id: + continue + + segment = next( + ( + s + for s in segment_map.values() + if s.index_node_id == index_node_id and s.dataset_id == dataset_document.dataset_id + ), + None, ) if not segment: @@ -332,66 +396,23 @@ class RetrievalService: if segment.id not in include_segment_ids: include_segment_ids.add(segment.id) - child_chunk_detail = { - "id": child_chunk.id, - "content": child_chunk.content, - "position": child_chunk.position, - "score": document.metadata.get("score", 0.0), - } - map_detail = { - "max_score": document.metadata.get("score", 0.0), - "child_chunks": [child_chunk_detail], - } - segment_child_map[segment.id] = map_detail - record = { - "segment": segment, - } - records.append(record) - else: - child_chunk_detail = { - "id": child_chunk.id, - "content": child_chunk.content, - "position": child_chunk.position, - "score": document.metadata.get("score", 0.0), - } - segment_child_map[segment.id]["child_chunks"].append(child_chunk_detail) - segment_child_map[segment.id]["max_score"] = max( - segment_child_map[segment.id]["max_score"], document.metadata.get("score", 0.0) + records.append( + { + "segment": segment, + "score": document.metadata.get("score", 0.0), + } ) - else: - # Handle normal documents - index_node_id = document.metadata.get("doc_id") - if not index_node_id: - continue - segment = ( - db.session.query(DocumentSegment) - .filter( - DocumentSegment.dataset_id == dataset_document.dataset_id, - DocumentSegment.enabled == True, - DocumentSegment.status == "completed", - DocumentSegment.index_node_id == index_node_id, - ) - .first() - ) - - if not segment: - continue - - include_segment_ids.add(segment.id) - record = { - "segment": segment, - "score": document.metadata.get("score"), # type: ignore - } - records.append(record) - - # Add child chunks information to records + # Merge child chunks information for record in records: - if record["segment"].id in segment_child_map: - record["child_chunks"] = segment_child_map[record["segment"].id].get("child_chunks") # type: ignore - record["score"] = segment_child_map[record["segment"].id]["max_score"] + segment_id = record["segment"].id + if segment_id in segment_child_map: + record["child_chunks"] = segment_child_map[segment_id]["child_chunks"] + record["score"] = segment_child_map[segment_id]["max_score"] + logging.debug(f"Formatting retrieval documents took {time.time() - start_time:.2f} seconds") return [RetrievalSegments(**record) for record in records] except Exception as e: + # Only rollback if there were write operations db.session.rollback() raise e diff --git a/api/migrations/versions/2025_03_29_2227-6a9f914f656c_change_documentsegment_and_childchunk_.py b/api/migrations/versions/2025_03_29_2227-6a9f914f656c_change_documentsegment_and_childchunk_.py new file mode 100644 index 0000000000..45904f0c80 --- /dev/null +++ b/api/migrations/versions/2025_03_29_2227-6a9f914f656c_change_documentsegment_and_childchunk_.py @@ -0,0 +1,43 @@ +"""change documentsegment and childchunk indexes + +Revision ID: 6a9f914f656c +Revises: d20049ed0af6 +Create Date: 2025-03-29 22:27:24.789481 + +""" +from alembic import op +import models as models +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = '6a9f914f656c' +down_revision = 'd20049ed0af6' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('child_chunks', schema=None) as batch_op: + batch_op.create_index('child_chunks_node_idx', ['index_node_id', 'dataset_id'], unique=False) + batch_op.create_index('child_chunks_segment_idx', ['segment_id'], unique=False) + + with op.batch_alter_table('document_segments', schema=None) as batch_op: + batch_op.drop_index('document_segment_dataset_node_idx') + batch_op.create_index('document_segment_node_dataset_idx', ['index_node_id', 'dataset_id'], unique=False) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('document_segments', schema=None) as batch_op: + batch_op.drop_index('document_segment_node_dataset_idx') + batch_op.create_index('document_segment_dataset_node_idx', ['dataset_id', 'index_node_id'], unique=False) + + with op.batch_alter_table('child_chunks', schema=None) as batch_op: + batch_op.drop_index('child_chunks_segment_idx') + batch_op.drop_index('child_chunks_node_idx') + + # ### end Alembic commands ### diff --git a/api/models/dataset.py b/api/models/dataset.py index 47f96c669e..d6708ac88b 100644 --- a/api/models/dataset.py +++ b/api/models/dataset.py @@ -643,7 +643,7 @@ class DocumentSegment(db.Model): # type: ignore[name-defined] db.Index("document_segment_document_id_idx", "document_id"), db.Index("document_segment_tenant_dataset_idx", "dataset_id", "tenant_id"), db.Index("document_segment_tenant_document_idx", "document_id", "tenant_id"), - db.Index("document_segment_dataset_node_idx", "dataset_id", "index_node_id"), + db.Index("document_segment_node_dataset_idx", "index_node_id", "dataset_id"), db.Index("document_segment_tenant_idx", "tenant_id"), ) @@ -791,6 +791,8 @@ class ChildChunk(db.Model): # type: ignore[name-defined] __table_args__ = ( db.PrimaryKeyConstraint("id", name="child_chunk_pkey"), db.Index("child_chunk_dataset_id_idx", "tenant_id", "dataset_id", "document_id", "segment_id", "index_node_id"), + db.Index("child_chunks_node_idx", "index_node_id", "dataset_id"), + db.Index("child_chunks_segment_idx", "segment_id"), ) # initial fields diff --git a/api/services/hit_testing_service.py b/api/services/hit_testing_service.py index f8c1c1d297..0b98065f5d 100644 --- a/api/services/hit_testing_service.py +++ b/api/services/hit_testing_service.py @@ -29,15 +29,6 @@ class HitTestingService: external_retrieval_model: dict, limit: int = 10, ) -> dict: - if dataset.available_document_count == 0 or dataset.available_segment_count == 0: - return { - "query": { - "content": query, - "tsne_position": {"x": 0, "y": 0}, - }, - "records": [], - } - start = time.perf_counter() # get retrieval model , if the model is not setting , using default