diff --git a/sdk/python/test/test_frontend_api/common.py b/sdk/python/test/test_frontend_api/common.py index aa6e258e0..4e9835763 100644 --- a/sdk/python/test/test_frontend_api/common.py +++ b/sdk/python/test/test_frontend_api/common.py @@ -5,6 +5,7 @@ HOST_ADDRESS = os.getenv('HOST_ADDRESS', 'http://127.0.0.1:9380') DATASET_NAME_LIMIT = 128 + def create_dataset(auth, dataset_name): authorization = {"Authorization": auth} url = f"{HOST_ADDRESS}/v1/kb/create" @@ -27,8 +28,53 @@ def rm_dataset(auth, dataset_id): res = requests.post(url=url, headers=authorization, json=json) return res.json() + def update_dataset(auth, json_req): authorization = {"Authorization": auth} url = f"{HOST_ADDRESS}/v1/kb/update" res = requests.post(url=url, headers=authorization, json=json_req) return res.json() + + +def upload_file(auth, dataset_id, path): + authorization = {"Authorization": auth} + url = f"{HOST_ADDRESS}/v1/document/upload" + base_name = os.path.basename(path) + json_req = { + "kb_id": dataset_id, + } + + file = { + 'file': open(f'{path}', 'rb') + } + + res = requests.post(url=url, headers=authorization, files=file, data=json_req) + return res.json() + +def list_document(auth, dataset_id): + authorization = {"Authorization": auth} + url = f"{HOST_ADDRESS}/v1/document/list?kb_id={dataset_id}" + res = requests.get(url=url, headers=authorization) + return res.json() + +def get_docs_info(auth, doc_ids): + authorization = {"Authorization": auth} + json_req = { + "doc_ids": doc_ids + } + url = f"{HOST_ADDRESS}/v1/document/infos" + res = requests.post(url=url, headers=authorization, json=json_req) + return res.json() + +def parse_docs(auth, doc_ids): + authorization = {"Authorization": auth} + json_req = { + "doc_ids": doc_ids, + "run": 1 + } + url = f"{HOST_ADDRESS}/v1/document/run" + res = requests.post(url=url, headers=authorization, json=json_req) + return res.json() + +def parse_file(auth, document_id): + pass \ No newline at end of file diff --git a/sdk/python/test/test_frontend_api/test_chunk.py b/sdk/python/test/test_frontend_api/test_chunk.py new file mode 100644 index 000000000..555b93601 --- /dev/null +++ b/sdk/python/test/test_frontend_api/test_chunk.py @@ -0,0 +1,76 @@ +# +# Copyright 2024 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from common import HOST_ADDRESS, create_dataset, list_dataset, rm_dataset, update_dataset, upload_file, DATASET_NAME_LIMIT +from common import list_document, get_docs_info, parse_docs +from time import sleep +from timeit import default_timer as timer +import re +import pytest +import random +import string + + +def test_parse_txt_document(get_auth): + # create dataset + res = create_dataset(get_auth, "test_parse_txt_document") + assert res.get("code") == 0, f"{res.get('message')}" + + # list dataset + page_number = 1 + dataset_list = [] + dataset_id = None + while True: + res = list_dataset(get_auth, page_number) + data = res.get("data").get("kbs") + for item in data: + dataset_id = item.get("id") + dataset_list.append(dataset_id) + if len(dataset_list) < page_number * 150: + break + page_number += 1 + + filename = 'ragflow_test.txt' + res = upload_file(get_auth, dataset_id, f"../test_sdk_api/test_data/{filename}") + assert res.get("code") == 0, f"{res.get('message')}" + + res = list_document(get_auth, dataset_id) + + doc_id_list = [] + for doc in res['data']['docs']: + doc_id_list.append(doc['id']) + + res = get_docs_info(get_auth, doc_id_list) + print(doc_id_list) + doc_count = len(doc_id_list) + res = parse_docs(get_auth, doc_id_list) + + start_ts = timer() + while True: + res = get_docs_info(get_auth, doc_id_list) + finished_count = 0 + for doc_info in res['data']: + if doc_info['progress'] == 1: + finished_count += 1 + if finished_count == doc_count: + break + sleep(1) + print('time cost {:.1f}s'.format(timer() - start_ts)) + + # delete dataset + for dataset_id in dataset_list: + res = rm_dataset(get_auth, dataset_id) + assert res.get("code") == 0, f"{res.get('message')}" + print(f"{len(dataset_list)} datasets are deleted")