complete implementation of dataset SDK (#2147)

### What problem does this PR solve?

Complete implementation of dataset SDK.
#1102

### Type of change


- [x] New Feature (non-breaking change which adds functionality)

---------

Co-authored-by: Feiue <10215101452@stu.ecun.edu.cn>
Co-authored-by: Kevin Hu <kevinhu.sh@gmail.com>
This commit is contained in:
LiuHua 2024-08-29 14:31:31 +08:00 committed by GitHub
parent fc1ac3a962
commit f87e7242cd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 260 additions and 86 deletions

View File

@ -15,82 +15,156 @@
# #
from flask import request from flask import request
from api.db import StatusEnum from api.db import StatusEnum, FileSource
from api.db.db_models import APIToken from api.db.db_models import File
from api.db.services.document_service import DocumentService
from api.db.services.file2document_service import File2DocumentService
from api.db.services.file_service import FileService
from api.db.services.knowledgebase_service import KnowledgebaseService from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.user_service import TenantService from api.db.services.user_service import TenantService
from api.settings import RetCode from api.settings import RetCode
from api.utils import get_uuid from api.utils import get_uuid
from api.utils.api_utils import get_data_error_result from api.utils.api_utils import get_json_result, token_required, get_data_error_result
from api.utils.api_utils import get_json_result
@manager.route('/save', methods=['POST']) @manager.route('/save', methods=['POST'])
def save(): @token_required
def save(tenant_id):
req = request.json req = request.json
token = request.headers.get('Authorization').split()[1]
objs = APIToken.query(token=token)
if not objs:
return get_json_result(
data=False, retmsg='Token is not valid!"', retcode=RetCode.AUTHENTICATION_ERROR)
tenant_id = objs[0].tenant_id
e, t = TenantService.get_by_id(tenant_id) e, t = TenantService.get_by_id(tenant_id)
if not e:
return get_data_error_result(retmsg="Tenant not found.")
if "id" not in req: if "id" not in req:
if "tenant_id" in req or "embd_id" in req:
return get_data_error_result(
retmsg="Tenant_id or embedding_model must not be provided")
if "name" not in req:
return get_data_error_result(
retmsg="Name is not empty!")
req['id'] = get_uuid() req['id'] = get_uuid()
req["name"] = req["name"].strip() req["name"] = req["name"].strip()
if req["name"] == "": if req["name"] == "":
return get_data_error_result( return get_data_error_result(
retmsg="Name is not empty") retmsg="Name is not empty string!")
if KnowledgebaseService.query(name=req["name"]): if KnowledgebaseService.query(name=req["name"], tenant_id=tenant_id, status=StatusEnum.VALID.value):
return get_data_error_result( return get_data_error_result(
retmsg="Duplicated knowledgebase name") retmsg="Duplicated knowledgebase name in creating dataset.")
req["tenant_id"] = tenant_id req["tenant_id"] = tenant_id
req['created_by'] = tenant_id req['created_by'] = tenant_id
req['embd_id'] = t.embd_id req['embd_id'] = t.embd_id
if not KnowledgebaseService.save(**req): if not KnowledgebaseService.save(**req):
return get_data_error_result(retmsg="Data saving error") return get_data_error_result(retmsg="Create dataset error.(Database error)")
req.pop('created_by')
keys_to_rename = {'embd_id': "embedding_model", 'parser_id': 'parser_method',
'chunk_num': 'chunk_count', 'doc_num': 'document_count'}
for old_key,new_key in keys_to_rename.items():
if old_key in req:
req[new_key]=req.pop(old_key)
return get_json_result(data=req) return get_json_result(data=req)
else: else:
if req["tenant_id"] != tenant_id or req["embd_id"] != t.embd_id: if "tenant_id" in req:
return get_data_error_result( if req["tenant_id"] != tenant_id:
retmsg="Can't change tenant_id or embedding_model") return get_data_error_result(
retmsg="Can't change tenant_id.")
e, kb = KnowledgebaseService.get_by_id(req["id"]) if "embd_id" in req:
if not e: if req["embd_id"] != t.embd_id:
return get_data_error_result( return get_data_error_result(
retmsg="Can't find this knowledgebase!") retmsg="Can't change embedding_model.")
if not KnowledgebaseService.query( if not KnowledgebaseService.query(
created_by=tenant_id, id=req["id"]): created_by=tenant_id, id=req["id"]):
return get_json_result( return get_json_result(
data=False, retmsg=f'Only owner of knowledgebase authorized for this operation.', data=False, retmsg='You do not own the dataset.',
retcode=RetCode.OPERATING_ERROR) retcode=RetCode.OPERATING_ERROR)
if req["chunk_num"] != kb.chunk_num or req['doc_num'] != kb.doc_num: e, kb = KnowledgebaseService.get_by_id(req["id"])
return get_data_error_result(
retmsg="Can't change document_count or chunk_count ")
if kb.chunk_num > 0 and req['parser_id'] != kb.parser_id: if "chunk_num" in req:
return get_data_error_result( if req["chunk_num"] != kb.chunk_num:
retmsg="if chunk count is not 0, parser method is not changable. ") return get_data_error_result(
retmsg="Can't change chunk_count.")
if "doc_num" in req:
if req['doc_num'] != kb.doc_num:
return get_data_error_result(
retmsg="Can't change document_count.")
if req["name"].lower() != kb.name.lower() \ if "parser_id" in req:
and len(KnowledgebaseService.query(name=req["name"], tenant_id=req['tenant_id'], if kb.chunk_num > 0 and req['parser_id'] != kb.parser_id:
status=StatusEnum.VALID.value)) > 0: return get_data_error_result(
return get_data_error_result( retmsg="if chunk count is not 0, parse method is not changable.")
retmsg="Duplicated knowledgebase name.") if "name" in req:
if req["name"].lower() != kb.name.lower() \
and len(KnowledgebaseService.query(name=req["name"], tenant_id=tenant_id,
status=StatusEnum.VALID.value)) > 0:
return get_data_error_result(
retmsg="Duplicated knowledgebase name in updating dataset.")
del req["id"] del req["id"]
req['created_by'] = tenant_id
if not KnowledgebaseService.update_by_id(kb.id, req): if not KnowledgebaseService.update_by_id(kb.id, req):
return get_data_error_result(retmsg="Data update error ") return get_data_error_result(retmsg="Update dataset error.(Database error)")
return get_json_result(data=True) return get_json_result(data=True)
@manager.route('/delete', methods=['DELETE'])
@token_required
def delete(tenant_id):
req = request.args
kbs = KnowledgebaseService.query(
created_by=tenant_id, id=req["id"])
if not kbs:
return get_json_result(
data=False, retmsg='You do not own the dataset',
retcode=RetCode.OPERATING_ERROR)
for doc in DocumentService.query(kb_id=req["id"]):
if not DocumentService.remove_document(doc, kbs[0].tenant_id):
return get_data_error_result(
retmsg="Remove document error.(Database error)")
f2d = File2DocumentService.get_by_document_id(doc.id)
FileService.filter_delete([File.source_type == FileSource.KNOWLEDGEBASE, File.id == f2d[0].file_id])
File2DocumentService.delete_by_document_id(doc.id)
if not KnowledgebaseService.delete_by_id(req["id"]):
return get_data_error_result(
retmsg="Delete dataset error.(Database error)")
return get_json_result(data=True)
@manager.route('/list', methods=['GET'])
@token_required
def list_datasets(tenant_id):
page_number = int(request.args.get("page", 1))
items_per_page = int(request.args.get("page_size", 1024))
orderby = request.args.get("orderby", "create_time")
desc = bool(request.args.get("desc", True))
tenants = TenantService.get_joined_tenants_by_user_id(tenant_id)
kbs = KnowledgebaseService.get_by_tenant_ids(
[m["tenant_id"] for m in tenants], tenant_id, page_number, items_per_page, orderby, desc)
return get_json_result(data=kbs)
@manager.route('/detail', methods=['GET'])
@token_required
def detail(tenant_id):
req = request.args
if "id" in req:
id = req["id"]
kb = KnowledgebaseService.query(created_by=tenant_id, id=req["id"])
if not kb:
return get_json_result(
data=False, retmsg='You do not own the dataset',
retcode=RetCode.OPERATING_ERROR)
if "name" in req:
name = req["name"]
if kb[0].name != name:
return get_json_result(
data=False, retmsg='You do not own the dataset',
retcode=RetCode.OPERATING_ERROR)
e, k = KnowledgebaseService.get_by_id(id)
return get_json_result(data=k.to_dict())
else:
if "name" in req:
name = req["name"]
e, k = KnowledgebaseService.get_by_name(kb_name=name, tenant_id=tenant_id)
if not e:
return get_json_result(
data=False, retmsg='You do not own the dataset',
retcode=RetCode.OPERATING_ERROR)
return get_json_result(data=k.to_dict())
else:
return get_data_error_result(
retmsg="At least one of `id` or `name` must be provided.")

View File

@ -13,30 +13,32 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# #
import functools
import json import json
import random import random
import time import time
from base64 import b64encode
from functools import wraps from functools import wraps
from hmac import HMAC
from io import BytesIO from io import BytesIO
from urllib.parse import quote, urlencode
from uuid import uuid1
import requests
from flask import ( from flask import (
Response, jsonify, send_file, make_response, Response, jsonify, send_file, make_response,
request as flask_request, request as flask_request,
) )
from werkzeug.http import HTTP_STATUS_CODES from werkzeug.http import HTTP_STATUS_CODES
from api.utils import json_dumps from api.db.db_models import APIToken
from api.settings import RetCode
from api.settings import ( from api.settings import (
REQUEST_MAX_WAIT_SEC, REQUEST_WAIT_SEC, REQUEST_MAX_WAIT_SEC, REQUEST_WAIT_SEC,
stat_logger, CLIENT_AUTHENTICATION, HTTP_APP_KEY, SECRET_KEY stat_logger, CLIENT_AUTHENTICATION, HTTP_APP_KEY, SECRET_KEY
) )
import requests from api.settings import RetCode
import functools
from api.utils import CustomJSONEncoder from api.utils import CustomJSONEncoder
from uuid import uuid1 from api.utils import json_dumps
from base64 import b64encode
from hmac import HMAC
from urllib.parse import quote, urlencode
requests.models.complexjson.dumps = functools.partial( requests.models.complexjson.dumps = functools.partial(
json.dumps, cls=CustomJSONEncoder) json.dumps, cls=CustomJSONEncoder)
@ -96,7 +98,6 @@ def get_exponential_backoff_interval(retries, full_jitter=False):
def get_json_result(retcode=RetCode.SUCCESS, retmsg='success', def get_json_result(retcode=RetCode.SUCCESS, retmsg='success',
data=None, job_id=None, meta=None): data=None, job_id=None, meta=None):
import re
result_dict = { result_dict = {
"retcode": retcode, "retcode": retcode,
"retmsg": retmsg, "retmsg": retmsg,
@ -145,7 +146,8 @@ def server_error_response(e):
return get_json_result( return get_json_result(
retcode=RetCode.EXCEPTION_ERROR, retmsg=repr(e.args[0]), data=e.args[1]) retcode=RetCode.EXCEPTION_ERROR, retmsg=repr(e.args[0]), data=e.args[1])
if repr(e).find("index_not_found_exception") >= 0: if repr(e).find("index_not_found_exception") >= 0:
return get_json_result(retcode=RetCode.EXCEPTION_ERROR, retmsg="No chunk found, please upload file and parse it.") return get_json_result(retcode=RetCode.EXCEPTION_ERROR,
retmsg="No chunk found, please upload file and parse it.")
return get_json_result(retcode=RetCode.EXCEPTION_ERROR, retmsg=repr(e)) return get_json_result(retcode=RetCode.EXCEPTION_ERROR, retmsg=repr(e))
@ -190,7 +192,9 @@ def validate_request(*args, **kwargs):
return get_json_result( return get_json_result(
retcode=RetCode.ARGUMENT_ERROR, retmsg=error_string) retcode=RetCode.ARGUMENT_ERROR, retmsg=error_string)
return func(*_args, **_kwargs) return func(*_args, **_kwargs)
return decorated_function return decorated_function
return wrapper return wrapper
@ -217,7 +221,7 @@ def get_json_result(retcode=RetCode.SUCCESS, retmsg='success', data=None):
def construct_response(retcode=RetCode.SUCCESS, def construct_response(retcode=RetCode.SUCCESS,
retmsg='success', data=None, auth=None): retmsg='success', data=None, auth=None):
result_dict = {"retcode": retcode, "retmsg": retmsg, "data": data} result_dict = {"retcode": retcode, "retmsg": retmsg, "data": data}
response_dict = {} response_dict = {}
for key, value in result_dict.items(): for key, value in result_dict.items():
@ -235,6 +239,7 @@ def construct_response(retcode=RetCode.SUCCESS,
response.headers["Access-Control-Expose-Headers"] = "Authorization" response.headers["Access-Control-Expose-Headers"] = "Authorization"
return response return response
def construct_result(code=RetCode.DATA_ERROR, message='data is missing'): def construct_result(code=RetCode.DATA_ERROR, message='data is missing'):
import re import re
result_dict = {"code": code, "message": re.sub(r"rag", "seceum", message, flags=re.IGNORECASE)} result_dict = {"code": code, "message": re.sub(r"rag", "seceum", message, flags=re.IGNORECASE)}
@ -263,7 +268,23 @@ def construct_error_response(e):
pass pass
if len(e.args) > 1: if len(e.args) > 1:
return construct_json_result(code=RetCode.EXCEPTION_ERROR, message=repr(e.args[0]), data=e.args[1]) return construct_json_result(code=RetCode.EXCEPTION_ERROR, message=repr(e.args[0]), data=e.args[1])
if repr(e).find("index_not_found_exception") >=0: if repr(e).find("index_not_found_exception") >= 0:
return construct_json_result(code=RetCode.EXCEPTION_ERROR, message="No chunk found, please upload file and parse it.") return construct_json_result(code=RetCode.EXCEPTION_ERROR,
message="No chunk found, please upload file and parse it.")
return construct_json_result(code=RetCode.EXCEPTION_ERROR, message=repr(e)) return construct_json_result(code=RetCode.EXCEPTION_ERROR, message=repr(e))
def token_required(func):
@wraps(func)
def decorated_function(*args, **kwargs):
token = flask_request.headers.get('Authorization').split()[1]
objs = APIToken.query(token=token)
if not objs:
return get_json_result(
data=False, retmsg='Token is not valid!', retcode=RetCode.AUTHENTICATION_ERROR
)
kwargs['tenant_id'] = objs[0].tenant_id
return func(*args, **kwargs)
return decorated_function

View File

@ -18,13 +18,17 @@ class Base(object):
pr[name] = value pr[name] = value
return pr return pr
def post(self, path, param): def post(self, path, param):
res = self.rag.post(path,param) res = self.rag.post(path, param)
return res return res
def get(self, path, params=''): def get(self, path, params):
res = self.rag.get(path,params) res = self.rag.get(path, params)
return res return res
def rm(self, path, params):
res = self.rag.delete(path, params)
return res
def __str__(self):
return str(self.to_json())

View File

@ -21,18 +21,36 @@ class DataSet(Base):
self.permission = "me" self.permission = "me"
self.document_count = 0 self.document_count = 0
self.chunk_count = 0 self.chunk_count = 0
self.parser_method = "naive" self.parse_method = "naive"
self.parser_config = None self.parser_config = None
for k in list(res_dict.keys()):
if k == "embd_id":
res_dict["embedding_model"] = res_dict[k]
if k == "parser_id":
res_dict['parse_method'] = res_dict[k]
if k == "doc_num":
res_dict["document_count"] = res_dict[k]
if k == "chunk_num":
res_dict["chunk_count"] = res_dict[k]
if k not in self.__dict__:
res_dict.pop(k)
super().__init__(rag, res_dict) super().__init__(rag, res_dict)
def save(self): def save(self) -> bool:
res = self.post('/dataset/save', res = self.post('/dataset/save',
{"id": self.id, "name": self.name, "avatar": self.avatar, "tenant_id": self.tenant_id, {"id": self.id, "name": self.name, "avatar": self.avatar, "tenant_id": self.tenant_id,
"description": self.description, "language": self.language, "embd_id": self.embedding_model, "description": self.description, "language": self.language, "embd_id": self.embedding_model,
"permission": self.permission, "permission": self.permission,
"doc_num": self.document_count, "chunk_num": self.chunk_count, "parser_id": self.parser_method, "doc_num": self.document_count, "chunk_num": self.chunk_count, "parser_id": self.parse_method,
"parser_config": self.parser_config.to_json() "parser_config": self.parser_config.to_json()
}) })
res = res.json() res = res.json()
if not res.get("retmsg"): return True if res.get("retmsg") == "success": return True
raise Exception(res["retmsg"])
def delete(self) -> bool:
res = self.rm('/dataset/delete',
{"id": self.id})
res = res.json()
if res.get("retmsg") == "success": return True
raise Exception(res["retmsg"]) raise Exception(res["retmsg"])

View File

@ -13,6 +13,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import List
import requests import requests
from .modules.dataset import DataSet from .modules.dataset import DataSet
@ -25,30 +27,54 @@ class RAGFlow:
""" """
self.user_key = user_key self.user_key = user_key
self.api_url = f"{base_url}/api/{version}" self.api_url = f"{base_url}/api/{version}"
self.authorization_header = {"Authorization": "{} {}".format("Bearer",self.user_key)} self.authorization_header = {"Authorization": "{} {}".format("Bearer", self.user_key)}
def post(self, path, param): def post(self, path, param):
res = requests.post(url=self.api_url + path, json=param, headers=self.authorization_header) res = requests.post(url=self.api_url + path, json=param, headers=self.authorization_header)
return res return res
def get(self, path, params=''): def get(self, path, params=None):
res = requests.get(self.api_url + path, params=params, headers=self.authorization_header) res = requests.get(url=self.api_url + path, params=params, headers=self.authorization_header)
return res return res
def create_dataset(self, name:str,avatar:str="",description:str="",language:str="English",permission:str="me", def delete(self, path, params):
document_count:int=0,chunk_count:int=0,parser_method:str="naive", res = requests.delete(url=self.api_url + path, params=params, headers=self.authorization_header)
parser_config:DataSet.ParserConfig=None): return res
def create_dataset(self, name: str, avatar: str = "", description: str = "", language: str = "English",
permission: str = "me",
document_count: int = 0, chunk_count: int = 0, parse_method: str = "naive",
parser_config: DataSet.ParserConfig = None) -> DataSet:
if parser_config is None: if parser_config is None:
parser_config = DataSet.ParserConfig(self, {"chunk_token_count":128,"layout_recognize": True, "delimiter":"\n!?。;!?","task_page_size":12}) parser_config = DataSet.ParserConfig(self, {"chunk_token_count": 128, "layout_recognize": True,
parser_config=parser_config.to_json() "delimiter": "\n!?。;!?", "task_page_size": 12})
res=self.post("/dataset/save",{"name":name,"avatar":avatar,"description":description,"language":language,"permission":permission, parser_config = parser_config.to_json()
"doc_num": document_count,"chunk_num":chunk_count,"parser_id":parser_method, res = self.post("/dataset/save",
"parser_config":parser_config {"name": name, "avatar": avatar, "description": description, "language": language,
} "permission": permission,
) "doc_num": document_count, "chunk_num": chunk_count, "parser_id": parse_method,
"parser_config": parser_config
}
)
res = res.json() res = res.json()
if not res.get("retmsg"): if res.get("retmsg") == "success":
return DataSet(self, res["data"]) return DataSet(self, res["data"])
raise Exception(res["retmsg"]) raise Exception(res["retmsg"])
def list_datasets(self, page: int = 1, page_size: int = 150, orderby: str = "create_time", desc: bool = True) -> \
List[DataSet]:
res = self.get("/dataset/list", {"page": page, "page_size": page_size, "orderby": orderby, "desc": desc})
res = res.json()
result_list = []
if res.get("retmsg") == "success":
for data in res['data']:
result_list.append(DataSet(self, data))
return result_list
raise Exception(res["retmsg"])
def get_dataset(self, id: str = None, name: str = None) -> DataSet:
res = self.get("/dataset/detail", {"id": id, "name": name})
res = res.json()
if res.get("retmsg") == "success":
return DataSet(self, res['data'])
raise Exception(res["retmsg"])

View File

@ -7,7 +7,7 @@ from test_sdkbase import TestSdk
class TestDataset(TestSdk): class TestDataset(TestSdk):
def test_create_dataset_with_success(self): def test_create_dataset_with_success(self):
""" """
Test creating dataset with success Test creating a dataset with success
""" """
rag = RAGFlow(API_KEY, HOST_ADDRESS) rag = RAGFlow(API_KEY, HOST_ADDRESS)
ds = rag.create_dataset("God") ds = rag.create_dataset("God")
@ -18,15 +18,46 @@ class TestDataset(TestSdk):
def test_update_dataset_with_success(self): def test_update_dataset_with_success(self):
""" """
Test updating dataset with success. Test updating a dataset with success.
""" """
rag = RAGFlow(API_KEY, HOST_ADDRESS) rag = RAGFlow(API_KEY, HOST_ADDRESS)
ds = rag.create_dataset("ABC") ds = rag.create_dataset("ABC")
if isinstance(ds, DataSet): if isinstance(ds, DataSet):
assert ds.name == "ABC", "Name does not match." assert ds.name == "ABC", "Name does not match."
ds.name = 'DEF' ds.name = 'DEF'
res = ds.save() res = ds.save()
assert res is True, f"Failed to update dataset, error: {res}" assert res is True, f"Failed to update dataset, error: {res}"
else: else:
assert False, f"Failed to create dataset, error: {ds}" assert False, f"Failed to create dataset, error: {ds}"
def test_delete_dataset_with_success(self):
"""
Test deleting a dataset with success
"""
rag = RAGFlow(API_KEY, HOST_ADDRESS)
ds = rag.create_dataset("MA")
if isinstance(ds, DataSet):
assert ds.name == "MA", "Name does not match."
res = ds.delete()
assert res is True, f"Failed to delete dataset, error: {res}"
else:
assert False, f"Failed to create dataset, error: {ds}"
def test_list_datasets_with_success(self):
"""
Test listing datasets with success
"""
rag = RAGFlow(API_KEY, HOST_ADDRESS)
list_datasets = rag.list_datasets()
assert len(list_datasets) > 0, "Do not exist any dataset"
for ds in list_datasets:
assert isinstance(ds, DataSet), "Existence type is not dataset."
def test_get_detail_dataset_with_success(self):
"""
Test getting a dataset's detail with success
"""
rag = RAGFlow(API_KEY, HOST_ADDRESS)
ds = rag.get_dataset(name="God")
assert isinstance(ds, DataSet), f"Failed to get dataset, error: {ds}."
assert ds.name == "God", "Name does not match"