Feat/add knowledge include all filter (#12537)

This commit is contained in:
Jyong 2025-01-09 20:21:25 +08:00 committed by GitHub
parent 2e97ba5700
commit 14ee51aead
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 8 additions and 5 deletions

View File

@ -52,12 +52,12 @@ class DatasetListApi(Resource):
# provider = request.args.get("provider", default="vendor") # provider = request.args.get("provider", default="vendor")
search = request.args.get("keyword", default=None, type=str) search = request.args.get("keyword", default=None, type=str)
tag_ids = request.args.getlist("tag_ids") tag_ids = request.args.getlist("tag_ids")
include_all = request.args.get("include_all", default="false").lower() == "true"
if ids: if ids:
datasets, total = DatasetService.get_datasets_by_ids(ids, current_user.current_tenant_id) datasets, total = DatasetService.get_datasets_by_ids(ids, current_user.current_tenant_id)
else: else:
datasets, total = DatasetService.get_datasets( datasets, total = DatasetService.get_datasets(
page, limit, current_user.current_tenant_id, current_user, search, tag_ids page, limit, current_user.current_tenant_id, current_user, search, tag_ids, include_all
) )
# check embedding setting # check embedding setting

View File

@ -31,8 +31,11 @@ class DatasetListApi(DatasetApiResource):
# provider = request.args.get("provider", default="vendor") # provider = request.args.get("provider", default="vendor")
search = request.args.get("keyword", default=None, type=str) search = request.args.get("keyword", default=None, type=str)
tag_ids = request.args.getlist("tag_ids") tag_ids = request.args.getlist("tag_ids")
include_all = request.args.get("include_all", default="false").lower() == "true"
datasets, total = DatasetService.get_datasets(page, limit, tenant_id, current_user, search, tag_ids) datasets, total = DatasetService.get_datasets(
page, limit, tenant_id, current_user, search, tag_ids, include_all
)
# check embedding setting # check embedding setting
provider_manager = ProviderManager() provider_manager = ProviderManager()
configurations = provider_manager.get_configurations(tenant_id=current_user.current_tenant_id) configurations = provider_manager.get_configurations(tenant_id=current_user.current_tenant_id)

View File

@ -71,7 +71,7 @@ from tasks.sync_website_document_indexing_task import sync_website_document_inde
class DatasetService: class DatasetService:
@staticmethod @staticmethod
def get_datasets(page, per_page, tenant_id=None, user=None, search=None, tag_ids=None): def get_datasets(page, per_page, tenant_id=None, user=None, search=None, tag_ids=None, include_all=False):
query = Dataset.query.filter(Dataset.tenant_id == tenant_id).order_by(Dataset.created_at.desc()) query = Dataset.query.filter(Dataset.tenant_id == tenant_id).order_by(Dataset.created_at.desc())
if user: if user:
@ -86,7 +86,7 @@ class DatasetService:
else: else:
return [], 0 return [], 0
else: else:
if user.current_role != TenantAccountRole.OWNER: if user.current_role != TenantAccountRole.OWNER or not include_all:
# show all datasets that the user has permission to access # show all datasets that the user has permission to access
if permitted_dataset_ids: if permitted_dataset_ids:
query = query.filter( query = query.filter(