Optimize knowledge retrieval performance by batching dataset quries. (#4917)

This commit is contained in:
JasonVV 2024-06-05 13:30:32 +08:00 committed by GitHub
parent 3006124e6d
commit 7749b71fff
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 27 additions and 19 deletions

View File

@ -329,6 +329,7 @@ class DatasetRetrieval:
""" """
if not query: if not query:
return return
dataset_queries = []
for dataset_id in dataset_ids: for dataset_id in dataset_ids:
dataset_query = DatasetQuery( dataset_query = DatasetQuery(
dataset_id=dataset_id, dataset_id=dataset_id,
@ -338,7 +339,9 @@ class DatasetRetrieval:
created_by_role=user_from, created_by_role=user_from,
created_by=user_id created_by=user_id
) )
db.session.add(dataset_query) dataset_queries.append(dataset_query)
if dataset_queries:
db.session.add_all(dataset_queries)
db.session.commit() db.session.commit()
def _retriever(self, flask_app: Flask, dataset_id: str, query: str, top_k: int, all_documents: list): def _retriever(self, flask_app: Flask, dataset_id: str, query: str, top_k: int, all_documents: list):

View File

@ -1,5 +1,7 @@
from typing import Any, cast from typing import Any, cast
from sqlalchemy import func
from core.app.app_config.entities import DatasetRetrieveConfigEntity from core.app.app_config.entities import DatasetRetrieveConfigEntity
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.entities.agent_entities import PlanningStrategy from core.entities.agent_entities import PlanningStrategy
@ -73,30 +75,33 @@ class KnowledgeRetrievalNode(BaseNode):
def _fetch_dataset_retriever(self, node_data: KnowledgeRetrievalNodeData, query: str) -> list[ def _fetch_dataset_retriever(self, node_data: KnowledgeRetrievalNodeData, query: str) -> list[
dict[str, Any]]: dict[str, Any]]:
"""
A dataset tool is a tool that can be used to retrieve information from a dataset
:param node_data: node data
:param query: query
"""
tools = []
available_datasets = [] available_datasets = []
dataset_ids = node_data.dataset_ids dataset_ids = node_data.dataset_ids
for dataset_id in dataset_ids:
# get dataset from dataset id
dataset = db.session.query(Dataset).filter(
Dataset.tenant_id == self.tenant_id,
Dataset.id == dataset_id
).first()
# Subquery: Count the number of available documents for each dataset
subquery = db.session.query(
Document.dataset_id,
func.count(Document.id).label('available_document_count')
).filter(
Document.indexing_status == 'completed',
Document.enabled == True,
Document.archived == False,
Document.dataset_id.in_(dataset_ids)
).group_by(Document.dataset_id).having(
func.count(Document.id) > 0
).subquery()
results = db.session.query(Dataset).join(
subquery, Dataset.id == subquery.c.dataset_id
).filter(
Dataset.tenant_id == self.tenant_id,
Dataset.id.in_(dataset_ids)
).all()
for dataset in results:
# pass if dataset is not available # pass if dataset is not available
if not dataset: if not dataset:
continue continue
# pass if dataset is not available
if (dataset and dataset.available_document_count == 0
and dataset.available_document_count == 0):
continue
available_datasets.append(dataset) available_datasets.append(dataset)
all_documents = [] all_documents = []
dataset_retrieval = DatasetRetrieval() dataset_retrieval = DatasetRetrieval()