complete implementation of dataset SDK (#2147)

### What problem does this PR solve?

Complete implementation of dataset SDK.
#1102

### Type of change


- [x] New Feature (non-breaking change which adds functionality)

---------

Co-authored-by: Feiue <10215101452@stu.ecun.edu.cn>
Co-authored-by: Kevin Hu <kevinhu.sh@gmail.com>
This commit is contained in:
LiuHua 2024-08-29 14:31:31 +08:00 committed by GitHub
parent fc1ac3a962
commit f87e7242cd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 260 additions and 86 deletions

View File

@ -15,82 +15,156 @@
#
from flask import request
from api.db import StatusEnum
from api.db.db_models import APIToken
from api.db import StatusEnum, FileSource
from api.db.db_models import File
from api.db.services.document_service import DocumentService
from api.db.services.file2document_service import File2DocumentService
from api.db.services.file_service import FileService
from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.user_service import TenantService
from api.settings import RetCode
from api.utils import get_uuid
from api.utils.api_utils import get_data_error_result
from api.utils.api_utils import get_json_result
from api.utils.api_utils import get_json_result, token_required, get_data_error_result
@manager.route('/save', methods=['POST'])
def save():
@token_required
def save(tenant_id):
req = request.json
token = request.headers.get('Authorization').split()[1]
objs = APIToken.query(token=token)
if not objs:
return get_json_result(
data=False, retmsg='Token is not valid!"', retcode=RetCode.AUTHENTICATION_ERROR)
tenant_id = objs[0].tenant_id
e, t = TenantService.get_by_id(tenant_id)
if not e:
return get_data_error_result(retmsg="Tenant not found.")
if "id" not in req:
if "tenant_id" in req or "embd_id" in req:
return get_data_error_result(
retmsg="Tenant_id or embedding_model must not be provided")
if "name" not in req:
return get_data_error_result(
retmsg="Name is not empty!")
req['id'] = get_uuid()
req["name"] = req["name"].strip()
if req["name"] == "":
return get_data_error_result(
retmsg="Name is not empty")
if KnowledgebaseService.query(name=req["name"]):
retmsg="Name is not empty string!")
if KnowledgebaseService.query(name=req["name"], tenant_id=tenant_id, status=StatusEnum.VALID.value):
return get_data_error_result(
retmsg="Duplicated knowledgebase name")
retmsg="Duplicated knowledgebase name in creating dataset.")
req["tenant_id"] = tenant_id
req['created_by'] = tenant_id
req['embd_id'] = t.embd_id
if not KnowledgebaseService.save(**req):
return get_data_error_result(retmsg="Data saving error")
req.pop('created_by')
keys_to_rename = {'embd_id': "embedding_model", 'parser_id': 'parser_method',
'chunk_num': 'chunk_count', 'doc_num': 'document_count'}
for old_key,new_key in keys_to_rename.items():
if old_key in req:
req[new_key]=req.pop(old_key)
return get_data_error_result(retmsg="Create dataset error.(Database error)")
return get_json_result(data=req)
else:
if req["tenant_id"] != tenant_id or req["embd_id"] != t.embd_id:
return get_data_error_result(
retmsg="Can't change tenant_id or embedding_model")
if "tenant_id" in req:
if req["tenant_id"] != tenant_id:
return get_data_error_result(
retmsg="Can't change tenant_id.")
e, kb = KnowledgebaseService.get_by_id(req["id"])
if not e:
return get_data_error_result(
retmsg="Can't find this knowledgebase!")
if "embd_id" in req:
if req["embd_id"] != t.embd_id:
return get_data_error_result(
retmsg="Can't change embedding_model.")
if not KnowledgebaseService.query(
created_by=tenant_id, id=req["id"]):
return get_json_result(
data=False, retmsg=f'Only owner of knowledgebase authorized for this operation.',
data=False, retmsg='You do not own the dataset.',
retcode=RetCode.OPERATING_ERROR)
if req["chunk_num"] != kb.chunk_num or req['doc_num'] != kb.doc_num:
return get_data_error_result(
retmsg="Can't change document_count or chunk_count ")
e, kb = KnowledgebaseService.get_by_id(req["id"])
if kb.chunk_num > 0 and req['parser_id'] != kb.parser_id:
return get_data_error_result(
retmsg="if chunk count is not 0, parser method is not changable. ")
if "chunk_num" in req:
if req["chunk_num"] != kb.chunk_num:
return get_data_error_result(
retmsg="Can't change chunk_count.")
if "doc_num" in req:
if req['doc_num'] != kb.doc_num:
return get_data_error_result(
retmsg="Can't change document_count.")
if req["name"].lower() != kb.name.lower() \
and len(KnowledgebaseService.query(name=req["name"], tenant_id=req['tenant_id'],
status=StatusEnum.VALID.value)) > 0:
return get_data_error_result(
retmsg="Duplicated knowledgebase name.")
if "parser_id" in req:
if kb.chunk_num > 0 and req['parser_id'] != kb.parser_id:
return get_data_error_result(
retmsg="if chunk count is not 0, parse method is not changable.")
if "name" in req:
if req["name"].lower() != kb.name.lower() \
and len(KnowledgebaseService.query(name=req["name"], tenant_id=tenant_id,
status=StatusEnum.VALID.value)) > 0:
return get_data_error_result(
retmsg="Duplicated knowledgebase name in updating dataset.")
del req["id"]
req['created_by'] = tenant_id
if not KnowledgebaseService.update_by_id(kb.id, req):
return get_data_error_result(retmsg="Data update error ")
return get_data_error_result(retmsg="Update dataset error.(Database error)")
return get_json_result(data=True)
@manager.route('/delete', methods=['DELETE'])
@token_required
def delete(tenant_id):
req = request.args
kbs = KnowledgebaseService.query(
created_by=tenant_id, id=req["id"])
if not kbs:
return get_json_result(
data=False, retmsg='You do not own the dataset',
retcode=RetCode.OPERATING_ERROR)
for doc in DocumentService.query(kb_id=req["id"]):
if not DocumentService.remove_document(doc, kbs[0].tenant_id):
return get_data_error_result(
retmsg="Remove document error.(Database error)")
f2d = File2DocumentService.get_by_document_id(doc.id)
FileService.filter_delete([File.source_type == FileSource.KNOWLEDGEBASE, File.id == f2d[0].file_id])
File2DocumentService.delete_by_document_id(doc.id)
if not KnowledgebaseService.delete_by_id(req["id"]):
return get_data_error_result(
retmsg="Delete dataset error.(Database error)")
return get_json_result(data=True)
@manager.route('/list', methods=['GET'])
@token_required
def list_datasets(tenant_id):
page_number = int(request.args.get("page", 1))
items_per_page = int(request.args.get("page_size", 1024))
orderby = request.args.get("orderby", "create_time")
desc = bool(request.args.get("desc", True))
tenants = TenantService.get_joined_tenants_by_user_id(tenant_id)
kbs = KnowledgebaseService.get_by_tenant_ids(
[m["tenant_id"] for m in tenants], tenant_id, page_number, items_per_page, orderby, desc)
return get_json_result(data=kbs)
@manager.route('/detail', methods=['GET'])
@token_required
def detail(tenant_id):
req = request.args
if "id" in req:
id = req["id"]
kb = KnowledgebaseService.query(created_by=tenant_id, id=req["id"])
if not kb:
return get_json_result(
data=False, retmsg='You do not own the dataset',
retcode=RetCode.OPERATING_ERROR)
if "name" in req:
name = req["name"]
if kb[0].name != name:
return get_json_result(
data=False, retmsg='You do not own the dataset',
retcode=RetCode.OPERATING_ERROR)
e, k = KnowledgebaseService.get_by_id(id)
return get_json_result(data=k.to_dict())
else:
if "name" in req:
name = req["name"]
e, k = KnowledgebaseService.get_by_name(kb_name=name, tenant_id=tenant_id)
if not e:
return get_json_result(
data=False, retmsg='You do not own the dataset',
retcode=RetCode.OPERATING_ERROR)
return get_json_result(data=k.to_dict())
else:
return get_data_error_result(
retmsg="At least one of `id` or `name` must be provided.")

