From 1eb4caf02a9fdb1f3c5a7e8f304163d8823bfa5c Mon Sep 17 00:00:00 2001 From: cecilia-uu <117628326+cecilia-uu@users.noreply.github.com> Date: Mon, 17 Jun 2024 12:19:05 +0800 Subject: [PATCH] create list_dataset api and tests (#1138) ### What problem does this PR solve? This PR have completed both HTTP API and Python SDK for 'list_dataset". In addition, there are tests for it. ### Type of change - [x] New Feature (non-breaking change which adds functionality) --- api/apps/dataset_api.py | 26 +++++-- api/db/services/knowledgebase_service.py | 23 ++++++ sdk/python/ragflow/__init__.py | 2 + sdk/python/ragflow/dataset.py | 2 +- sdk/python/ragflow/ragflow.py | 46 +++++++++--- sdk/python/test/common.py | 2 +- sdk/python/test/test_dataset.py | 94 ++++++++++++++++++++++-- 7 files changed, 170 insertions(+), 25 deletions(-) diff --git a/api/apps/dataset_api.py b/api/apps/dataset_api.py index 8d3db6b21..9eec34148 100644 --- a/api/apps/dataset_api.py +++ b/api/apps/dataset_api.py @@ -46,7 +46,7 @@ from api.contants import NAME_LENGTH_LIMIT # ------------------------------ create a dataset --------------------------------------- @manager.route('/', methods=['POST']) -@login_required # use login +@login_required # use login @validate_request("name") # check name key def create_dataset(): # Check if Authorization header is present @@ -111,10 +111,27 @@ def create_dataset(): if not KnowledgebaseService.save(**request_body): # failed to create new dataset return construct_result() - return construct_json_result(data={"dataset_id": request_body["id"]}) + return construct_json_result(data={"dataset_name": request_body["name"]}) except Exception as e: return construct_error_response(e) +# -----------------------------list datasets------------------------------------------------------- +@manager.route('/', methods=['GET']) +@login_required +def list_datasets(): + offset = request.args.get("offset", 0) + count = request.args.get("count", -1) + orderby = request.args.get("orderby", "create_time") + desc = request.args.get("desc", True) + try: + tenants = TenantService.get_joined_tenants_by_user_id(current_user.id) + kbs = KnowledgebaseService.get_by_tenant_ids( + [m["tenant_id"] for m in tenants], current_user.id, int(offset), int(count), orderby, desc) + return construct_json_result(data=kbs, code=RetCode.DATA_ERROR, message=f"attempt to list datasets") + except Exception as e: + return construct_error_response(e) + +# ---------------------------------delete a dataset ---------------------------- @manager.route('/', methods=['DELETE']) @login_required @@ -135,8 +152,5 @@ def get_dataset(dataset_id): return construct_json_result(code=RetCode.DATA_ERROR, message=f"attempt to get detail of dataset: {dataset_id}") -@manager.route('/', methods=['GET']) -@login_required -def list_datasets(): - return construct_json_result(code=RetCode.DATA_ERROR, message=f"attempt to list datasets") + diff --git a/api/db/services/knowledgebase_service.py b/api/db/services/knowledgebase_service.py index f1bc131ea..61075f3e0 100644 --- a/api/db/services/knowledgebase_service.py +++ b/api/db/services/knowledgebase_service.py @@ -40,6 +40,29 @@ class KnowledgebaseService(CommonService): return list(kbs.dicts()) + @classmethod + @DB.connection_context() + def get_by_tenant_ids(cls, joined_tenant_ids, user_id, + offset, count, orderby, desc): + kbs = cls.model.select().where( + ((cls.model.tenant_id.in_(joined_tenant_ids) & (cls.model.permission == + TenantPermission.TEAM.value)) | ( + cls.model.tenant_id == user_id)) + & (cls.model.status == StatusEnum.VALID.value) + ) + if desc: + kbs = kbs.order_by(cls.model.getter_by(orderby).desc()) + else: + kbs = kbs.order_by(cls.model.getter_by(orderby).asc()) + + kbs = list(kbs.dicts()) + + kbs_length = len(kbs) + if offset < 0 or offset > kbs_length: + raise IndexError("Offset is out of the valid range.") + + return kbs[offset:offset+count] + @classmethod @DB.connection_context() def get_detail(cls, kb_id): diff --git a/sdk/python/ragflow/__init__.py b/sdk/python/ragflow/__init__.py index 889b8322d..1ef2f0879 100644 --- a/sdk/python/ragflow/__init__.py +++ b/sdk/python/ragflow/__init__.py @@ -1,3 +1,5 @@ import importlib.metadata __version__ = importlib.metadata.version("ragflow") + +from .ragflow import RAGFlow diff --git a/sdk/python/ragflow/dataset.py b/sdk/python/ragflow/dataset.py index 1d4b56cea..5984aa62f 100644 --- a/sdk/python/ragflow/dataset.py +++ b/sdk/python/ragflow/dataset.py @@ -18,4 +18,4 @@ class DataSet: self.user_key = user_key self.dataset_url = dataset_url self.uuid = uuid - self.name = name \ No newline at end of file + self.name = name diff --git a/sdk/python/ragflow/ragflow.py b/sdk/python/ragflow/ragflow.py index ec106687f..2deb69ab2 100644 --- a/sdk/python/ragflow/ragflow.py +++ b/sdk/python/ragflow/ragflow.py @@ -17,7 +17,10 @@ import os import requests import json -class RAGFLow: +from httpx import HTTPError + + +class RAGFlow: def __init__(self, user_key, base_url, version = 'v1'): ''' api_url: http:///api/v1 @@ -36,16 +39,39 @@ class RAGFLow: result_dict = json.loads(res.text) return result_dict - def delete_dataset(self, dataset_name = None, dataset_id = None): + def delete_dataset(self, dataset_name=None, dataset_id=None): return dataset_name - def list_dataset(self): - response = requests.get(self.dataset_url) - print(response) - if response.status_code == 200: - return response.json()['datasets'] - else: - return None + def list_dataset(self, offset=0, count=-1, orderby="create_time", desc=True): + params = { + "offset": offset, + "count": count, + "orderby": orderby, + "desc": desc + } + try: + response = requests.get(url=self.dataset_url, params=params, headers=self.authorization_header) + response.raise_for_status() # if it is not 200 + original_data = response.json() + # TODO: format the data + # print(original_data) + # # Process the original data into the desired format + # formatted_data = { + # "datasets": [ + # { + # "id": dataset["id"], + # "created": dataset["create_time"], # Adjust the key based on the actual response + # "fileCount": dataset["doc_num"], # Adjust the key based on the actual response + # "name": dataset["name"] + # } + # for dataset in original_data + # ] + # } + return response.status_code, original_data + except HTTPError as http_err: + print(f"HTTP error occurred: {http_err}") + except Exception as err: + print(f"An error occurred: {err}") def get_dataset(self, dataset_id): endpoint = f"{self.dataset_url}/{dataset_id}" @@ -61,4 +87,4 @@ class RAGFLow: if response.status_code == 200: return True else: - return False \ No newline at end of file + return False diff --git a/sdk/python/test/common.py b/sdk/python/test/common.py index 7187f98ed..c7525297d 100644 --- a/sdk/python/test/common.py +++ b/sdk/python/test/common.py @@ -1,4 +1,4 @@ -API_KEY = 'IjJiMTVkZWNhMjU3MzExZWY4YzNiNjQ0OTdkMTllYjM3Ig.ZmQZrA.x9Z7c-1ErBUSL3m8SRtBRgGq5uE' +API_KEY = 'ImFmNWQ3YTY0Mjg5NjExZWZhNTdjMzA0M2Q3ZWU1MzdlIg.ZmldwA.9oP9pVtuEQSpg-Z18A2eOkWO-3E' HOST_ADDRESS = 'http://127.0.0.1:9380' \ No newline at end of file diff --git a/sdk/python/test/test_dataset.py b/sdk/python/test/test_dataset.py index 868eddcd1..04e168a75 100644 --- a/sdk/python/test/test_dataset.py +++ b/sdk/python/test/test_dataset.py @@ -1,10 +1,10 @@ from test_sdkbase import TestSdk -import ragflow -from ragflow.ragflow import RAGFLow +from ragflow import RAGFlow import pytest -from unittest.mock import MagicMock from common import API_KEY, HOST_ADDRESS + + class TestDataset(TestSdk): def test_create_dataset(self): @@ -15,12 +15,92 @@ class TestDataset(TestSdk): 4. update the kb 5. delete the kb ''' - ragflow = RAGFLow(API_KEY, HOST_ADDRESS) + ragflow = RAGFlow(API_KEY, HOST_ADDRESS) # create a kb res = ragflow.create_dataset("kb1") assert res['code'] == 0 and res['message'] == 'success' - dataset_id = res['data']['dataset_id'] - print(dataset_id) + dataset_name = res['data']['dataset_name'] + + def test_list_dataset_success(self): + ragflow = RAGFlow(API_KEY, HOST_ADDRESS) + # Call the list_datasets method + response = ragflow.list_dataset() + + code, datasets = response + + assert code == 200 + + def test_list_dataset_with_checking_size_and_name(self): + datasets_to_create = ["dataset1", "dataset2", "dataset3"] + ragflow = RAGFlow(API_KEY, HOST_ADDRESS) + created_response = [ragflow.create_dataset(name) for name in datasets_to_create] + + real_name_to_create = set() + for response in created_response: + assert 'data' in response, "Response is missing 'data' key" + dataset_name = response['data']['dataset_name'] + real_name_to_create.add(dataset_name) + + status_code, listed_data = ragflow.list_dataset(0, 3) + listed_data = listed_data['data'] + + listed_names = {d['name'] for d in listed_data} + assert listed_names == real_name_to_create + assert status_code == 200 + assert len(listed_data) == len(datasets_to_create) + + def test_list_dataset_with_getting_empty_result(self): + ragflow = RAGFlow(API_KEY, HOST_ADDRESS) + datasets_to_create = [] + created_response = [ragflow.create_dataset(name) for name in datasets_to_create] + + real_name_to_create = set() + for response in created_response: + assert 'data' in response, "Response is missing 'data' key" + dataset_name = response['data']['dataset_name'] + real_name_to_create.add(dataset_name) + + status_code, listed_data = ragflow.list_dataset(0, 0) + listed_data = listed_data['data'] + + listed_names = {d['name'] for d in listed_data} + assert listed_names == real_name_to_create + assert status_code == 200 + assert len(listed_data) == 0 + + def test_list_dataset_with_creating_100_knowledge_bases(self): + ragflow = RAGFlow(API_KEY, HOST_ADDRESS) + datasets_to_create = ["dataset1"] * 100 + created_response = [ragflow.create_dataset(name) for name in datasets_to_create] + + real_name_to_create = set() + for response in created_response: + assert 'data' in response, "Response is missing 'data' key" + dataset_name = response['data']['dataset_name'] + real_name_to_create.add(dataset_name) + + status_code, listed_data = ragflow.list_dataset(0, 100) + listed_data = listed_data['data'] + + listed_names = {d['name'] for d in listed_data} + assert listed_names == real_name_to_create + assert status_code == 200 + assert len(listed_data) == 100 + + def test_list_dataset_with_showing_one_dataset(self): + ragflow = RAGFlow(API_KEY, HOST_ADDRESS) + response = ragflow.list_dataset(0, 1) + code, response = response + datasets = response['data'] + assert len(datasets) == 1 + + def test_list_dataset_failure(self): + ragflow = RAGFlow(API_KEY, HOST_ADDRESS) + response = ragflow.list_dataset(-1, -1) + _, res = response + assert "IndexError" in res['message'] + + + - # TODO: list the kb \ No newline at end of file