diff --git a/api/apps/sdk/dataset.py b/api/apps/sdk/dataset.py index e69de29bb..7a885ab38 100644 --- a/api/apps/sdk/dataset.py +++ b/api/apps/sdk/dataset.py @@ -0,0 +1,96 @@ +# +# Copyright 2024 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from flask import request + +from api.db import StatusEnum +from api.db.db_models import APIToken +from api.db.services.knowledgebase_service import KnowledgebaseService +from api.db.services.user_service import TenantService +from api.settings import RetCode +from api.utils import get_uuid +from api.utils.api_utils import get_data_error_result +from api.utils.api_utils import get_json_result + + +@manager.route('/save', methods=['POST']) +def save(): + req = request.json + token = request.headers.get('Authorization').split()[1] + objs = APIToken.query(token=token) + if not objs: + return get_json_result( + data=False, retmsg='Token is not valid!"', retcode=RetCode.AUTHENTICATION_ERROR) + tenant_id = objs[0].tenant_id + e, t = TenantService.get_by_id(tenant_id) + if not e: + return get_data_error_result(retmsg="Tenant not found.") + if "id" not in req: + req['id'] = get_uuid() + req["name"] = req["name"].strip() + if req["name"] == "": + return get_data_error_result( + retmsg="Name is not empty") + if KnowledgebaseService.query(name=req["name"]): + return get_data_error_result( + retmsg="Duplicated knowledgebase name") + req["tenant_id"] = tenant_id + req['created_by'] = tenant_id + req['embd_id'] = t.embd_id + if not KnowledgebaseService.save(**req): + return get_data_error_result(retmsg="Data saving error") + req.pop('created_by') + keys_to_rename = {'embd_id': "embedding_model", 'parser_id': 'parser_method', + 'chunk_num': 'chunk_count', 'doc_num': 'document_count'} + for old_key,new_key in keys_to_rename.items(): + if old_key in req: + req[new_key]=req.pop(old_key) + return get_json_result(data=req) + else: + if req["tenant_id"] != tenant_id or req["embd_id"] != t.embd_id: + return get_data_error_result( + retmsg="Can't change tenant_id or embedding_model") + + e, kb = KnowledgebaseService.get_by_id(req["id"]) + if not e: + return get_data_error_result( + retmsg="Can't find this knowledgebase!") + + if not KnowledgebaseService.query( + created_by=tenant_id, id=req["id"]): + return get_json_result( + data=False, retmsg=f'Only owner of knowledgebase authorized for this operation.', + retcode=RetCode.OPERATING_ERROR) + + if req["chunk_num"] != kb.chunk_num or req['doc_num'] != kb.doc_num: + return get_data_error_result( + retmsg="Can't change document_count or chunk_count ") + + if kb.chunk_num > 0 and req['parser_id'] != kb.parser_id: + return get_data_error_result( + retmsg="if chunk count is not 0, parser method is not changable. ") + + + if req["name"].lower() != kb.name.lower() \ + and len(KnowledgebaseService.query(name=req["name"], tenant_id=req['tenant_id'], + status=StatusEnum.VALID.value)) > 0: + return get_data_error_result( + retmsg="Duplicated knowledgebase name.") + + del req["id"] + req['created_by'] = tenant_id + if not KnowledgebaseService.update_by_id(kb.id, req): + return get_data_error_result(retmsg="Data update error ") + return get_json_result(data=True) diff --git a/sdk/python/ragflow/__init__.py b/sdk/python/ragflow/__init__.py index 1ef2f0879..fbdb1bcea 100644 --- a/sdk/python/ragflow/__init__.py +++ b/sdk/python/ragflow/__init__.py @@ -3,3 +3,4 @@ import importlib.metadata __version__ = importlib.metadata.version("ragflow") from .ragflow import RAGFlow +from .modules.dataset import DataSet \ No newline at end of file diff --git a/sdk/python/ragflow/modules/dataset.py b/sdk/python/ragflow/modules/dataset.py index 889f3703f..7689cf7fe 100644 --- a/sdk/python/ragflow/modules/dataset.py +++ b/sdk/python/ragflow/modules/dataset.py @@ -2,7 +2,7 @@ from .base import Base class DataSet(Base): - class ParseConfig(Base): + class ParserConfig(Base): def __init__(self, rag, res_dict): self.chunk_token_count = 128 self.layout_recognize = True @@ -21,13 +21,18 @@ class DataSet(Base): self.permission = "me" self.document_count = 0 self.chunk_count = 0 - self.parse_method = 0 + self.parser_method = "naive" self.parser_config = None super().__init__(rag, res_dict) - def delete(self): - try: - self.post("/rm", {"kb_id": self.id}) - return True - except Exception: - return False + def save(self): + res = self.post('/dataset/save', + {"id": self.id, "name": self.name, "avatar": self.avatar, "tenant_id": self.tenant_id, + "description": self.description, "language": self.language, "embd_id": self.embedding_model, + "permission": self.permission, + "doc_num": self.document_count, "chunk_num": self.chunk_count, "parser_id": self.parser_method, + "parser_config": self.parser_config.to_json() + }) + res = res.json() + if not res.get("retmsg"): return True + raise Exception(res["retmsg"]) \ No newline at end of file diff --git a/sdk/python/ragflow/ragflow.py b/sdk/python/ragflow/ragflow.py index 3a94ea95f..ff3dba7da 100644 --- a/sdk/python/ragflow/ragflow.py +++ b/sdk/python/ragflow/ragflow.py @@ -21,180 +21,34 @@ from .modules.dataset import DataSet class RAGFlow: def __init__(self, user_key, base_url, version='v1'): """ - api_url: http:///v1 - dataset_url: http:///v1/kb - document_url: http:///v1/dataset/{dataset_id}/documents + api_url: http:///api/v1 """ self.user_key = user_key - self.api_url = f"{base_url}/{version}" - self.dataset_url = f"{self.api_url}/kb" - self.authorization_header = {"Authorization": "{}".format(self.user_key)} - self.base_url = base_url + self.api_url = f"{base_url}/api/{version}" + self.authorization_header = {"Authorization": "{} {}".format("Bearer",self.user_key)} def post(self, path, param): - res = requests.post(url=self.dataset_url + path, json=param, headers=self.authorization_header) + res = requests.post(url=self.api_url + path, json=param, headers=self.authorization_header) return res def get(self, path, params=''): - res = requests.get(self.dataset_url + path, params=params, headers=self.authorization_header) + res = requests.get(self.api_url + path, params=params, headers=self.authorization_header) return res - def create_dataset(self, dataset_name): - """ - name: dataset name - """ - res_create = self.post("/create", {"name": dataset_name}) - res_create_data = res_create.json()['data'] - res_detail = self.get("/detail", {"kb_id": res_create_data["kb_id"]}) - res_detail_data = res_detail.json()['data'] - result = {} - result['id'] = res_detail_data['id'] - result['name'] = res_detail_data['name'] - result['avatar'] = res_detail_data['avatar'] - result['description'] = res_detail_data['description'] - result['language'] = res_detail_data['language'] - result['embedding_model'] = res_detail_data['embd_id'] - result['permission'] = res_detail_data['permission'] - result['document_count'] = res_detail_data['doc_num'] - result['chunk_count'] = res_detail_data['chunk_num'] - result['parser_config'] = res_detail_data['parser_config'] - dataset = DataSet(self, result) - return dataset + def create_dataset(self, name:str,avatar:str="",description:str="",language:str="English",permission:str="me", + document_count:int=0,chunk_count:int=0,parser_method:str="naive", + parser_config:DataSet.ParserConfig=None): + if parser_config is None: + parser_config = DataSet.ParserConfig(self, {"chunk_token_count":128,"layout_recognize": True, "delimiter":"\n!?。;!?","task_page_size":12}) + parser_config=parser_config.to_json() + res=self.post("/dataset/save",{"name":name,"avatar":avatar,"description":description,"language":language,"permission":permission, + "doc_num": document_count,"chunk_num":chunk_count,"parser_id":parser_method, + "parser_config":parser_config + } + ) + res = res.json() + if not res.get("retmsg"): + return DataSet(self, res["data"]) + raise Exception(res["retmsg"]) - """ - def delete_dataset(self, dataset_name): - dataset_id = self.find_dataset_id_by_name(dataset_name) - endpoint = f"{self.dataset_url}/{dataset_id}" - res = requests.delete(endpoint, headers=self.authorization_header) - return res.json() - - def find_dataset_id_by_name(self, dataset_name): - res = requests.get(self.dataset_url, headers=self.authorization_header) - for dataset in res.json()["data"]: - if dataset["name"] == dataset_name: - return dataset["id"] - return None - - def get_dataset(self, dataset_name): - dataset_id = self.find_dataset_id_by_name(dataset_name) - endpoint = f"{self.dataset_url}/{dataset_id}" - response = requests.get(endpoint, headers=self.authorization_header) - return response.json() - - def update_dataset(self, dataset_name, **params): - dataset_id = self.find_dataset_id_by_name(dataset_name) - - endpoint = f"{self.dataset_url}/{dataset_id}" - response = requests.put(endpoint, json=params, headers=self.authorization_header) - return response.json() - - # ------------------------------- CONTENT MANAGEMENT ----------------------------------------------------- - - # ----------------------------upload local files----------------------------------------------------- - def upload_local_file(self, dataset_id, file_paths): - files = [] - - for file_path in file_paths: - if not isinstance(file_path, str): - return {"code": RetCode.ARGUMENT_ERROR, "message": f"{file_path} is not string."} - if "http" in file_path: - return {"code": RetCode.ARGUMENT_ERROR, "message": "Remote files have not unsupported."} - if os.path.isfile(file_path): - files.append(("file", open(file_path, "rb"))) - else: - return {"code": RetCode.DATA_ERROR, "message": f"The file {file_path} does not exist"} - - res = requests.request("POST", url=f"{self.dataset_url}/{dataset_id}/documents", files=files, - headers=self.authorization_header) - - result_dict = json.loads(res.text) - return result_dict - - # ----------------------------delete a file----------------------------------------------------- - def delete_files(self, document_id, dataset_id): - endpoint = f"{self.dataset_url}/{dataset_id}/documents/{document_id}" - res = requests.delete(endpoint, headers=self.authorization_header) - return res.json() - - # ----------------------------list files----------------------------------------------------- - def list_files(self, dataset_id, offset=0, count=-1, order_by="create_time", descend=True, keywords=""): - params = { - "offset": offset, - "count": count, - "order_by": order_by, - "descend": descend, - "keywords": keywords - } - endpoint = f"{self.dataset_url}/{dataset_id}/documents/" - res = requests.get(endpoint, params=params, headers=self.authorization_header) - return res.json() - - # ----------------------------update files: enable, rename, template_type------------------------------------------- - def update_file(self, dataset_id, document_id, **params): - endpoint = f"{self.dataset_url}/{dataset_id}/documents/{document_id}" - response = requests.put(endpoint, json=params, headers=self.authorization_header) - return response.json() - - # ----------------------------download a file----------------------------------------------------- - def download_file(self, dataset_id, document_id): - endpoint = f"{self.dataset_url}/{dataset_id}/documents/{document_id}" - res = requests.get(endpoint, headers=self.authorization_header) - - content = res.content # binary data - # decode the binary data - try: - decoded_content = content.decode("utf-8") - json_data = json.loads(decoded_content) - return json_data # message - except json.JSONDecodeError: # binary data - _, document = DocumentService.get_by_id(document_id) - file_path = os.path.join(os.getcwd(), document.name) - with open(file_path, "wb") as file: - file.write(content) - return {"code": RetCode.SUCCESS, "data": content} - - # ----------------------------start parsing----------------------------------------------------- - def start_parsing_document(self, dataset_id, document_id): - endpoint = f"{self.dataset_url}/{dataset_id}/documents/{document_id}/status" - res = requests.post(endpoint, headers=self.authorization_header) - - return res.json() - - def start_parsing_documents(self, dataset_id, doc_ids=None): - endpoint = f"{self.dataset_url}/{dataset_id}/documents/status" - res = requests.post(endpoint, headers=self.authorization_header, json={"doc_ids": doc_ids}) - - return res.json() - - # ----------------------------stop parsing----------------------------------------------------- - def stop_parsing_document(self, dataset_id, document_id): - endpoint = f"{self.dataset_url}/{dataset_id}/documents/{document_id}/status" - res = requests.delete(endpoint, headers=self.authorization_header) - - return res.json() - - def stop_parsing_documents(self, dataset_id, doc_ids=None): - endpoint = f"{self.dataset_url}/{dataset_id}/documents/status" - res = requests.delete(endpoint, headers=self.authorization_header, json={"doc_ids": doc_ids}) - - return res.json() - - # ----------------------------show the status of the file----------------------------------------------------- - def show_parsing_status(self, dataset_id, document_id): - endpoint = f"{self.dataset_url}/{dataset_id}/documents/{document_id}/status" - res = requests.get(endpoint, headers=self.authorization_header) - - return res.json() - # ----------------------------list the chunks of the file----------------------------------------------------- - - # ----------------------------delete the chunk----------------------------------------------------- - - # ----------------------------edit the status of the chunk----------------------------------------------------- - - # ----------------------------insert a new chunk----------------------------------------------------- - - # ----------------------------get a specific chunk----------------------------------------------------- - - # ----------------------------retrieval test----------------------------------------------------- -""" diff --git a/sdk/python/test/common.py b/sdk/python/test/common.py index ac286db40..5feca4777 100644 --- a/sdk/python/test/common.py +++ b/sdk/python/test/common.py @@ -1,4 +1,4 @@ -API_KEY = 'IjUxNGM0MmM4NWY5MzExZWY5MDhhMDI0MmFjMTIwMDA2Ig.ZsWebA.mV1NKdSPPllgowiH-7vz36tMWyI' +API_KEY = 'ragflow-k0N2I1MzQwNjNhMzExZWY5ODg1MDI0Mm' HOST_ADDRESS = 'http://127.0.0.1:9380' \ No newline at end of file diff --git a/sdk/python/test/t_dataset.py b/sdk/python/test/t_dataset.py index 35b1e4c97..1466233a1 100644 --- a/sdk/python/test/t_dataset.py +++ b/sdk/python/test/t_dataset.py @@ -1,4 +1,4 @@ -from ragflow import RAGFlow +from ragflow import RAGFlow, DataSet from common import API_KEY, HOST_ADDRESS from test_sdkbase import TestSdk @@ -6,18 +6,27 @@ from test_sdkbase import TestSdk class TestDataset(TestSdk): def test_create_dataset_with_success(self): + """ + Test creating dataset with success + """ rag = RAGFlow(API_KEY, HOST_ADDRESS) ds = rag.create_dataset("God") - assert ds is not None, "The dataset creation failed, returned None." - assert ds.name == "God", "Dataset name does not match." + if isinstance(ds, DataSet): + assert ds.name == "God", "Name does not match." + else: + assert False, f"Failed to create dataset, error: {ds}" - def test_delete_one_file(self): + def test_update_dataset_with_success(self): """ - Test deleting one file with success. + Test updating dataset with success. """ rag = RAGFlow(API_KEY, HOST_ADDRESS) ds = rag.create_dataset("ABC") - assert ds is not None, "Failed to create dataset" - assert ds.name == "ABC", "Dataset name mismatch" - delete_result = ds.delete() - assert delete_result is True, "Failed to delete dataset" + if isinstance(ds, DataSet): + assert ds.name == "ABC", "Name does not match." + ds.name = 'DEF' + res = ds.save() + assert res is True, f"Failed to update dataset, error: {res}" + + else: + assert False, f"Failed to create dataset, error: {ds}" \ No newline at end of file