View File

@ -13,30 +13,32 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
import functools
import json
import random
import time
from base64 import b64encode
from functools import wraps
from hmac import HMAC
from io import BytesIO
from urllib.parse import quote, urlencode
from uuid import uuid1
import requests
from flask import (
Response, jsonify, send_file, make_response,
request as flask_request,
)
from werkzeug.http import HTTP_STATUS_CODES
from api.utils import json_dumps
from api.settings import RetCode
from api.db.db_models import APIToken
from api.settings import (
REQUEST_MAX_WAIT_SEC, REQUEST_WAIT_SEC,
stat_logger, CLIENT_AUTHENTICATION, HTTP_APP_KEY, SECRET_KEY
)
import requests
import functools
from api.settings import RetCode
from api.utils import CustomJSONEncoder
from uuid import uuid1
from base64 import b64encode
from hmac import HMAC
from urllib.parse import quote, urlencode
from api.utils import json_dumps
requests.models.complexjson.dumps = functools.partial(
json.dumps, cls=CustomJSONEncoder)
@ -96,7 +98,6 @@ def get_exponential_backoff_interval(retries, full_jitter=False):
def get_json_result(retcode=RetCode.SUCCESS, retmsg='success',
data=None, job_id=None, meta=None):
import re
result_dict = {
"retcode": retcode,
"retmsg": retmsg,
@ -145,7 +146,8 @@ def server_error_response(e):
return get_json_result(
retcode=RetCode.EXCEPTION_ERROR, retmsg=repr(e.args[0]), data=e.args[1])
if repr(e).find("index_not_found_exception") >= 0:
return get_json_result(retcode=RetCode.EXCEPTION_ERROR, retmsg="No chunk found, please upload file and parse it.")
return get_json_result(retcode=RetCode.EXCEPTION_ERROR,
retmsg="No chunk found, please upload file and parse it.")
return get_json_result(retcode=RetCode.EXCEPTION_ERROR, retmsg=repr(e))
@ -190,7 +192,9 @@ def validate_request(*args, **kwargs):
return get_json_result(
retcode=RetCode.ARGUMENT_ERROR, retmsg=error_string)
return func(*_args, **_kwargs)
return decorated_function
return wrapper
@ -217,7 +221,7 @@ def get_json_result(retcode=RetCode.SUCCESS, retmsg='success', data=None):
def construct_response(retcode=RetCode.SUCCESS,
retmsg='success', data=None, auth=None):
retmsg='success', data=None, auth=None):
result_dict = {"retcode": retcode, "retmsg": retmsg, "data": data}
response_dict = {}
for key, value in result_dict.items():
@ -235,6 +239,7 @@ def construct_response(retcode=RetCode.SUCCESS,
response.headers["Access-Control-Expose-Headers"] = "Authorization"
return response
def construct_result(code=RetCode.DATA_ERROR, message='data is missing'):
import re
result_dict = {"code": code, "message": re.sub(r"rag", "seceum", message, flags=re.IGNORECASE)}
@ -263,7 +268,23 @@ def construct_error_response(e):
pass
if len(e.args) > 1:
return construct_json_result(code=RetCode.EXCEPTION_ERROR, message=repr(e.args[0]), data=e.args[1])
if repr(e).find("index_not_found_exception") >=0:
return construct_json_result(code=RetCode.EXCEPTION_ERROR, message="No chunk found, please upload file and parse it.")
if repr(e).find("index_not_found_exception") >= 0:
return construct_json_result(code=RetCode.EXCEPTION_ERROR,
message="No chunk found, please upload file and parse it.")
return construct_json_result(code=RetCode.EXCEPTION_ERROR, message=repr(e))
def token_required(func):
@wraps(func)
def decorated_function(*args, **kwargs):
token = flask_request.headers.get('Authorization').split()[1]
objs = APIToken.query(token=token)
if not objs:
return get_json_result(
data=False, retmsg='Token is not valid!', retcode=RetCode.AUTHENTICATION_ERROR
)
kwargs['tenant_id'] = objs[0].tenant_id
return func(*args, **kwargs)
return decorated_function

View File

@ -18,13 +18,17 @@ class Base(object):
pr[name] = value
return pr
def post(self, path, param):
res = self.rag.post(path,param)
res = self.rag.post(path, param)
return res
def get(self, path, params=''):
res = self.rag.get(path,params)
def get(self, path, params):
res = self.rag.get(path, params)
return res
def rm(self, path, params):
res = self.rag.delete(path, params)
return res
def __str__(self):
return str(self.to_json())

View File

@ -21,18 +21,36 @@ class DataSet(Base):
self.permission = "me"
self.document_count = 0
self.chunk_count = 0
self.parser_method = "naive"
self.parse_method = "naive"
self.parser_config = None
for k in list(res_dict.keys()):
if k == "embd_id":
res_dict["embedding_model"] = res_dict[k]
if k == "parser_id":
res_dict['parse_method'] = res_dict[k]
if k == "doc_num":
res_dict["document_count"] = res_dict[k]
if k == "chunk_num":
res_dict["chunk_count"] = res_dict[k]
if k not in self.__dict__:
res_dict.pop(k)
super().__init__(rag, res_dict)
def save(self):
def save(self) -> bool:
res = self.post('/dataset/save',
{"id": self.id, "name": self.name, "avatar": self.avatar, "tenant_id": self.tenant_id,
"description": self.description, "language": self.language, "embd_id": self.embedding_model,
"permission": self.permission,
"doc_num": self.document_count, "chunk_num": self.chunk_count, "parser_id": self.parser_method,
"doc_num": self.document_count, "chunk_num": self.chunk_count, "parser_id": self.parse_method,
"parser_config": self.parser_config.to_json()
})
res = res.json()
if not res.get("retmsg"): return True
raise Exception(res["retmsg"])
if res.get("retmsg") == "success": return True
raise Exception(res["retmsg"])
def delete(self) -> bool:
res = self.rm('/dataset/delete',
{"id": self.id})
res = res.json()
if res.get("retmsg") == "success": return True
raise Exception(res["retmsg"])

View File

@ -13,6 +13,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List
import requests
from .modules.dataset import DataSet
@ -25,30 +27,54 @@ class RAGFlow:
"""
self.user_key = user_key
self.api_url = f"{base_url}/api/{version}"
self.authorization_header = {"Authorization": "{} {}".format("Bearer",self.user_key)}
self.authorization_header = {"Authorization": "{} {}".format("Bearer", self.user_key)}
def post(self, path, param):
res = requests.post(url=self.api_url + path, json=param, headers=self.authorization_header)
return res
def get(self, path, params=''):
res = requests.get(self.api_url + path, params=params, headers=self.authorization_header)
def get(self, path, params=None):
res = requests.get(url=self.api_url + path, params=params, headers=self.authorization_header)
return res
def create_dataset(self, name:str,avatar:str="",description:str="",language:str="English",permission:str="me",
document_count:int=0,chunk_count:int=0,parser_method:str="naive",
parser_config:DataSet.ParserConfig=None):
def delete(self, path, params):
res = requests.delete(url=self.api_url + path, params=params, headers=self.authorization_header)
return res
def create_dataset(self, name: str, avatar: str = "", description: str = "", language: str = "English",
permission: str = "me",
document_count: int = 0, chunk_count: int = 0, parse_method: str = "naive",
parser_config: DataSet.ParserConfig = None) -> DataSet:
if parser_config is None:
parser_config = DataSet.ParserConfig(self, {"chunk_token_count":128,"layout_recognize": True, "delimiter":"\n!?。;!?","task_page_size":12})
parser_config=parser_config.to_json()
res=self.post("/dataset/save",{"name":name,"avatar":avatar,"description":description,"language":language,"permission":permission,
"doc_num": document_count,"chunk_num":chunk_count,"parser_id":parser_method,
"parser_config":parser_config
}
)
parser_config = DataSet.ParserConfig(self, {"chunk_token_count": 128, "layout_recognize": True,
"delimiter": "\n!?。;!?", "task_page_size": 12})
parser_config = parser_config.to_json()
res = self.post("/dataset/save",
{"name": name, "avatar": avatar, "description": description, "language": language,
"permission": permission,
"doc_num": document_count, "chunk_num": chunk_count, "parser_id": parse_method,
"parser_config": parser_config
}
)
res = res.json()
if not res.get("retmsg"):
if res.get("retmsg") == "success":
return DataSet(self, res["data"])
raise Exception(res["retmsg"])
def list_datasets(self, page: int = 1, page_size: int = 150, orderby: str = "create_time", desc: bool = True) -> \
List[DataSet]:
res = self.get("/dataset/list", {"page": page, "page_size": page_size, "orderby": orderby, "desc": desc})
res = res.json()
result_list = []
if res.get("retmsg") == "success":
for data in res['data']:
result_list.append(DataSet(self, data))
return result_list
raise Exception(res["retmsg"])
def get_dataset(self, id: str = None, name: str = None) -> DataSet:
res = self.get("/dataset/detail", {"id": id, "name": name})
res = res.json()
if res.get("retmsg") == "success":
return DataSet(self, res['data'])
raise Exception(res["retmsg"])

View File

@ -7,7 +7,7 @@ from test_sdkbase import TestSdk
class TestDataset(TestSdk):
def test_create_dataset_with_success(self):
"""
Test creating dataset with success
Test creating a dataset with success
"""
rag = RAGFlow(API_KEY, HOST_ADDRESS)
ds = rag.create_dataset("God")
@ -18,15 +18,46 @@ class TestDataset(TestSdk):
def test_update_dataset_with_success(self):
"""
Test updating dataset with success.
Test updating a dataset with success.
"""
rag = RAGFlow(API_KEY, HOST_ADDRESS)
ds = rag.create_dataset("ABC")
if isinstance(ds, DataSet):
assert ds.name == "ABC", "Name does not match."
assert ds.name == "ABC", "Name does not match."
ds.name = 'DEF'
res = ds.save()
assert res is True, f"Failed to update dataset, error: {res}"
assert res is True, f"Failed to update dataset, error: {res}"
else:
assert False, f"Failed to create dataset, error: {ds}"
assert False, f"Failed to create dataset, error: {ds}"
def test_delete_dataset_with_success(self):
"""
Test deleting a dataset with success
"""
rag = RAGFlow(API_KEY, HOST_ADDRESS)
ds = rag.create_dataset("MA")
if isinstance(ds, DataSet):
assert ds.name == "MA", "Name does not match."
res = ds.delete()
assert res is True, f"Failed to delete dataset, error: {res}"
else:
assert False, f"Failed to create dataset, error: {ds}"
def test_list_datasets_with_success(self):
"""
Test listing datasets with success
"""
rag = RAGFlow(API_KEY, HOST_ADDRESS)
list_datasets = rag.list_datasets()
assert len(list_datasets) > 0, "Do not exist any dataset"
for ds in list_datasets:
assert isinstance(ds, DataSet), "Existence type is not dataset."
def test_get_detail_dataset_with_success(self):
"""
Test getting a dataset's detail with success
"""
rag = RAGFlow(API_KEY, HOST_ADDRESS)
ds = rag.get_dataset(name="God")
assert isinstance(ds, DataSet), f"Failed to get dataset, error: {ds}."
assert ds.name == "God", "Name does not match"