diff --git a/api/models/dataset.py b/api/models/dataset.py index 567f7db432..1cf3dc42fe 100644 --- a/api/models/dataset.py +++ b/api/models/dataset.py @@ -13,6 +13,7 @@ from typing import Any, cast from sqlalchemy import func from sqlalchemy.dialects.postgresql import JSONB +from sqlalchemy.orm import Mapped from configs import dify_config from core.rag.retrieval.retrieval_methods import RetrievalMethod @@ -515,7 +516,7 @@ class DocumentSegment(db.Model): # type: ignore[name-defined] tenant_id = db.Column(StringUUID, nullable=False) dataset_id = db.Column(StringUUID, nullable=False) document_id = db.Column(StringUUID, nullable=False) - position = db.Column(db.Integer, nullable=False) + position: Mapped[int] content = db.Column(db.Text, nullable=False) answer = db.Column(db.Text, nullable=True) word_count = db.Column(db.Integer, nullable=False) diff --git a/api/tasks/batch_create_segment_to_index_task.py b/api/tasks/batch_create_segment_to_index_task.py index dbef6b708e..e2d2392797 100644 --- a/api/tasks/batch_create_segment_to_index_task.py +++ b/api/tasks/batch_create_segment_to_index_task.py @@ -5,7 +5,8 @@ import uuid import click from celery import shared_task # type: ignore -from sqlalchemy import func +from sqlalchemy import func, select +from sqlalchemy.orm import Session from core.model_manager import ModelManager from core.model_runtime.entities.model_entities import ModelType @@ -18,7 +19,12 @@ from services.vector_service import VectorService @shared_task(queue="dataset") def batch_create_segment_to_index_task( - job_id: str, content: list, dataset_id: str, document_id: str, tenant_id: str, user_id: str + job_id: str, + content: list, + dataset_id: str, + document_id: str, + tenant_id: str, + user_id: str, ): """ Async batch create segment to index @@ -37,71 +43,80 @@ def batch_create_segment_to_index_task( indexing_cache_key = "segment_batch_import_{}".format(job_id) try: - dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first() - if not dataset: - raise ValueError("Dataset not exist.") + with Session(db.engine) as session: + dataset = session.get(Dataset, dataset_id) + if not dataset: + raise ValueError("Dataset not exist.") - dataset_document = db.session.query(Document).filter(Document.id == document_id).first() - if not dataset_document: - raise ValueError("Document not exist.") + dataset_document = session.get(Document, document_id) + if not dataset_document: + raise ValueError("Document not exist.") - if not dataset_document.enabled or dataset_document.archived or dataset_document.indexing_status != "completed": - raise ValueError("Document is not available.") - document_segments = [] - embedding_model = None - if dataset.indexing_technique == "high_quality": - model_manager = ModelManager() - embedding_model = model_manager.get_model_instance( - tenant_id=dataset.tenant_id, - provider=dataset.embedding_model_provider, - model_type=ModelType.TEXT_EMBEDDING, - model=dataset.embedding_model, + if ( + not dataset_document.enabled + or dataset_document.archived + or dataset_document.indexing_status != "completed" + ): + raise ValueError("Document is not available.") + document_segments = [] + embedding_model = None + if dataset.indexing_technique == "high_quality": + model_manager = ModelManager() + embedding_model = model_manager.get_model_instance( + tenant_id=dataset.tenant_id, + provider=dataset.embedding_model_provider, + model_type=ModelType.TEXT_EMBEDDING, + model=dataset.embedding_model, + ) + word_count_change = 0 + segments_to_insert: list[str] = [] + max_position_stmt = select(func.max(DocumentSegment.position)).where( + DocumentSegment.document_id == dataset_document.id ) - word_count_change = 0 - segments_to_insert: list[str] = [] # Explicitly type hint the list as List[str] - for segment in content: - content_str = segment["content"] - doc_id = str(uuid.uuid4()) - segment_hash = helper.generate_text_hash(content_str) - # calc embedding use tokens - tokens = embedding_model.get_text_embedding_num_tokens(texts=[content_str]) if embedding_model else 0 - max_position = ( - db.session.query(func.max(DocumentSegment.position)) - .filter(DocumentSegment.document_id == dataset_document.id) - .scalar() - ) - segment_document = DocumentSegment( - tenant_id=tenant_id, - dataset_id=dataset_id, - document_id=document_id, - index_node_id=doc_id, - index_node_hash=segment_hash, - position=max_position + 1 if max_position else 1, - content=content_str, - word_count=len(content_str), - tokens=tokens, - created_by=user_id, - indexing_at=datetime.datetime.now(datetime.UTC).replace(tzinfo=None), - status="completed", - completed_at=datetime.datetime.now(datetime.UTC).replace(tzinfo=None), - ) - if dataset_document.doc_form == "qa_model": - segment_document.answer = segment["answer"] - segment_document.word_count += len(segment["answer"]) - word_count_change += segment_document.word_count - db.session.add(segment_document) - document_segments.append(segment_document) - segments_to_insert.append(str(segment)) # Cast to string if needed - # update document word count - dataset_document.word_count += word_count_change - db.session.add(dataset_document) - # add index to db - VectorService.create_segments_vector(None, document_segments, dataset, dataset_document.doc_form) - db.session.commit() + max_position = session.scalar(max_position_stmt) or 1 + for segment in content: + content_str = segment["content"] + doc_id = str(uuid.uuid4()) + segment_hash = helper.generate_text_hash(content_str) + # calc embedding use tokens + tokens = embedding_model.get_text_embedding_num_tokens(texts=[content_str]) if embedding_model else 0 + segment_document = DocumentSegment( + tenant_id=tenant_id, + dataset_id=dataset_id, + document_id=document_id, + index_node_id=doc_id, + index_node_hash=segment_hash, + position=max_position, + content=content_str, + word_count=len(content_str), + tokens=tokens, + created_by=user_id, + indexing_at=datetime.datetime.now(datetime.UTC).replace(tzinfo=None), + status="completed", + completed_at=datetime.datetime.now(datetime.UTC).replace(tzinfo=None), + ) + max_position += 1 + if dataset_document.doc_form == "qa_model": + segment_document.answer = segment["answer"] + segment_document.word_count += len(segment["answer"]) + word_count_change += segment_document.word_count + session.add(segment_document) + document_segments.append(segment_document) + segments_to_insert.append(str(segment)) # Cast to string if needed + # update document word count + dataset_document.word_count += word_count_change + session.add(dataset_document) + # add index to db + VectorService.create_segments_vector(None, document_segments, dataset, dataset_document.doc_form) + session.commit() + redis_client.setex(indexing_cache_key, 600, "completed") end_at = time.perf_counter() logging.info( - click.style("Segment batch created job: {} latency: {}".format(job_id, end_at - start_at), fg="green") + click.style( + "Segment batch created job: {} latency: {}".format(job_id, end_at - start_at), + fg="green", + ) ) except Exception as e: logging.exception("Segments batch created index failed")