diff --git a/api/commands.py b/api/commands.py index c05ed786aa..66278a53a3 100644 --- a/api/commands.py +++ b/api/commands.py @@ -552,11 +552,12 @@ def old_metadata_migration(): page = 1 while True: try: - documents = ( - DatasetDocument.query.filter(DatasetDocument.doc_metadata is not None) + stmt = ( + select(DatasetDocument) + .filter(DatasetDocument.doc_metadata.is_not(None)) .order_by(DatasetDocument.created_at.desc()) - .paginate(page=page, per_page=50) ) + documents = db.paginate(select=stmt, page=page, per_page=50, max_per_page=50, error_out=False) except NotFound: break if not documents: diff --git a/api/controllers/console/explore/installed_app.py b/api/controllers/console/explore/installed_app.py index 9336c35a0d..4062972d08 100644 --- a/api/controllers/console/explore/installed_app.py +++ b/api/controllers/console/explore/installed_app.py @@ -66,7 +66,7 @@ class InstalledAppsListApi(Resource): parser.add_argument("app_id", type=str, required=True, help="Invalid app_id") args = parser.parse_args() - recommended_app = RecommendedApp.query.filter(RecommendedApp.app_id == args["app_id"]).first() + recommended_app = db.session.query(RecommendedApp).filter(RecommendedApp.app_id == args["app_id"]).first() if recommended_app is None: raise NotFound("App not found") @@ -79,9 +79,11 @@ class InstalledAppsListApi(Resource): if not app.is_public: raise Forbidden("You can't install a non-public app") - installed_app = InstalledApp.query.filter( - and_(InstalledApp.app_id == args["app_id"], InstalledApp.tenant_id == current_tenant_id) - ).first() + installed_app = ( + db.session.query(InstalledApp) + .filter(and_(InstalledApp.app_id == args["app_id"], InstalledApp.tenant_id == current_tenant_id)) + .first() + ) if installed_app is None: # todo: position diff --git a/api/core/callback_handler/index_tool_callback_handler.py b/api/core/callback_handler/index_tool_callback_handler.py index 7908bd0467..13c22213c4 100644 --- a/api/core/callback_handler/index_tool_callback_handler.py +++ b/api/core/callback_handler/index_tool_callback_handler.py @@ -1,3 +1,5 @@ +import logging + from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.queue_entities import QueueRetrieverResourcesEvent @@ -7,6 +9,8 @@ from extensions.ext_database import db from models.dataset import ChildChunk, DatasetQuery, DocumentSegment from models.dataset import Document as DatasetDocument +_logger = logging.getLogger(__name__) + class DatasetIndexToolCallbackHandler: """Callback handler for dataset tool.""" @@ -42,9 +46,14 @@ class DatasetIndexToolCallbackHandler: """Handle tool end.""" for document in documents: if document.metadata is not None: - dataset_document = DatasetDocument.query.filter( - DatasetDocument.id == document.metadata["document_id"] - ).first() + document_id = document.metadata["document_id"] + dataset_document = db.session.query(DatasetDocument).filter(DatasetDocument.id == document_id).first() + if not dataset_document: + _logger.warning( + "Expected DatasetDocument record to exist, but none was found, document_id=%s", + document_id, + ) + continue if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX: child_chunk = ( db.session.query(ChildChunk) diff --git a/api/core/indexing_runner.py b/api/core/indexing_runner.py index c389496801..848d897779 100644 --- a/api/core/indexing_runner.py +++ b/api/core/indexing_runner.py @@ -660,10 +660,10 @@ class IndexingRunner: """ Update the document indexing status. """ - count = DatasetDocument.query.filter_by(id=document_id, is_paused=True).count() + count = db.session.query(DatasetDocument).filter_by(id=document_id, is_paused=True).count() if count > 0: raise DocumentIsPausedError() - document = DatasetDocument.query.filter_by(id=document_id).first() + document = db.session.query(DatasetDocument).filter_by(id=document_id).first() if not document: raise DocumentIsDeletedPausedError() @@ -672,7 +672,7 @@ class IndexingRunner: if extra_update_params: update_params.update(extra_update_params) - DatasetDocument.query.filter_by(id=document_id).update(update_params) + db.session.query(DatasetDocument).filter_by(id=document_id).update(update_params) db.session.commit() @staticmethod diff --git a/api/core/rag/extractor/notion_extractor.py b/api/core/rag/extractor/notion_extractor.py index 00a2150875..4e14800d0a 100644 --- a/api/core/rag/extractor/notion_extractor.py +++ b/api/core/rag/extractor/notion_extractor.py @@ -317,7 +317,7 @@ class NotionExtractor(BaseExtractor): data_source_info["last_edited_time"] = last_edited_time update_params = {DocumentModel.data_source_info: json.dumps(data_source_info)} - DocumentModel.query.filter_by(id=document_model.id).update(update_params) + db.session.query(DocumentModel).filter_by(id=document_model.id).update(update_params) db.session.commit() def get_notion_last_edited_time(self) -> str: diff --git a/api/core/rag/retrieval/dataset_retrieval.py b/api/core/rag/retrieval/dataset_retrieval.py index 444d7ee329..d3605da146 100644 --- a/api/core/rag/retrieval/dataset_retrieval.py +++ b/api/core/rag/retrieval/dataset_retrieval.py @@ -238,11 +238,15 @@ class DatasetRetrieval: for record in records: segment = record.segment dataset = db.session.query(Dataset).filter_by(id=segment.dataset_id).first() - document = DatasetDocument.query.filter( - DatasetDocument.id == segment.document_id, - DatasetDocument.enabled == True, - DatasetDocument.archived == False, - ).first() + document = ( + db.session.query(DatasetDocument) + .filter( + DatasetDocument.id == segment.document_id, + DatasetDocument.enabled == True, + DatasetDocument.archived == False, + ) + .first() + ) if dataset and document: source = { "dataset_id": dataset.id, @@ -506,9 +510,11 @@ class DatasetRetrieval: dify_documents = [document for document in documents if document.provider == "dify"] for document in dify_documents: if document.metadata is not None: - dataset_document = DatasetDocument.query.filter( - DatasetDocument.id == document.metadata["document_id"] - ).first() + dataset_document = ( + db.session.query(DatasetDocument) + .filter(DatasetDocument.id == document.metadata["document_id"]) + .first() + ) if dataset_document: if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX: child_chunk = ( diff --git a/api/core/tools/utils/dataset_retriever/dataset_retriever_tool.py b/api/core/tools/utils/dataset_retriever/dataset_retriever_tool.py index c19c357d2a..fff261e0bd 100644 --- a/api/core/tools/utils/dataset_retriever/dataset_retriever_tool.py +++ b/api/core/tools/utils/dataset_retriever/dataset_retriever_tool.py @@ -186,11 +186,15 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool): for record in records: segment = record.segment dataset = db.session.query(Dataset).filter_by(id=segment.dataset_id).first() - document = DatasetDocument.query.filter( - DatasetDocument.id == segment.document_id, - DatasetDocument.enabled == True, - DatasetDocument.archived == False, - ).first() + document = ( + db.session.query(DatasetDocument) # type: ignore + .filter( + DatasetDocument.id == segment.document_id, + DatasetDocument.enabled == True, + DatasetDocument.archived == False, + ) + .first() + ) if dataset and document: source = { "dataset_id": dataset.id, diff --git a/api/schedule/clean_messages.py b/api/schedule/clean_messages.py index 5e4d3ec323..f41f5264c7 100644 --- a/api/schedule/clean_messages.py +++ b/api/schedule/clean_messages.py @@ -1,4 +1,5 @@ import datetime +import logging import time import click @@ -20,6 +21,8 @@ from models.model import ( from models.web import SavedMessage from services.feature_service import FeatureService +_logger = logging.getLogger(__name__) + @app.celery.task(queue="dataset") def clean_messages(): @@ -46,7 +49,14 @@ def clean_messages(): break for message in messages: plan_sandbox_clean_message_day = message.created_at - app = App.query.filter_by(id=message.app_id).first() + app = db.session.query(App).filter_by(id=message.app_id).first() + if not app: + _logger.warning( + "Expected App record to exist, but none was found, app_id=%s, message_id=%s", + message.app_id, + message.id, + ) + continue features_cache_key = f"features:{app.tenant_id}" plan_cache = redis_client.get(features_cache_key) if plan_cache is None: diff --git a/api/schedule/mail_clean_document_notify_task.py b/api/schedule/mail_clean_document_notify_task.py index 0b3ff5d47d..5ee813e1de 100644 --- a/api/schedule/mail_clean_document_notify_task.py +++ b/api/schedule/mail_clean_document_notify_task.py @@ -54,7 +54,7 @@ def mail_clean_document_notify_task(): ) if not current_owner_join: continue - account = Account.query.filter(Account.id == current_owner_join.account_id).first() + account = db.session.query(Account).filter(Account.id == current_owner_join.account_id).first() if not account: continue diff --git a/api/services/vector_service.py b/api/services/vector_service.py index 92422bf29d..18d10cc528 100644 --- a/api/services/vector_service.py +++ b/api/services/vector_service.py @@ -1,3 +1,4 @@ +import logging from typing import Optional from core.model_manager import ModelInstance, ModelManager @@ -12,6 +13,8 @@ from models.dataset import ChildChunk, Dataset, DatasetProcessRule, DocumentSegm from models.dataset import Document as DatasetDocument from services.entities.knowledge_entities.knowledge_entities import ParentMode +_logger = logging.getLogger(__name__) + class VectorService: @classmethod @@ -22,7 +25,14 @@ class VectorService: for segment in segments: if doc_form == IndexType.PARENT_CHILD_INDEX: - document = DatasetDocument.query.filter_by(id=segment.document_id).first() + document = db.session.query(DatasetDocument).filter_by(id=segment.document_id).first() + if not document: + _logger.warning( + "Expected DatasetDocument record to exist, but none was found, document_id=%s, segment_id=%s", + segment.document_id, + segment.id, + ) + continue # get the process rule processing_rule = ( db.session.query(DatasetProcessRule) @@ -52,7 +62,7 @@ class VectorService: raise ValueError("The knowledge base index technique is not high quality!") cls.generate_child_chunks(segment, document, dataset, embedding_model_instance, processing_rule, False) else: - document = Document( + document = Document( # type: ignore page_content=segment.content, metadata={ "doc_id": segment.index_node_id, @@ -64,7 +74,7 @@ class VectorService: documents.append(document) if len(documents) > 0: index_processor = IndexProcessorFactory(doc_form).init_index_processor() - index_processor.load(dataset, documents, with_keywords=True, keywords_list=keywords_list) + index_processor.load(dataset, documents, with_keywords=True, keywords_list=keywords_list) # type: ignore @classmethod def update_segment_vector(cls, keywords: Optional[list[str]], segment: DocumentSegment, dataset: Dataset):