mirror of
https://git.mirrors.martin98.com/https://github.com/infiniflow/ragflow.git
synced 2025-08-13 22:35:53 +08:00
create list_dataset api and tests (#1138)
### What problem does this PR solve? This PR have completed both HTTP API and Python SDK for 'list_dataset". In addition, there are tests for it. ### Type of change - [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
parent
f04fb36c26
commit
1eb4caf02a
@ -46,7 +46,7 @@ from api.contants import NAME_LENGTH_LIMIT
|
||||
|
||||
# ------------------------------ create a dataset ---------------------------------------
|
||||
@manager.route('/', methods=['POST'])
|
||||
@login_required # use login
|
||||
@login_required # use login
|
||||
@validate_request("name") # check name key
|
||||
def create_dataset():
|
||||
# Check if Authorization header is present
|
||||
@ -111,10 +111,27 @@ def create_dataset():
|
||||
if not KnowledgebaseService.save(**request_body):
|
||||
# failed to create new dataset
|
||||
return construct_result()
|
||||
return construct_json_result(data={"dataset_id": request_body["id"]})
|
||||
return construct_json_result(data={"dataset_name": request_body["name"]})
|
||||
except Exception as e:
|
||||
return construct_error_response(e)
|
||||
|
||||
# -----------------------------list datasets-------------------------------------------------------
|
||||
@manager.route('/', methods=['GET'])
|
||||
@login_required
|
||||
def list_datasets():
|
||||
offset = request.args.get("offset", 0)
|
||||
count = request.args.get("count", -1)
|
||||
orderby = request.args.get("orderby", "create_time")
|
||||
desc = request.args.get("desc", True)
|
||||
try:
|
||||
tenants = TenantService.get_joined_tenants_by_user_id(current_user.id)
|
||||
kbs = KnowledgebaseService.get_by_tenant_ids(
|
||||
[m["tenant_id"] for m in tenants], current_user.id, int(offset), int(count), orderby, desc)
|
||||
return construct_json_result(data=kbs, code=RetCode.DATA_ERROR, message=f"attempt to list datasets")
|
||||
except Exception as e:
|
||||
return construct_error_response(e)
|
||||
|
||||
# ---------------------------------delete a dataset ----------------------------
|
||||
|
||||
@manager.route('/<dataset_id>', methods=['DELETE'])
|
||||
@login_required
|
||||
@ -135,8 +152,5 @@ def get_dataset(dataset_id):
|
||||
return construct_json_result(code=RetCode.DATA_ERROR, message=f"attempt to get detail of dataset: {dataset_id}")
|
||||
|
||||
|
||||
@manager.route('/', methods=['GET'])
|
||||
@login_required
|
||||
def list_datasets():
|
||||
return construct_json_result(code=RetCode.DATA_ERROR, message=f"attempt to list datasets")
|
||||
|
||||
|
||||
|
@ -40,6 +40,29 @@ class KnowledgebaseService(CommonService):
|
||||
|
||||
return list(kbs.dicts())
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_by_tenant_ids(cls, joined_tenant_ids, user_id,
|
||||
offset, count, orderby, desc):
|
||||
kbs = cls.model.select().where(
|
||||
((cls.model.tenant_id.in_(joined_tenant_ids) & (cls.model.permission ==
|
||||
TenantPermission.TEAM.value)) | (
|
||||
cls.model.tenant_id == user_id))
|
||||
& (cls.model.status == StatusEnum.VALID.value)
|
||||
)
|
||||
if desc:
|
||||
kbs = kbs.order_by(cls.model.getter_by(orderby).desc())
|
||||
else:
|
||||
kbs = kbs.order_by(cls.model.getter_by(orderby).asc())
|
||||
|
||||
kbs = list(kbs.dicts())
|
||||
|
||||
kbs_length = len(kbs)
|
||||
if offset < 0 or offset > kbs_length:
|
||||
raise IndexError("Offset is out of the valid range.")
|
||||
|
||||
return kbs[offset:offset+count]
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_detail(cls, kb_id):
|
||||
|
@ -1,3 +1,5 @@
|
||||
import importlib.metadata
|
||||
|
||||
__version__ = importlib.metadata.version("ragflow")
|
||||
|
||||
from .ragflow import RAGFlow
|
||||
|
@ -18,4 +18,4 @@ class DataSet:
|
||||
self.user_key = user_key
|
||||
self.dataset_url = dataset_url
|
||||
self.uuid = uuid
|
||||
self.name = name
|
||||
self.name = name
|
||||
|
@ -17,7 +17,10 @@ import os
|
||||
import requests
|
||||
import json
|
||||
|
||||
class RAGFLow:
|
||||
from httpx import HTTPError
|
||||
|
||||
|
||||
class RAGFlow:
|
||||
def __init__(self, user_key, base_url, version = 'v1'):
|
||||
'''
|
||||
api_url: http://<host_address>/api/v1
|
||||
@ -36,16 +39,39 @@ class RAGFLow:
|
||||
result_dict = json.loads(res.text)
|
||||
return result_dict
|
||||
|
||||
def delete_dataset(self, dataset_name = None, dataset_id = None):
|
||||
def delete_dataset(self, dataset_name=None, dataset_id=None):
|
||||
return dataset_name
|
||||
|
||||
def list_dataset(self):
|
||||
response = requests.get(self.dataset_url)
|
||||
print(response)
|
||||
if response.status_code == 200:
|
||||
return response.json()['datasets']
|
||||
else:
|
||||
return None
|
||||
def list_dataset(self, offset=0, count=-1, orderby="create_time", desc=True):
|
||||
params = {
|
||||
"offset": offset,
|
||||
"count": count,
|
||||
"orderby": orderby,
|
||||
"desc": desc
|
||||
}
|
||||
try:
|
||||
response = requests.get(url=self.dataset_url, params=params, headers=self.authorization_header)
|
||||
response.raise_for_status() # if it is not 200
|
||||
original_data = response.json()
|
||||
# TODO: format the data
|
||||
# print(original_data)
|
||||
# # Process the original data into the desired format
|
||||
# formatted_data = {
|
||||
# "datasets": [
|
||||
# {
|
||||
# "id": dataset["id"],
|
||||
# "created": dataset["create_time"], # Adjust the key based on the actual response
|
||||
# "fileCount": dataset["doc_num"], # Adjust the key based on the actual response
|
||||
# "name": dataset["name"]
|
||||
# }
|
||||
# for dataset in original_data
|
||||
# ]
|
||||
# }
|
||||
return response.status_code, original_data
|
||||
except HTTPError as http_err:
|
||||
print(f"HTTP error occurred: {http_err}")
|
||||
except Exception as err:
|
||||
print(f"An error occurred: {err}")
|
||||
|
||||
def get_dataset(self, dataset_id):
|
||||
endpoint = f"{self.dataset_url}/{dataset_id}"
|
||||
@ -61,4 +87,4 @@ class RAGFLow:
|
||||
if response.status_code == 200:
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
return False
|
||||
|
@ -1,4 +1,4 @@
|
||||
|
||||
|
||||
API_KEY = 'IjJiMTVkZWNhMjU3MzExZWY4YzNiNjQ0OTdkMTllYjM3Ig.ZmQZrA.x9Z7c-1ErBUSL3m8SRtBRgGq5uE'
|
||||
API_KEY = 'ImFmNWQ3YTY0Mjg5NjExZWZhNTdjMzA0M2Q3ZWU1MzdlIg.ZmldwA.9oP9pVtuEQSpg-Z18A2eOkWO-3E'
|
||||
HOST_ADDRESS = 'http://127.0.0.1:9380'
|
@ -1,10 +1,10 @@
|
||||
from test_sdkbase import TestSdk
|
||||
import ragflow
|
||||
from ragflow.ragflow import RAGFLow
|
||||
from ragflow import RAGFlow
|
||||
import pytest
|
||||
from unittest.mock import MagicMock
|
||||
from common import API_KEY, HOST_ADDRESS
|
||||
|
||||
|
||||
|
||||
class TestDataset(TestSdk):
|
||||
|
||||
def test_create_dataset(self):
|
||||
@ -15,12 +15,92 @@ class TestDataset(TestSdk):
|
||||
4. update the kb
|
||||
5. delete the kb
|
||||
'''
|
||||
ragflow = RAGFLow(API_KEY, HOST_ADDRESS)
|
||||
|
||||
ragflow = RAGFlow(API_KEY, HOST_ADDRESS)
|
||||
# create a kb
|
||||
res = ragflow.create_dataset("kb1")
|
||||
assert res['code'] == 0 and res['message'] == 'success'
|
||||
dataset_id = res['data']['dataset_id']
|
||||
print(dataset_id)
|
||||
dataset_name = res['data']['dataset_name']
|
||||
|
||||
def test_list_dataset_success(self):
|
||||
ragflow = RAGFlow(API_KEY, HOST_ADDRESS)
|
||||
# Call the list_datasets method
|
||||
response = ragflow.list_dataset()
|
||||
|
||||
code, datasets = response
|
||||
|
||||
assert code == 200
|
||||
|
||||
def test_list_dataset_with_checking_size_and_name(self):
|
||||
datasets_to_create = ["dataset1", "dataset2", "dataset3"]
|
||||
ragflow = RAGFlow(API_KEY, HOST_ADDRESS)
|
||||
created_response = [ragflow.create_dataset(name) for name in datasets_to_create]
|
||||
|
||||
real_name_to_create = set()
|
||||
for response in created_response:
|
||||
assert 'data' in response, "Response is missing 'data' key"
|
||||
dataset_name = response['data']['dataset_name']
|
||||
real_name_to_create.add(dataset_name)
|
||||
|
||||
status_code, listed_data = ragflow.list_dataset(0, 3)
|
||||
listed_data = listed_data['data']
|
||||
|
||||
listed_names = {d['name'] for d in listed_data}
|
||||
assert listed_names == real_name_to_create
|
||||
assert status_code == 200
|
||||
assert len(listed_data) == len(datasets_to_create)
|
||||
|
||||
def test_list_dataset_with_getting_empty_result(self):
|
||||
ragflow = RAGFlow(API_KEY, HOST_ADDRESS)
|
||||
datasets_to_create = []
|
||||
created_response = [ragflow.create_dataset(name) for name in datasets_to_create]
|
||||
|
||||
real_name_to_create = set()
|
||||
for response in created_response:
|
||||
assert 'data' in response, "Response is missing 'data' key"
|
||||
dataset_name = response['data']['dataset_name']
|
||||
real_name_to_create.add(dataset_name)
|
||||
|
||||
status_code, listed_data = ragflow.list_dataset(0, 0)
|
||||
listed_data = listed_data['data']
|
||||
|
||||
listed_names = {d['name'] for d in listed_data}
|
||||
assert listed_names == real_name_to_create
|
||||
assert status_code == 200
|
||||
assert len(listed_data) == 0
|
||||
|
||||
def test_list_dataset_with_creating_100_knowledge_bases(self):
|
||||
ragflow = RAGFlow(API_KEY, HOST_ADDRESS)
|
||||
datasets_to_create = ["dataset1"] * 100
|
||||
created_response = [ragflow.create_dataset(name) for name in datasets_to_create]
|
||||
|
||||
real_name_to_create = set()
|
||||
for response in created_response:
|
||||
assert 'data' in response, "Response is missing 'data' key"
|
||||
dataset_name = response['data']['dataset_name']
|
||||
real_name_to_create.add(dataset_name)
|
||||
|
||||
status_code, listed_data = ragflow.list_dataset(0, 100)
|
||||
listed_data = listed_data['data']
|
||||
|
||||
listed_names = {d['name'] for d in listed_data}
|
||||
assert listed_names == real_name_to_create
|
||||
assert status_code == 200
|
||||
assert len(listed_data) == 100
|
||||
|
||||
def test_list_dataset_with_showing_one_dataset(self):
|
||||
ragflow = RAGFlow(API_KEY, HOST_ADDRESS)
|
||||
response = ragflow.list_dataset(0, 1)
|
||||
code, response = response
|
||||
datasets = response['data']
|
||||
assert len(datasets) == 1
|
||||
|
||||
def test_list_dataset_failure(self):
|
||||
ragflow = RAGFlow(API_KEY, HOST_ADDRESS)
|
||||
response = ragflow.list_dataset(-1, -1)
|
||||
_, res = response
|
||||
assert "IndexError" in res['message']
|
||||
|
||||
|
||||
|
||||
|
||||
# TODO: list the kb
|
Loading…
x
Reference in New Issue
Block a user