diff --git a/api/core/rag/retrieval/dataset_retrieval.py b/api/core/rag/retrieval/dataset_retrieval.py index 4d8f826427..b42a441a3f 100644 --- a/api/core/rag/retrieval/dataset_retrieval.py +++ b/api/core/rag/retrieval/dataset_retrieval.py @@ -329,6 +329,7 @@ class DatasetRetrieval: """ if not query: return + dataset_queries = [] for dataset_id in dataset_ids: dataset_query = DatasetQuery( dataset_id=dataset_id, @@ -338,7 +339,9 @@ class DatasetRetrieval: created_by_role=user_from, created_by=user_id ) - db.session.add(dataset_query) + dataset_queries.append(dataset_query) + if dataset_queries: + db.session.add_all(dataset_queries) db.session.commit() def _retriever(self, flask_app: Flask, dataset_id: str, query: str, top_k: int, all_documents: list): diff --git a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py index e9dfb75f17..1a0f3b0495 100644 --- a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py +++ b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py @@ -1,5 +1,7 @@ from typing import Any, cast +from sqlalchemy import func + from core.app.app_config.entities import DatasetRetrieveConfigEntity from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.entities.agent_entities import PlanningStrategy @@ -73,30 +75,33 @@ class KnowledgeRetrievalNode(BaseNode): def _fetch_dataset_retriever(self, node_data: KnowledgeRetrievalNodeData, query: str) -> list[ dict[str, Any]]: - """ - A dataset tool is a tool that can be used to retrieve information from a dataset - :param node_data: node data - :param query: query - """ - tools = [] available_datasets = [] dataset_ids = node_data.dataset_ids - for dataset_id in dataset_ids: - # get dataset from dataset id - dataset = db.session.query(Dataset).filter( - Dataset.tenant_id == self.tenant_id, - Dataset.id == dataset_id - ).first() + # Subquery: Count the number of available documents for each dataset + subquery = db.session.query( + Document.dataset_id, + func.count(Document.id).label('available_document_count') + ).filter( + Document.indexing_status == 'completed', + Document.enabled == True, + Document.archived == False, + Document.dataset_id.in_(dataset_ids) + ).group_by(Document.dataset_id).having( + func.count(Document.id) > 0 + ).subquery() + + results = db.session.query(Dataset).join( + subquery, Dataset.id == subquery.c.dataset_id + ).filter( + Dataset.tenant_id == self.tenant_id, + Dataset.id.in_(dataset_ids) + ).all() + + for dataset in results: # pass if dataset is not available if not dataset: continue - - # pass if dataset is not available - if (dataset and dataset.available_document_count == 0 - and dataset.available_document_count == 0): - continue - available_datasets.append(dataset) all_documents = [] dataset_retrieval = DatasetRetrieval()