From b00f94df644063732408561d405aed9364c8879b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=9D=9E=E6=B3=95=E6=93=8D=E4=BD=9C?= Date: Mon, 12 May 2025 13:52:33 +0800 Subject: [PATCH] fix: replace all dataset.Model.query to db.session.query(Model) (#19509) --- api/commands.py | 23 ++- api/controllers/console/datasets/datasets.py | 22 ++- .../console/datasets/datasets_document.py | 74 +++++--- .../console/datasets/datasets_segments.py | 87 ++++++---- .../service_api/dataset/document.py | 30 ++-- .../index_tool_callback_handler.py | 22 ++- api/core/indexing_runner.py | 28 +-- api/core/rag/retrieval/dataset_retrieval.py | 25 ++- .../dataset_multi_retriever_tool.py | 34 ++-- .../dataset_retriever_tool.py | 2 +- .../knowledge_retrieval_node.py | 16 +- api/models/dataset.py | 13 +- api/schedule/clean_unused_datasets_task.py | 21 ++- api/schedule/create_tidb_serverless_task.py | 4 +- .../mail_clean_document_notify_task.py | 6 +- .../update_tidb_serverless_status_task.py | 9 +- api/services/dataset_service.py | 162 +++++++++++------- api/services/external_knowledge_service.py | 75 ++++---- api/services/metadata_service.py | 36 ++-- api/tasks/create_segment_to_index_task.py | 4 +- api/tasks/deal_dataset_vector_index_task.py | 2 +- 21 files changed, 430 insertions(+), 265 deletions(-) diff --git a/api/commands.py b/api/commands.py index dc31dc0d80..c05ed786aa 100644 --- a/api/commands.py +++ b/api/commands.py @@ -6,6 +6,7 @@ from typing import Optional import click from flask import current_app +from sqlalchemy import select from werkzeug.exceptions import NotFound from configs import dify_config @@ -297,11 +298,11 @@ def migrate_knowledge_vector_database(): page = 1 while True: try: - datasets = ( - Dataset.query.filter(Dataset.indexing_technique == "high_quality") - .order_by(Dataset.created_at.desc()) - .paginate(page=page, per_page=50) + stmt = ( + select(Dataset).filter(Dataset.indexing_technique == "high_quality").order_by(Dataset.created_at.desc()) ) + + datasets = db.paginate(select=stmt, page=page, per_page=50, max_per_page=50, error_out=False) except NotFound: break @@ -592,11 +593,15 @@ def old_metadata_migration(): ) db.session.add(dataset_metadata_binding) else: - dataset_metadata_binding = DatasetMetadataBinding.query.filter( - DatasetMetadataBinding.dataset_id == document.dataset_id, - DatasetMetadataBinding.document_id == document.id, - DatasetMetadataBinding.metadata_id == dataset_metadata.id, - ).first() + dataset_metadata_binding = ( + db.session.query(DatasetMetadataBinding) # type: ignore + .filter( + DatasetMetadataBinding.dataset_id == document.dataset_id, + DatasetMetadataBinding.document_id == document.id, + DatasetMetadataBinding.metadata_id == dataset_metadata.id, + ) + .first() + ) if not dataset_metadata_binding: dataset_metadata_binding = DatasetMetadataBinding( tenant_id=document.tenant_id, diff --git a/api/controllers/console/datasets/datasets.py b/api/controllers/console/datasets/datasets.py index 571a395780..981619b0cb 100644 --- a/api/controllers/console/datasets/datasets.py +++ b/api/controllers/console/datasets/datasets.py @@ -526,14 +526,20 @@ class DatasetIndexingStatusApi(Resource): ) documents_status = [] for document in documents: - completed_segments = DocumentSegment.query.filter( - DocumentSegment.completed_at.isnot(None), - DocumentSegment.document_id == str(document.id), - DocumentSegment.status != "re_segment", - ).count() - total_segments = DocumentSegment.query.filter( - DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment" - ).count() + completed_segments = ( + db.session.query(DocumentSegment) + .filter( + DocumentSegment.completed_at.isnot(None), + DocumentSegment.document_id == str(document.id), + DocumentSegment.status != "re_segment", + ) + .count() + ) + total_segments = ( + db.session.query(DocumentSegment) + .filter(DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment") + .count() + ) document.completed_segments = completed_segments document.total_segments = total_segments documents_status.append(marshal(document, document_status_fields)) diff --git a/api/controllers/console/datasets/datasets_document.py b/api/controllers/console/datasets/datasets_document.py index 68601adfed..ca18c25e74 100644 --- a/api/controllers/console/datasets/datasets_document.py +++ b/api/controllers/console/datasets/datasets_document.py @@ -6,7 +6,7 @@ from typing import cast from flask import request from flask_login import current_user from flask_restful import Resource, fields, marshal, marshal_with, reqparse -from sqlalchemy import asc, desc +from sqlalchemy import asc, desc, select from werkzeug.exceptions import Forbidden, NotFound import services @@ -112,7 +112,7 @@ class GetProcessRuleApi(Resource): limits = DocumentService.DEFAULT_RULES["limits"] if document_id: # get the latest process rule - document = Document.query.get_or_404(document_id) + document = db.get_or_404(Document, document_id) dataset = DatasetService.get_dataset(document.dataset_id) @@ -175,7 +175,7 @@ class DatasetDocumentListApi(Resource): except services.errors.account.NoPermissionError as e: raise Forbidden(str(e)) - query = Document.query.filter_by(dataset_id=str(dataset_id), tenant_id=current_user.current_tenant_id) + query = select(Document).filter_by(dataset_id=str(dataset_id), tenant_id=current_user.current_tenant_id) if search: search = f"%{search}%" @@ -209,18 +209,24 @@ class DatasetDocumentListApi(Resource): desc(Document.position), ) - paginated_documents = query.paginate(page=page, per_page=limit, max_per_page=100, error_out=False) + paginated_documents = db.paginate(select=query, page=page, per_page=limit, max_per_page=100, error_out=False) documents = paginated_documents.items if fetch: for document in documents: - completed_segments = DocumentSegment.query.filter( - DocumentSegment.completed_at.isnot(None), - DocumentSegment.document_id == str(document.id), - DocumentSegment.status != "re_segment", - ).count() - total_segments = DocumentSegment.query.filter( - DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment" - ).count() + completed_segments = ( + db.session.query(DocumentSegment) + .filter( + DocumentSegment.completed_at.isnot(None), + DocumentSegment.document_id == str(document.id), + DocumentSegment.status != "re_segment", + ) + .count() + ) + total_segments = ( + db.session.query(DocumentSegment) + .filter(DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment") + .count() + ) document.completed_segments = completed_segments document.total_segments = total_segments data = marshal(documents, document_with_segments_fields) @@ -563,14 +569,20 @@ class DocumentBatchIndexingStatusApi(DocumentResource): documents = self.get_batch_documents(dataset_id, batch) documents_status = [] for document in documents: - completed_segments = DocumentSegment.query.filter( - DocumentSegment.completed_at.isnot(None), - DocumentSegment.document_id == str(document.id), - DocumentSegment.status != "re_segment", - ).count() - total_segments = DocumentSegment.query.filter( - DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment" - ).count() + completed_segments = ( + db.session.query(DocumentSegment) + .filter( + DocumentSegment.completed_at.isnot(None), + DocumentSegment.document_id == str(document.id), + DocumentSegment.status != "re_segment", + ) + .count() + ) + total_segments = ( + db.session.query(DocumentSegment) + .filter(DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment") + .count() + ) document.completed_segments = completed_segments document.total_segments = total_segments if document.is_paused: @@ -589,14 +601,20 @@ class DocumentIndexingStatusApi(DocumentResource): document_id = str(document_id) document = self.get_document(dataset_id, document_id) - completed_segments = DocumentSegment.query.filter( - DocumentSegment.completed_at.isnot(None), - DocumentSegment.document_id == str(document_id), - DocumentSegment.status != "re_segment", - ).count() - total_segments = DocumentSegment.query.filter( - DocumentSegment.document_id == str(document_id), DocumentSegment.status != "re_segment" - ).count() + completed_segments = ( + db.session.query(DocumentSegment) + .filter( + DocumentSegment.completed_at.isnot(None), + DocumentSegment.document_id == str(document_id), + DocumentSegment.status != "re_segment", + ) + .count() + ) + total_segments = ( + db.session.query(DocumentSegment) + .filter(DocumentSegment.document_id == str(document_id), DocumentSegment.status != "re_segment") + .count() + ) document.completed_segments = completed_segments document.total_segments = total_segments diff --git a/api/controllers/console/datasets/datasets_segments.py b/api/controllers/console/datasets/datasets_segments.py index a145038672..eee09ac32e 100644 --- a/api/controllers/console/datasets/datasets_segments.py +++ b/api/controllers/console/datasets/datasets_segments.py @@ -4,6 +4,7 @@ import pandas as pd from flask import request from flask_login import current_user from flask_restful import Resource, marshal, reqparse +from sqlalchemy import select from werkzeug.exceptions import Forbidden, NotFound import services @@ -26,6 +27,7 @@ from controllers.console.wraps import ( from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError from core.model_manager import ModelManager from core.model_runtime.entities.model_entities import ModelType +from extensions.ext_database import db from extensions.ext_redis import redis_client from fields.segment_fields import child_chunk_fields, segment_fields from libs.login import login_required @@ -74,9 +76,14 @@ class DatasetDocumentSegmentListApi(Resource): hit_count_gte = args["hit_count_gte"] keyword = args["keyword"] - query = DocumentSegment.query.filter( - DocumentSegment.document_id == str(document_id), DocumentSegment.tenant_id == current_user.current_tenant_id - ).order_by(DocumentSegment.position.asc()) + query = ( + select(DocumentSegment) + .filter( + DocumentSegment.document_id == str(document_id), + DocumentSegment.tenant_id == current_user.current_tenant_id, + ) + .order_by(DocumentSegment.position.asc()) + ) if status_list: query = query.filter(DocumentSegment.status.in_(status_list)) @@ -93,7 +100,7 @@ class DatasetDocumentSegmentListApi(Resource): elif args["enabled"].lower() == "false": query = query.filter(DocumentSegment.enabled == False) - segments = query.paginate(page=page, per_page=limit, max_per_page=100, error_out=False) + segments = db.paginate(select=query, page=page, per_page=limit, max_per_page=100, error_out=False) response = { "data": marshal(segments.items, segment_fields), @@ -276,9 +283,11 @@ class DatasetDocumentSegmentUpdateApi(Resource): raise ProviderNotInitializeError(ex.description) # check segment segment_id = str(segment_id) - segment = DocumentSegment.query.filter( - DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id - ).first() + segment = ( + db.session.query(DocumentSegment) + .filter(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id) + .first() + ) if not segment: raise NotFound("Segment not found.") # The role of the current user in the ta table must be admin, owner, dataset_operator, or editor @@ -320,9 +329,11 @@ class DatasetDocumentSegmentUpdateApi(Resource): raise NotFound("Document not found.") # check segment segment_id = str(segment_id) - segment = DocumentSegment.query.filter( - DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id - ).first() + segment = ( + db.session.query(DocumentSegment) + .filter(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id) + .first() + ) if not segment: raise NotFound("Segment not found.") # The role of the current user in the ta table must be admin, owner, dataset_operator, or editor @@ -423,9 +434,11 @@ class ChildChunkAddApi(Resource): raise NotFound("Document not found.") # check segment segment_id = str(segment_id) - segment = DocumentSegment.query.filter( - DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id - ).first() + segment = ( + db.session.query(DocumentSegment) + .filter(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id) + .first() + ) if not segment: raise NotFound("Segment not found.") if not current_user.is_dataset_editor: @@ -478,9 +491,11 @@ class ChildChunkAddApi(Resource): raise NotFound("Document not found.") # check segment segment_id = str(segment_id) - segment = DocumentSegment.query.filter( - DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id - ).first() + segment = ( + db.session.query(DocumentSegment) + .filter(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id) + .first() + ) if not segment: raise NotFound("Segment not found.") parser = reqparse.RequestParser() @@ -523,9 +538,11 @@ class ChildChunkAddApi(Resource): raise NotFound("Document not found.") # check segment segment_id = str(segment_id) - segment = DocumentSegment.query.filter( - DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id - ).first() + segment = ( + db.session.query(DocumentSegment) + .filter(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id) + .first() + ) if not segment: raise NotFound("Segment not found.") # The role of the current user in the ta table must be admin, owner, dataset_operator, or editor @@ -567,16 +584,20 @@ class ChildChunkUpdateApi(Resource): raise NotFound("Document not found.") # check segment segment_id = str(segment_id) - segment = DocumentSegment.query.filter( - DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id - ).first() + segment = ( + db.session.query(DocumentSegment) + .filter(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id) + .first() + ) if not segment: raise NotFound("Segment not found.") # check child chunk child_chunk_id = str(child_chunk_id) - child_chunk = ChildChunk.query.filter( - ChildChunk.id == str(child_chunk_id), ChildChunk.tenant_id == current_user.current_tenant_id - ).first() + child_chunk = ( + db.session.query(ChildChunk) + .filter(ChildChunk.id == str(child_chunk_id), ChildChunk.tenant_id == current_user.current_tenant_id) + .first() + ) if not child_chunk: raise NotFound("Child chunk not found.") # The role of the current user in the ta table must be admin, owner, dataset_operator, or editor @@ -612,16 +633,20 @@ class ChildChunkUpdateApi(Resource): raise NotFound("Document not found.") # check segment segment_id = str(segment_id) - segment = DocumentSegment.query.filter( - DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id - ).first() + segment = ( + db.session.query(DocumentSegment) + .filter(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id) + .first() + ) if not segment: raise NotFound("Segment not found.") # check child chunk child_chunk_id = str(child_chunk_id) - child_chunk = ChildChunk.query.filter( - ChildChunk.id == str(child_chunk_id), ChildChunk.tenant_id == current_user.current_tenant_id - ).first() + child_chunk = ( + db.session.query(ChildChunk) + .filter(ChildChunk.id == str(child_chunk_id), ChildChunk.tenant_id == current_user.current_tenant_id) + .first() + ) if not child_chunk: raise NotFound("Child chunk not found.") # The role of the current user in the ta table must be admin, owner, dataset_operator, or editor diff --git a/api/controllers/service_api/dataset/document.py b/api/controllers/service_api/dataset/document.py index 33eda37014..f0f39fc2e5 100644 --- a/api/controllers/service_api/dataset/document.py +++ b/api/controllers/service_api/dataset/document.py @@ -2,10 +2,10 @@ import json from flask import request from flask_restful import marshal, reqparse -from sqlalchemy import desc +from sqlalchemy import desc, select from werkzeug.exceptions import NotFound -import services.dataset_service +import services from controllers.common.errors import FilenameNotExistsError from controllers.service_api import api from controllers.service_api.app.error import ( @@ -337,7 +337,7 @@ class DocumentListApi(DatasetApiResource): if not dataset: raise NotFound("Dataset not found.") - query = Document.query.filter_by(dataset_id=str(dataset_id), tenant_id=tenant_id) + query = select(Document).filter_by(dataset_id=str(dataset_id), tenant_id=tenant_id) if search: search = f"%{search}%" @@ -345,7 +345,7 @@ class DocumentListApi(DatasetApiResource): query = query.order_by(desc(Document.created_at), desc(Document.position)) - paginated_documents = query.paginate(page=page, per_page=limit, max_per_page=100, error_out=False) + paginated_documents = db.paginate(select=query, page=page, per_page=limit, max_per_page=100, error_out=False) documents = paginated_documents.items response = { @@ -374,14 +374,20 @@ class DocumentIndexingStatusApi(DatasetApiResource): raise NotFound("Documents not found.") documents_status = [] for document in documents: - completed_segments = DocumentSegment.query.filter( - DocumentSegment.completed_at.isnot(None), - DocumentSegment.document_id == str(document.id), - DocumentSegment.status != "re_segment", - ).count() - total_segments = DocumentSegment.query.filter( - DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment" - ).count() + completed_segments = ( + db.session.query(DocumentSegment) + .filter( + DocumentSegment.completed_at.isnot(None), + DocumentSegment.document_id == str(document.id), + DocumentSegment.status != "re_segment", + ) + .count() + ) + total_segments = ( + db.session.query(DocumentSegment) + .filter(DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment") + .count() + ) document.completed_segments = completed_segments document.total_segments = total_segments if document.is_paused: diff --git a/api/core/callback_handler/index_tool_callback_handler.py b/api/core/callback_handler/index_tool_callback_handler.py index 56859df7f4..7908bd0467 100644 --- a/api/core/callback_handler/index_tool_callback_handler.py +++ b/api/core/callback_handler/index_tool_callback_handler.py @@ -46,14 +46,22 @@ class DatasetIndexToolCallbackHandler: DatasetDocument.id == document.metadata["document_id"] ).first() if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX: - child_chunk = ChildChunk.query.filter( - ChildChunk.index_node_id == document.metadata["doc_id"], - ChildChunk.dataset_id == dataset_document.dataset_id, - ChildChunk.document_id == dataset_document.id, - ).first() + child_chunk = ( + db.session.query(ChildChunk) + .filter( + ChildChunk.index_node_id == document.metadata["doc_id"], + ChildChunk.dataset_id == dataset_document.dataset_id, + ChildChunk.document_id == dataset_document.id, + ) + .first() + ) if child_chunk: - segment = DocumentSegment.query.filter(DocumentSegment.id == child_chunk.segment_id).update( - {DocumentSegment.hit_count: DocumentSegment.hit_count + 1}, synchronize_session=False + segment = ( + db.session.query(DocumentSegment) + .filter(DocumentSegment.id == child_chunk.segment_id) + .update( + {DocumentSegment.hit_count: DocumentSegment.hit_count + 1}, synchronize_session=False + ) ) else: query = db.session.query(DocumentSegment).filter( diff --git a/api/core/indexing_runner.py b/api/core/indexing_runner.py index 81bf59b2b6..c389496801 100644 --- a/api/core/indexing_runner.py +++ b/api/core/indexing_runner.py @@ -51,7 +51,7 @@ class IndexingRunner: for dataset_document in dataset_documents: try: # get dataset - dataset = Dataset.query.filter_by(id=dataset_document.dataset_id).first() + dataset = db.session.query(Dataset).filter_by(id=dataset_document.dataset_id).first() if not dataset: raise ValueError("no dataset found") @@ -103,15 +103,17 @@ class IndexingRunner: """Run the indexing process when the index_status is splitting.""" try: # get dataset - dataset = Dataset.query.filter_by(id=dataset_document.dataset_id).first() + dataset = db.session.query(Dataset).filter_by(id=dataset_document.dataset_id).first() if not dataset: raise ValueError("no dataset found") # get exist document_segment list and delete - document_segments = DocumentSegment.query.filter_by( - dataset_id=dataset.id, document_id=dataset_document.id - ).all() + document_segments = ( + db.session.query(DocumentSegment) + .filter_by(dataset_id=dataset.id, document_id=dataset_document.id) + .all() + ) for document_segment in document_segments: db.session.delete(document_segment) @@ -162,15 +164,17 @@ class IndexingRunner: """Run the indexing process when the index_status is indexing.""" try: # get dataset - dataset = Dataset.query.filter_by(id=dataset_document.dataset_id).first() + dataset = db.session.query(Dataset).filter_by(id=dataset_document.dataset_id).first() if not dataset: raise ValueError("no dataset found") # get exist document_segment list and delete - document_segments = DocumentSegment.query.filter_by( - dataset_id=dataset.id, document_id=dataset_document.id - ).all() + document_segments = ( + db.session.query(DocumentSegment) + .filter_by(dataset_id=dataset.id, document_id=dataset_document.id) + .all() + ) documents = [] if document_segments: @@ -254,7 +258,7 @@ class IndexingRunner: embedding_model_instance = None if dataset_id: - dataset = Dataset.query.filter_by(id=dataset_id).first() + dataset = db.session.query(Dataset).filter_by(id=dataset_id).first() if not dataset: raise ValueError("Dataset not found.") if dataset.indexing_technique == "high_quality" or indexing_technique == "high_quality": @@ -587,7 +591,7 @@ class IndexingRunner: @staticmethod def _process_keyword_index(flask_app, dataset_id, document_id, documents): with flask_app.app_context(): - dataset = Dataset.query.filter_by(id=dataset_id).first() + dataset = db.session.query(Dataset).filter_by(id=dataset_id).first() if not dataset: raise ValueError("no dataset found") keyword = Keyword(dataset) @@ -676,7 +680,7 @@ class IndexingRunner: """ Update the document segment by document id. """ - DocumentSegment.query.filter_by(document_id=dataset_document_id).update(update_params) + db.session.query(DocumentSegment).filter_by(document_id=dataset_document_id).update(update_params) db.session.commit() def _transform( diff --git a/api/core/rag/retrieval/dataset_retrieval.py b/api/core/rag/retrieval/dataset_retrieval.py index 9216b31b8e..444d7ee329 100644 --- a/api/core/rag/retrieval/dataset_retrieval.py +++ b/api/core/rag/retrieval/dataset_retrieval.py @@ -237,7 +237,7 @@ class DatasetRetrieval: if show_retrieve_source: for record in records: segment = record.segment - dataset = Dataset.query.filter_by(id=segment.dataset_id).first() + dataset = db.session.query(Dataset).filter_by(id=segment.dataset_id).first() document = DatasetDocument.query.filter( DatasetDocument.id == segment.document_id, DatasetDocument.enabled == True, @@ -511,14 +511,23 @@ class DatasetRetrieval: ).first() if dataset_document: if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX: - child_chunk = ChildChunk.query.filter( - ChildChunk.index_node_id == document.metadata["doc_id"], - ChildChunk.dataset_id == dataset_document.dataset_id, - ChildChunk.document_id == dataset_document.id, - ).first() + child_chunk = ( + db.session.query(ChildChunk) + .filter( + ChildChunk.index_node_id == document.metadata["doc_id"], + ChildChunk.dataset_id == dataset_document.dataset_id, + ChildChunk.document_id == dataset_document.id, + ) + .first() + ) if child_chunk: - segment = DocumentSegment.query.filter(DocumentSegment.id == child_chunk.segment_id).update( - {DocumentSegment.hit_count: DocumentSegment.hit_count + 1}, synchronize_session=False + segment = ( + db.session.query(DocumentSegment) + .filter(DocumentSegment.id == child_chunk.segment_id) + .update( + {DocumentSegment.hit_count: DocumentSegment.hit_count + 1}, + synchronize_session=False, + ) ) db.session.commit() else: diff --git a/api/core/tools/utils/dataset_retriever/dataset_multi_retriever_tool.py b/api/core/tools/utils/dataset_retriever/dataset_multi_retriever_tool.py index 032274b87e..04437ea6d8 100644 --- a/api/core/tools/utils/dataset_retriever/dataset_multi_retriever_tool.py +++ b/api/core/tools/utils/dataset_retriever/dataset_multi_retriever_tool.py @@ -84,13 +84,17 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool): document_context_list = [] index_node_ids = [document.metadata["doc_id"] for document in all_documents if document.metadata] - segments = DocumentSegment.query.filter( - DocumentSegment.dataset_id.in_(self.dataset_ids), - DocumentSegment.completed_at.isnot(None), - DocumentSegment.status == "completed", - DocumentSegment.enabled == True, - DocumentSegment.index_node_id.in_(index_node_ids), - ).all() + segments = ( + db.session.query(DocumentSegment) + .filter( + DocumentSegment.dataset_id.in_(self.dataset_ids), + DocumentSegment.completed_at.isnot(None), + DocumentSegment.status == "completed", + DocumentSegment.enabled == True, + DocumentSegment.index_node_id.in_(index_node_ids), + ) + .all() + ) if segments: index_node_id_to_position = {id: position for position, id in enumerate(index_node_ids)} @@ -106,12 +110,16 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool): context_list = [] resource_number = 1 for segment in sorted_segments: - dataset = Dataset.query.filter_by(id=segment.dataset_id).first() - document = Document.query.filter( - Document.id == segment.document_id, - Document.enabled == True, - Document.archived == False, - ).first() + dataset = db.session.query(Dataset).filter_by(id=segment.dataset_id).first() + document = ( + db.session.query(Document) + .filter( + Document.id == segment.document_id, + Document.enabled == True, + Document.archived == False, + ) + .first() + ) if dataset and document: source = { "position": resource_number, diff --git a/api/core/tools/utils/dataset_retriever/dataset_retriever_tool.py b/api/core/tools/utils/dataset_retriever/dataset_retriever_tool.py index ed97b44f95..c19c357d2a 100644 --- a/api/core/tools/utils/dataset_retriever/dataset_retriever_tool.py +++ b/api/core/tools/utils/dataset_retriever/dataset_retriever_tool.py @@ -185,7 +185,7 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool): if self.return_resource: for record in records: segment = record.segment - dataset = Dataset.query.filter_by(id=segment.dataset_id).first() + dataset = db.session.query(Dataset).filter_by(id=segment.dataset_id).first() document = DatasetDocument.query.filter( DatasetDocument.id == segment.document_id, DatasetDocument.enabled == True, diff --git a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py index c84a1897de..962deba1fd 100644 --- a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py +++ b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py @@ -275,12 +275,16 @@ class KnowledgeRetrievalNode(LLMNode): if records: for record in records: segment = record.segment - dataset = Dataset.query.filter_by(id=segment.dataset_id).first() - document = Document.query.filter( - Document.id == segment.document_id, - Document.enabled == True, - Document.archived == False, - ).first() + dataset = db.session.query(Dataset).filter_by(id=segment.dataset_id).first() # type: ignore + document = ( + db.session.query(Document) + .filter( + Document.id == segment.document_id, + Document.enabled == True, + Document.archived == False, + ) + .first() + ) if dataset and document: source = { "metadata": { diff --git a/api/models/dataset.py b/api/models/dataset.py index 94696f1633..ad43d6f371 100644 --- a/api/models/dataset.py +++ b/api/models/dataset.py @@ -93,7 +93,8 @@ class Dataset(Base): @property def latest_process_rule(self): return ( - DatasetProcessRule.query.filter(DatasetProcessRule.dataset_id == self.id) + db.session.query(DatasetProcessRule) + .filter(DatasetProcessRule.dataset_id == self.id) .order_by(DatasetProcessRule.created_at.desc()) .first() ) @@ -138,7 +139,8 @@ class Dataset(Base): @property def word_count(self): return ( - Document.query.with_entities(func.coalesce(func.sum(Document.word_count))) + db.session.query(Document) + .with_entities(func.coalesce(func.sum(Document.word_count))) .filter(Document.dataset_id == self.id) .scalar() ) @@ -440,12 +442,13 @@ class Document(Base): @property def segment_count(self): - return DocumentSegment.query.filter(DocumentSegment.document_id == self.id).count() + return db.session.query(DocumentSegment).filter(DocumentSegment.document_id == self.id).count() @property def hit_count(self): return ( - DocumentSegment.query.with_entities(func.coalesce(func.sum(DocumentSegment.hit_count))) + db.session.query(DocumentSegment) + .with_entities(func.coalesce(func.sum(DocumentSegment.hit_count))) .filter(DocumentSegment.document_id == self.id) .scalar() ) @@ -892,7 +895,7 @@ class DatasetKeywordTable(Base): return dct # get dataset - dataset = Dataset.query.filter_by(id=self.dataset_id).first() + dataset = db.session.query(Dataset).filter_by(id=self.dataset_id).first() if not dataset: return None if self.data_source_type == "database": diff --git a/api/schedule/clean_unused_datasets_task.py b/api/schedule/clean_unused_datasets_task.py index 4e7e443c2c..c0cd42a226 100644 --- a/api/schedule/clean_unused_datasets_task.py +++ b/api/schedule/clean_unused_datasets_task.py @@ -2,7 +2,7 @@ import datetime import time import click -from sqlalchemy import func +from sqlalchemy import func, select from werkzeug.exceptions import NotFound import app @@ -51,8 +51,9 @@ def clean_unused_datasets_task(): ) # Main query with join and filter - datasets = ( - Dataset.query.outerjoin(document_subquery_new, Dataset.id == document_subquery_new.c.dataset_id) + stmt = ( + select(Dataset) + .outerjoin(document_subquery_new, Dataset.id == document_subquery_new.c.dataset_id) .outerjoin(document_subquery_old, Dataset.id == document_subquery_old.c.dataset_id) .filter( Dataset.created_at < plan_sandbox_clean_day, @@ -60,9 +61,10 @@ def clean_unused_datasets_task(): func.coalesce(document_subquery_old.c.document_count, 0) > 0, ) .order_by(Dataset.created_at.desc()) - .paginate(page=1, per_page=50) ) + datasets = db.paginate(stmt, page=1, per_page=50) + except NotFound: break if datasets.items is None or len(datasets.items) == 0: @@ -99,7 +101,7 @@ def clean_unused_datasets_task(): # update document update_params = {Document.enabled: False} - Document.query.filter_by(dataset_id=dataset.id).update(update_params) + db.session.query(Document).filter_by(dataset_id=dataset.id).update(update_params) db.session.commit() click.echo(click.style("Cleaned unused dataset {} from db success!".format(dataset.id), fg="green")) except Exception as e: @@ -135,8 +137,9 @@ def clean_unused_datasets_task(): ) # Main query with join and filter - datasets = ( - Dataset.query.outerjoin(document_subquery_new, Dataset.id == document_subquery_new.c.dataset_id) + stmt = ( + select(Dataset) + .outerjoin(document_subquery_new, Dataset.id == document_subquery_new.c.dataset_id) .outerjoin(document_subquery_old, Dataset.id == document_subquery_old.c.dataset_id) .filter( Dataset.created_at < plan_pro_clean_day, @@ -144,8 +147,8 @@ def clean_unused_datasets_task(): func.coalesce(document_subquery_old.c.document_count, 0) > 0, ) .order_by(Dataset.created_at.desc()) - .paginate(page=1, per_page=50) ) + datasets = db.paginate(stmt, page=1, per_page=50) except NotFound: break @@ -175,7 +178,7 @@ def clean_unused_datasets_task(): # update document update_params = {Document.enabled: False} - Document.query.filter_by(dataset_id=dataset.id).update(update_params) + db.session.query(Document).filter_by(dataset_id=dataset.id).update(update_params) db.session.commit() click.echo( click.style("Cleaned unused dataset {} from db success!".format(dataset.id), fg="green") diff --git a/api/schedule/create_tidb_serverless_task.py b/api/schedule/create_tidb_serverless_task.py index 1c985461c6..8a02278de8 100644 --- a/api/schedule/create_tidb_serverless_task.py +++ b/api/schedule/create_tidb_serverless_task.py @@ -19,7 +19,9 @@ def create_tidb_serverless_task(): while True: try: # check the number of idle tidb serverless - idle_tidb_serverless_number = TidbAuthBinding.query.filter(TidbAuthBinding.active == False).count() + idle_tidb_serverless_number = ( + db.session.query(TidbAuthBinding).filter(TidbAuthBinding.active == False).count() + ) if idle_tidb_serverless_number >= tidb_serverless_number: break # create tidb serverless diff --git a/api/schedule/mail_clean_document_notify_task.py b/api/schedule/mail_clean_document_notify_task.py index 29d86935b5..2ceba8c486 100644 --- a/api/schedule/mail_clean_document_notify_task.py +++ b/api/schedule/mail_clean_document_notify_task.py @@ -29,7 +29,9 @@ def mail_clean_document_notify_task(): # send document clean notify mail try: - dataset_auto_disable_logs = DatasetAutoDisableLog.query.filter(DatasetAutoDisableLog.notified == False).all() + dataset_auto_disable_logs = ( + db.session.query(DatasetAutoDisableLog).filter(DatasetAutoDisableLog.notified == False).all() + ) # group by tenant_id dataset_auto_disable_logs_map: dict[str, list[DatasetAutoDisableLog]] = defaultdict(list) for dataset_auto_disable_log in dataset_auto_disable_logs: @@ -65,7 +67,7 @@ def mail_clean_document_notify_task(): ) for dataset_id, document_ids in dataset_auto_dataset_map.items(): - dataset = Dataset.query.filter(Dataset.id == dataset_id).first() + dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first() if dataset: document_count = len(document_ids) knowledge_details.append(rf"Knowledge base {dataset.name}: {document_count} documents") diff --git a/api/schedule/update_tidb_serverless_status_task.py b/api/schedule/update_tidb_serverless_status_task.py index 11a39e60ee..ce4ecb6e7c 100644 --- a/api/schedule/update_tidb_serverless_status_task.py +++ b/api/schedule/update_tidb_serverless_status_task.py @@ -5,6 +5,7 @@ import click import app from configs import dify_config from core.rag.datasource.vdb.tidb_on_qdrant.tidb_service import TidbService +from extensions.ext_database import db from models.dataset import TidbAuthBinding @@ -14,9 +15,11 @@ def update_tidb_serverless_status_task(): start_at = time.perf_counter() try: # check the number of idle tidb serverless - tidb_serverless_list = TidbAuthBinding.query.filter( - TidbAuthBinding.active == False, TidbAuthBinding.status == "CREATING" - ).all() + tidb_serverless_list = ( + db.session.query(TidbAuthBinding) + .filter(TidbAuthBinding.active == False, TidbAuthBinding.status == "CREATING") + .all() + ) if len(tidb_serverless_list) == 0: return # update tidb serverless status diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py index de90355ebf..f9fe39f977 100644 --- a/api/services/dataset_service.py +++ b/api/services/dataset_service.py @@ -9,7 +9,7 @@ from collections import Counter from typing import Any, Optional from flask_login import current_user -from sqlalchemy import func +from sqlalchemy import func, select from sqlalchemy.orm import Session from werkzeug.exceptions import NotFound @@ -77,11 +77,13 @@ from tasks.sync_website_document_indexing_task import sync_website_document_inde class DatasetService: @staticmethod def get_datasets(page, per_page, tenant_id=None, user=None, search=None, tag_ids=None, include_all=False): - query = Dataset.query.filter(Dataset.tenant_id == tenant_id).order_by(Dataset.created_at.desc()) + query = select(Dataset).filter(Dataset.tenant_id == tenant_id).order_by(Dataset.created_at.desc()) if user: # get permitted dataset ids - dataset_permission = DatasetPermission.query.filter_by(account_id=user.id, tenant_id=tenant_id).all() + dataset_permission = ( + db.session.query(DatasetPermission).filter_by(account_id=user.id, tenant_id=tenant_id).all() + ) permitted_dataset_ids = {dp.dataset_id for dp in dataset_permission} if dataset_permission else None if user.current_role == TenantAccountRole.DATASET_OPERATOR: @@ -129,7 +131,7 @@ class DatasetService: else: return [], 0 - datasets = query.paginate(page=page, per_page=per_page, max_per_page=100, error_out=False) + datasets = db.paginate(select=query, page=page, per_page=per_page, max_per_page=100, error_out=False) return datasets.items, datasets.total @@ -153,9 +155,10 @@ class DatasetService: @staticmethod def get_datasets_by_ids(ids, tenant_id): - datasets = Dataset.query.filter(Dataset.id.in_(ids), Dataset.tenant_id == tenant_id).paginate( - page=1, per_page=len(ids), max_per_page=len(ids), error_out=False - ) + stmt = select(Dataset).filter(Dataset.id.in_(ids), Dataset.tenant_id == tenant_id) + + datasets = db.paginate(select=stmt, page=1, per_page=len(ids), max_per_page=len(ids), error_out=False) + return datasets.items, datasets.total @staticmethod @@ -174,7 +177,7 @@ class DatasetService: retrieval_model: Optional[RetrievalModel] = None, ): # check if dataset name already exists - if Dataset.query.filter_by(name=name, tenant_id=tenant_id).first(): + if db.session.query(Dataset).filter_by(name=name, tenant_id=tenant_id).first(): raise DatasetNameDuplicateError(f"Dataset with name {name} already exists.") embedding_model = None if indexing_technique == "high_quality": @@ -235,7 +238,7 @@ class DatasetService: @staticmethod def get_dataset(dataset_id) -> Optional[Dataset]: - dataset: Optional[Dataset] = Dataset.query.filter_by(id=dataset_id).first() + dataset: Optional[Dataset] = db.session.query(Dataset).filter_by(id=dataset_id).first() return dataset @staticmethod @@ -436,7 +439,7 @@ class DatasetService: # update Retrieval model filtered_data["retrieval_model"] = data["retrieval_model"] - dataset.query.filter_by(id=dataset_id).update(filtered_data) + db.session.query(Dataset).filter_by(id=dataset_id).update(filtered_data) db.session.commit() if action: @@ -460,7 +463,7 @@ class DatasetService: @staticmethod def dataset_use_check(dataset_id) -> bool: - count = AppDatasetJoin.query.filter_by(dataset_id=dataset_id).count() + count = db.session.query(AppDatasetJoin).filter_by(dataset_id=dataset_id).count() if count > 0: return True return False @@ -475,7 +478,9 @@ class DatasetService: logging.debug(f"User {user.id} does not have permission to access dataset {dataset.id}") raise NoPermissionError("You do not have permission to access this dataset.") if dataset.permission == "partial_members": - user_permission = DatasetPermission.query.filter_by(dataset_id=dataset.id, account_id=user.id).first() + user_permission = ( + db.session.query(DatasetPermission).filter_by(dataset_id=dataset.id, account_id=user.id).first() + ) if ( not user_permission and dataset.tenant_id != user.current_tenant_id @@ -499,23 +504,24 @@ class DatasetService: elif dataset.permission == DatasetPermissionEnum.PARTIAL_TEAM: if not any( - dp.dataset_id == dataset.id for dp in DatasetPermission.query.filter_by(account_id=user.id).all() + dp.dataset_id == dataset.id + for dp in db.session.query(DatasetPermission).filter_by(account_id=user.id).all() ): raise NoPermissionError("You do not have permission to access this dataset.") @staticmethod def get_dataset_queries(dataset_id: str, page: int, per_page: int): - dataset_queries = ( - DatasetQuery.query.filter_by(dataset_id=dataset_id) - .order_by(db.desc(DatasetQuery.created_at)) - .paginate(page=page, per_page=per_page, max_per_page=100, error_out=False) - ) + stmt = select(DatasetQuery).filter_by(dataset_id=dataset_id).order_by(db.desc(DatasetQuery.created_at)) + + dataset_queries = db.paginate(select=stmt, page=page, per_page=per_page, max_per_page=100, error_out=False) + return dataset_queries.items, dataset_queries.total @staticmethod def get_related_apps(dataset_id: str): return ( - AppDatasetJoin.query.filter(AppDatasetJoin.dataset_id == dataset_id) + db.session.query(AppDatasetJoin) + .filter(AppDatasetJoin.dataset_id == dataset_id) .order_by(db.desc(AppDatasetJoin.created_at)) .all() ) @@ -530,10 +536,14 @@ class DatasetService: } # get recent 30 days auto disable logs start_date = datetime.datetime.now() - datetime.timedelta(days=30) - dataset_auto_disable_logs = DatasetAutoDisableLog.query.filter( - DatasetAutoDisableLog.dataset_id == dataset_id, - DatasetAutoDisableLog.created_at >= start_date, - ).all() + dataset_auto_disable_logs = ( + db.session.query(DatasetAutoDisableLog) + .filter( + DatasetAutoDisableLog.dataset_id == dataset_id, + DatasetAutoDisableLog.created_at >= start_date, + ) + .all() + ) if dataset_auto_disable_logs: return { "document_ids": [log.document_id for log in dataset_auto_disable_logs], @@ -873,7 +883,9 @@ class DocumentService: @staticmethod def get_documents_position(dataset_id): - document = Document.query.filter_by(dataset_id=dataset_id).order_by(Document.position.desc()).first() + document = ( + db.session.query(Document).filter_by(dataset_id=dataset_id).order_by(Document.position.desc()).first() + ) if document: return document.position + 1 else: @@ -1010,13 +1022,17 @@ class DocumentService: } # check duplicate if knowledge_config.duplicate: - document = Document.query.filter_by( - dataset_id=dataset.id, - tenant_id=current_user.current_tenant_id, - data_source_type="upload_file", - enabled=True, - name=file_name, - ).first() + document = ( + db.session.query(Document) + .filter_by( + dataset_id=dataset.id, + tenant_id=current_user.current_tenant_id, + data_source_type="upload_file", + enabled=True, + name=file_name, + ) + .first() + ) if document: document.dataset_process_rule_id = dataset_process_rule.id # type: ignore document.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) @@ -1054,12 +1070,16 @@ class DocumentService: raise ValueError("No notion info list found.") exist_page_ids = [] exist_document = {} - documents = Document.query.filter_by( - dataset_id=dataset.id, - tenant_id=current_user.current_tenant_id, - data_source_type="notion_import", - enabled=True, - ).all() + documents = ( + db.session.query(Document) + .filter_by( + dataset_id=dataset.id, + tenant_id=current_user.current_tenant_id, + data_source_type="notion_import", + enabled=True, + ) + .all() + ) if documents: for document in documents: data_source_info = json.loads(document.data_source_info) @@ -1206,12 +1226,16 @@ class DocumentService: @staticmethod def get_tenant_documents_count(): - documents_count = Document.query.filter( - Document.completed_at.isnot(None), - Document.enabled == True, - Document.archived == False, - Document.tenant_id == current_user.current_tenant_id, - ).count() + documents_count = ( + db.session.query(Document) + .filter( + Document.completed_at.isnot(None), + Document.enabled == True, + Document.archived == False, + Document.tenant_id == current_user.current_tenant_id, + ) + .count() + ) return documents_count @staticmethod @@ -1328,7 +1352,7 @@ class DocumentService: db.session.commit() # update document segment update_params = {DocumentSegment.status: "re_segment"} - DocumentSegment.query.filter_by(document_id=document.id).update(update_params) + db.session.query(DocumentSegment).filter_by(document_id=document.id).update(update_params) db.session.commit() # trigger async task document_indexing_update_task.delay(document.dataset_id, document.id) @@ -1918,7 +1942,8 @@ class SegmentService: @classmethod def delete_segments(cls, segment_ids: list, document: Document, dataset: Dataset): index_node_ids = ( - DocumentSegment.query.with_entities(DocumentSegment.index_node_id) + db.session.query(DocumentSegment) + .with_entities(DocumentSegment.index_node_id) .filter( DocumentSegment.id.in_(segment_ids), DocumentSegment.dataset_id == dataset.id, @@ -2157,20 +2182,28 @@ class SegmentService: def get_child_chunks( cls, segment_id: str, document_id: str, dataset_id: str, page: int, limit: int, keyword: Optional[str] = None ): - query = ChildChunk.query.filter_by( - tenant_id=current_user.current_tenant_id, - dataset_id=dataset_id, - document_id=document_id, - segment_id=segment_id, - ).order_by(ChildChunk.position.asc()) + query = ( + select(ChildChunk) + .filter_by( + tenant_id=current_user.current_tenant_id, + dataset_id=dataset_id, + document_id=document_id, + segment_id=segment_id, + ) + .order_by(ChildChunk.position.asc()) + ) if keyword: query = query.where(ChildChunk.content.ilike(f"%{keyword}%")) - return query.paginate(page=page, per_page=limit, max_per_page=100, error_out=False) + return db.paginate(select=query, page=page, per_page=limit, max_per_page=100, error_out=False) @classmethod def get_child_chunk_by_id(cls, child_chunk_id: str, tenant_id: str) -> Optional[ChildChunk]: """Get a child chunk by its ID.""" - result = ChildChunk.query.filter(ChildChunk.id == child_chunk_id, ChildChunk.tenant_id == tenant_id).first() + result = ( + db.session.query(ChildChunk) + .filter(ChildChunk.id == child_chunk_id, ChildChunk.tenant_id == tenant_id) + .first() + ) return result if isinstance(result, ChildChunk) else None @classmethod @@ -2184,7 +2217,7 @@ class SegmentService: limit: int = 20, ): """Get segments for a document with optional filtering.""" - query = DocumentSegment.query.filter( + query = select(DocumentSegment).filter( DocumentSegment.document_id == document_id, DocumentSegment.tenant_id == tenant_id ) @@ -2194,9 +2227,8 @@ class SegmentService: if keyword: query = query.filter(DocumentSegment.content.ilike(f"%{keyword}%")) - paginated_segments = query.order_by(DocumentSegment.position.asc()).paginate( - page=page, per_page=limit, max_per_page=100, error_out=False - ) + query = query.order_by(DocumentSegment.position.asc()) + paginated_segments = db.paginate(select=query, page=page, per_page=limit, max_per_page=100, error_out=False) return paginated_segments.items, paginated_segments.total @@ -2236,9 +2268,11 @@ class SegmentService: raise ValueError(ex.description) # check segment - segment = DocumentSegment.query.filter( - DocumentSegment.id == segment_id, DocumentSegment.tenant_id == user_id - ).first() + segment = ( + db.session.query(DocumentSegment) + .filter(DocumentSegment.id == segment_id, DocumentSegment.tenant_id == user_id) + .first() + ) if not segment: raise NotFound("Segment not found.") @@ -2251,9 +2285,11 @@ class SegmentService: @classmethod def get_segment_by_id(cls, segment_id: str, tenant_id: str) -> Optional[DocumentSegment]: """Get a segment by its ID.""" - result = DocumentSegment.query.filter( - DocumentSegment.id == segment_id, DocumentSegment.tenant_id == tenant_id - ).first() + result = ( + db.session.query(DocumentSegment) + .filter(DocumentSegment.id == segment_id, DocumentSegment.tenant_id == tenant_id) + .first() + ) return result if isinstance(result, DocumentSegment) else None diff --git a/api/services/external_knowledge_service.py b/api/services/external_knowledge_service.py index 6b75c29d95..eb50d79494 100644 --- a/api/services/external_knowledge_service.py +++ b/api/services/external_knowledge_service.py @@ -5,6 +5,7 @@ from typing import Any, Optional, Union, cast from urllib.parse import urlparse import httpx +from sqlalchemy import select from constants import HIDDEN_VALUE from core.helper import ssrf_proxy @@ -24,14 +25,20 @@ from services.errors.dataset import DatasetNameDuplicateError class ExternalDatasetService: @staticmethod - def get_external_knowledge_apis(page, per_page, tenant_id, search=None) -> tuple[list[ExternalKnowledgeApis], int]: - query = ExternalKnowledgeApis.query.filter(ExternalKnowledgeApis.tenant_id == tenant_id).order_by( - ExternalKnowledgeApis.created_at.desc() + def get_external_knowledge_apis( + page, per_page, tenant_id, search=None + ) -> tuple[list[ExternalKnowledgeApis], int | None]: + query = ( + select(ExternalKnowledgeApis) + .filter(ExternalKnowledgeApis.tenant_id == tenant_id) + .order_by(ExternalKnowledgeApis.created_at.desc()) ) if search: query = query.filter(ExternalKnowledgeApis.name.ilike(f"%{search}%")) - external_knowledge_apis = query.paginate(page=page, per_page=per_page, max_per_page=100, error_out=False) + external_knowledge_apis = db.paginate( + select=query, page=page, per_page=per_page, max_per_page=100, error_out=False + ) return external_knowledge_apis.items, external_knowledge_apis.total @@ -92,18 +99,18 @@ class ExternalDatasetService: @staticmethod def get_external_knowledge_api(external_knowledge_api_id: str) -> ExternalKnowledgeApis: - external_knowledge_api: Optional[ExternalKnowledgeApis] = ExternalKnowledgeApis.query.filter_by( - id=external_knowledge_api_id - ).first() + external_knowledge_api: Optional[ExternalKnowledgeApis] = ( + db.session.query(ExternalKnowledgeApis).filter_by(id=external_knowledge_api_id).first() + ) if external_knowledge_api is None: raise ValueError("api template not found") return external_knowledge_api @staticmethod def update_external_knowledge_api(tenant_id, user_id, external_knowledge_api_id, args) -> ExternalKnowledgeApis: - external_knowledge_api: Optional[ExternalKnowledgeApis] = ExternalKnowledgeApis.query.filter_by( - id=external_knowledge_api_id, tenant_id=tenant_id - ).first() + external_knowledge_api: Optional[ExternalKnowledgeApis] = ( + db.session.query(ExternalKnowledgeApis).filter_by(id=external_knowledge_api_id, tenant_id=tenant_id).first() + ) if external_knowledge_api is None: raise ValueError("api template not found") if args.get("settings") and args.get("settings").get("api_key") == HIDDEN_VALUE: @@ -120,9 +127,9 @@ class ExternalDatasetService: @staticmethod def delete_external_knowledge_api(tenant_id: str, external_knowledge_api_id: str): - external_knowledge_api = ExternalKnowledgeApis.query.filter_by( - id=external_knowledge_api_id, tenant_id=tenant_id - ).first() + external_knowledge_api = ( + db.session.query(ExternalKnowledgeApis).filter_by(id=external_knowledge_api_id, tenant_id=tenant_id).first() + ) if external_knowledge_api is None: raise ValueError("api template not found") @@ -131,25 +138,29 @@ class ExternalDatasetService: @staticmethod def external_knowledge_api_use_check(external_knowledge_api_id: str) -> tuple[bool, int]: - count = ExternalKnowledgeBindings.query.filter_by(external_knowledge_api_id=external_knowledge_api_id).count() + count = ( + db.session.query(ExternalKnowledgeBindings) + .filter_by(external_knowledge_api_id=external_knowledge_api_id) + .count() + ) if count > 0: return True, count return False, 0 @staticmethod def get_external_knowledge_binding_with_dataset_id(tenant_id: str, dataset_id: str) -> ExternalKnowledgeBindings: - external_knowledge_binding: Optional[ExternalKnowledgeBindings] = ExternalKnowledgeBindings.query.filter_by( - dataset_id=dataset_id, tenant_id=tenant_id - ).first() + external_knowledge_binding: Optional[ExternalKnowledgeBindings] = ( + db.session.query(ExternalKnowledgeBindings).filter_by(dataset_id=dataset_id, tenant_id=tenant_id).first() + ) if not external_knowledge_binding: raise ValueError("external knowledge binding not found") return external_knowledge_binding @staticmethod def document_create_args_validate(tenant_id: str, external_knowledge_api_id: str, process_parameter: dict): - external_knowledge_api = ExternalKnowledgeApis.query.filter_by( - id=external_knowledge_api_id, tenant_id=tenant_id - ).first() + external_knowledge_api = ( + db.session.query(ExternalKnowledgeApis).filter_by(id=external_knowledge_api_id, tenant_id=tenant_id).first() + ) if external_knowledge_api is None: raise ValueError("api template not found") settings = json.loads(external_knowledge_api.settings) @@ -212,11 +223,13 @@ class ExternalDatasetService: @staticmethod def create_external_dataset(tenant_id: str, user_id: str, args: dict) -> Dataset: # check if dataset name already exists - if Dataset.query.filter_by(name=args.get("name"), tenant_id=tenant_id).first(): + if db.session.query(Dataset).filter_by(name=args.get("name"), tenant_id=tenant_id).first(): raise DatasetNameDuplicateError(f"Dataset with name {args.get('name')} already exists.") - external_knowledge_api = ExternalKnowledgeApis.query.filter_by( - id=args.get("external_knowledge_api_id"), tenant_id=tenant_id - ).first() + external_knowledge_api = ( + db.session.query(ExternalKnowledgeApis) + .filter_by(id=args.get("external_knowledge_api_id"), tenant_id=tenant_id) + .first() + ) if external_knowledge_api is None: raise ValueError("api template not found") @@ -254,15 +267,17 @@ class ExternalDatasetService: external_retrieval_parameters: dict, metadata_condition: Optional[MetadataCondition] = None, ) -> list: - external_knowledge_binding = ExternalKnowledgeBindings.query.filter_by( - dataset_id=dataset_id, tenant_id=tenant_id - ).first() + external_knowledge_binding = ( + db.session.query(ExternalKnowledgeBindings).filter_by(dataset_id=dataset_id, tenant_id=tenant_id).first() + ) if not external_knowledge_binding: raise ValueError("external knowledge binding not found") - external_knowledge_api = ExternalKnowledgeApis.query.filter_by( - id=external_knowledge_binding.external_knowledge_api_id - ).first() + external_knowledge_api = ( + db.session.query(ExternalKnowledgeApis) + .filter_by(id=external_knowledge_binding.external_knowledge_api_id) + .first() + ) if not external_knowledge_api: raise ValueError("external api template not found") diff --git a/api/services/metadata_service.py b/api/services/metadata_service.py index c47c16f2f7..26d6d4ce18 100644 --- a/api/services/metadata_service.py +++ b/api/services/metadata_service.py @@ -20,9 +20,11 @@ class MetadataService: @staticmethod def create_metadata(dataset_id: str, metadata_args: MetadataArgs) -> DatasetMetadata: # check if metadata name already exists - if DatasetMetadata.query.filter_by( - tenant_id=current_user.current_tenant_id, dataset_id=dataset_id, name=metadata_args.name - ).first(): + if ( + db.session.query(DatasetMetadata) + .filter_by(tenant_id=current_user.current_tenant_id, dataset_id=dataset_id, name=metadata_args.name) + .first() + ): raise ValueError("Metadata name already exists.") for field in BuiltInField: if field.value == metadata_args.name: @@ -42,16 +44,18 @@ class MetadataService: def update_metadata_name(dataset_id: str, metadata_id: str, name: str) -> DatasetMetadata: # type: ignore lock_key = f"dataset_metadata_lock_{dataset_id}" # check if metadata name already exists - if DatasetMetadata.query.filter_by( - tenant_id=current_user.current_tenant_id, dataset_id=dataset_id, name=name - ).first(): + if ( + db.session.query(DatasetMetadata) + .filter_by(tenant_id=current_user.current_tenant_id, dataset_id=dataset_id, name=name) + .first() + ): raise ValueError("Metadata name already exists.") for field in BuiltInField: if field.value == name: raise ValueError("Metadata name already exists in Built-in fields.") try: MetadataService.knowledge_base_metadata_lock_check(dataset_id, None) - metadata = DatasetMetadata.query.filter_by(id=metadata_id).first() + metadata = db.session.query(DatasetMetadata).filter_by(id=metadata_id).first() if metadata is None: raise ValueError("Metadata not found.") old_name = metadata.name @@ -60,7 +64,9 @@ class MetadataService: metadata.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) # update related documents - dataset_metadata_bindings = DatasetMetadataBinding.query.filter_by(metadata_id=metadata_id).all() + dataset_metadata_bindings = ( + db.session.query(DatasetMetadataBinding).filter_by(metadata_id=metadata_id).all() + ) if dataset_metadata_bindings: document_ids = [binding.document_id for binding in dataset_metadata_bindings] documents = DocumentService.get_document_by_ids(document_ids) @@ -82,13 +88,15 @@ class MetadataService: lock_key = f"dataset_metadata_lock_{dataset_id}" try: MetadataService.knowledge_base_metadata_lock_check(dataset_id, None) - metadata = DatasetMetadata.query.filter_by(id=metadata_id).first() + metadata = db.session.query(DatasetMetadata).filter_by(id=metadata_id).first() if metadata is None: raise ValueError("Metadata not found.") db.session.delete(metadata) # deal related documents - dataset_metadata_bindings = DatasetMetadataBinding.query.filter_by(metadata_id=metadata_id).all() + dataset_metadata_bindings = ( + db.session.query(DatasetMetadataBinding).filter_by(metadata_id=metadata_id).all() + ) if dataset_metadata_bindings: document_ids = [binding.document_id for binding in dataset_metadata_bindings] documents = DocumentService.get_document_by_ids(document_ids) @@ -193,7 +201,7 @@ class MetadataService: db.session.add(document) db.session.commit() # deal metadata binding - DatasetMetadataBinding.query.filter_by(document_id=operation.document_id).delete() + db.session.query(DatasetMetadataBinding).filter_by(document_id=operation.document_id).delete() for metadata_value in operation.metadata_list: dataset_metadata_binding = DatasetMetadataBinding( tenant_id=current_user.current_tenant_id, @@ -230,9 +238,9 @@ class MetadataService: "id": item.get("id"), "name": item.get("name"), "type": item.get("type"), - "count": DatasetMetadataBinding.query.filter_by( - metadata_id=item.get("id"), dataset_id=dataset.id - ).count(), + "count": db.session.query(DatasetMetadataBinding) + .filter_by(metadata_id=item.get("id"), dataset_id=dataset.id) + .count(), } for item in dataset.doc_metadata or [] if item.get("id") != "built-in" diff --git a/api/tasks/create_segment_to_index_task.py b/api/tasks/create_segment_to_index_task.py index 4500b2a44b..a3f811faa1 100644 --- a/api/tasks/create_segment_to_index_task.py +++ b/api/tasks/create_segment_to_index_task.py @@ -41,7 +41,7 @@ def create_segment_to_index_task(segment_id: str, keywords: Optional[list[str]] DocumentSegment.status: "indexing", DocumentSegment.indexing_at: datetime.datetime.now(datetime.UTC).replace(tzinfo=None), } - DocumentSegment.query.filter_by(id=segment.id).update(update_params) + db.session.query(DocumentSegment).filter_by(id=segment.id).update(update_params) db.session.commit() document = Document( page_content=segment.content, @@ -78,7 +78,7 @@ def create_segment_to_index_task(segment_id: str, keywords: Optional[list[str]] DocumentSegment.status: "completed", DocumentSegment.completed_at: datetime.datetime.now(datetime.UTC).replace(tzinfo=None), } - DocumentSegment.query.filter_by(id=segment.id).update(update_params) + db.session.query(DocumentSegment).filter_by(id=segment.id).update(update_params) db.session.commit() end_at = time.perf_counter() diff --git a/api/tasks/deal_dataset_vector_index_task.py b/api/tasks/deal_dataset_vector_index_task.py index 075453e283..a27207f2f1 100644 --- a/api/tasks/deal_dataset_vector_index_task.py +++ b/api/tasks/deal_dataset_vector_index_task.py @@ -24,7 +24,7 @@ def deal_dataset_vector_index_task(dataset_id: str, action: str): start_at = time.perf_counter() try: - dataset = Dataset.query.filter_by(id=dataset_id).first() + dataset = db.session.query(Dataset).filter_by(id=dataset_id).first() if not dataset: raise Exception("Dataset not found")