From 0a877941f406c52c1c4e5e4196da48a0d857ed5e Mon Sep 17 00:00:00 2001 From: liu an Date: Thu, 13 Mar 2025 18:32:57 +0800 Subject: [PATCH] Test: Added test cases for Download Documents HTTP API (#6032) ### What problem does this PR solve? cover [download docments endpoints](https://ragflow.io/docs/dev/http_api_reference#download-document) ### Type of change - [x] add test cases --- sdk/python/test/libs/utils/__init__.py | 12 ++ sdk/python/test/test_http_api/common.py | 44 ++++ sdk/python/test/test_http_api/conftest.py | 49 ++--- .../test_download_document.py | 193 ++++++++++++++++++ .../test_upload_documents.py | 7 +- 5 files changed, 271 insertions(+), 34 deletions(-) create mode 100644 sdk/python/test/test_http_api/test_file_management_within_dataset/test_download_document.py diff --git a/sdk/python/test/libs/utils/__init__.py b/sdk/python/test/libs/utils/__init__.py index 9dd45daf3..86e9a8b6c 100644 --- a/sdk/python/test/libs/utils/__init__.py +++ b/sdk/python/test/libs/utils/__init__.py @@ -15,6 +15,7 @@ # import base64 +import hashlib from pathlib import Path @@ -23,3 +24,14 @@ def encode_avatar(image_path): binary_data = file.read() base64_encoded = base64.b64encode(binary_data).decode("utf-8") return base64_encoded + + +def compare_by_hash(file1, file2, algorithm="sha256"): + def _calc_hash(file_path): + hash_func = hashlib.new(algorithm) + with open(file_path, "rb") as f: + while chunk := f.read(8192): + hash_func.update(chunk) + return hash_func.hexdigest() + + return _calc_hash(file1) == _calc_hash(file2) diff --git a/sdk/python/test/test_http_api/common.py b/sdk/python/test/test_http_api/common.py index 08fc0d04b..2b87e4d95 100644 --- a/sdk/python/test/test_http_api/common.py +++ b/sdk/python/test/test_http_api/common.py @@ -18,6 +18,7 @@ import os from pathlib import Path import requests +from libs.utils.file_utils import create_txt_file from requests_toolbelt import MultipartEncoder HEADERS = {"Content-Type": "application/json"} @@ -99,3 +100,46 @@ def upload_documnets(auth, dataset_id, files_path=None): data=m, ) return res.json() + + +def batch_upload_documents(auth, dataset_id, num, tmp_path): + fps = [] + for i in range(num): + fp = create_txt_file(tmp_path / f"ragflow_test_upload_{i}.txt") + fps.append(fp) + res = upload_documnets(auth, dataset_id, fps) + document_ids = [] + for document in res["data"]: + document_ids.append(document["id"]) + return document_ids + + +def download_document(auth, dataset_id, document_id, save_path): + url = f"{HOST_ADDRESS}{FILE_API_URL}/{document_id}".format(dataset_id=dataset_id) + res = requests.get(url=url, auth=auth, stream=True) + try: + if res.status_code == 200: + with open(save_path, "wb") as f: + for chunk in res.iter_content(chunk_size=8192): + f.write(chunk) + finally: + res.close() + + return res + + +def list_documnet(auth, dataset_id, params=None): + url = f"{HOST_ADDRESS}{FILE_API_URL}".format(dataset_id=dataset_id) + res = requests.get( + url=url, + headers=HEADERS, + auth=auth, + params=params, + ) + return res.json() + + +def update_documnet(auth, dataset_id, document_id, payload): + url = f"{HOST_ADDRESS}{FILE_API_URL}/{document_id}".format(dataset_id=dataset_id) + res = requests.put(url=url, headers=HEADERS, auth=auth, json=payload) + return res.json() diff --git a/sdk/python/test/test_http_api/conftest.py b/sdk/python/test/test_http_api/conftest.py index 9c3e9c36e..dfd6d592e 100644 --- a/sdk/python/test/test_http_api/conftest.py +++ b/sdk/python/test/test_http_api/conftest.py @@ -38,36 +38,23 @@ def clear_datasets(get_http_api_auth): @pytest.fixture -def generate_test_files(tmp_path): +def generate_test_files(request, tmp_path): + file_creators = { + "docx": (tmp_path / "ragflow_test.docx", create_docx_file), + "excel": (tmp_path / "ragflow_test.xlsx", create_excel_file), + "ppt": (tmp_path / "ragflow_test.pptx", create_ppt_file), + "image": (tmp_path / "ragflow_test.png", create_image_file), + "pdf": (tmp_path / "ragflow_test.pdf", create_pdf_file), + "txt": (tmp_path / "ragflow_test.txt", create_txt_file), + "md": (tmp_path / "ragflow_test.md", create_md_file), + "json": (tmp_path / "ragflow_test.json", create_json_file), + "eml": (tmp_path / "ragflow_test.eml", create_eml_file), + "html": (tmp_path / "ragflow_test.html", create_html_file), + } + files = {} - files["docx"] = tmp_path / "ragflow_test.docx" - create_docx_file(files["docx"]) - - files["excel"] = tmp_path / "ragflow_test.xlsx" - create_excel_file(files["excel"]) - - files["ppt"] = tmp_path / "ragflow_test.pptx" - create_ppt_file(files["ppt"]) - - files["image"] = tmp_path / "ragflow_test.png" - create_image_file(files["image"]) - - files["pdf"] = tmp_path / "ragflow_test.pdf" - create_pdf_file(files["pdf"]) - - files["txt"] = tmp_path / "ragflow_test.txt" - create_txt_file(files["txt"]) - - files["md"] = tmp_path / "ragflow_test.md" - create_md_file(files["md"]) - - files["json"] = tmp_path / "ragflow_test.json" - create_json_file(files["json"]) - - files["eml"] = tmp_path / "ragflow_test.eml" - create_eml_file(files["eml"]) - - files["html"] = tmp_path / "ragflow_test.html" - create_html_file(files["html"]) - + for file_type, (file_path, creator_func) in file_creators.items(): + if request.param in ["", file_type]: + creator_func(file_path) + files[file_type] = file_path return files diff --git a/sdk/python/test/test_http_api/test_file_management_within_dataset/test_download_document.py b/sdk/python/test/test_http_api/test_file_management_within_dataset/test_download_document.py new file mode 100644 index 000000000..da7d6d192 --- /dev/null +++ b/sdk/python/test/test_http_api/test_file_management_within_dataset/test_download_document.py @@ -0,0 +1,193 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import json +from concurrent.futures import ThreadPoolExecutor + +import pytest +from common import ( + INVALID_API_TOKEN, + batch_upload_documents, + create_datasets, + download_document, + upload_documnets, +) +from libs.auth import RAGFlowHttpApiAuth +from libs.utils import compare_by_hash +from requests import codes + + +class TestAuthorization: + @pytest.mark.parametrize( + "auth, expected_code, expected_message", + [ + (None, 0, "`Authorization` can't be empty"), + ( + RAGFlowHttpApiAuth(INVALID_API_TOKEN), + 109, + "Authentication error: API key is invalid!", + ), + ], + ) + def test_invalid_auth( + self, get_http_api_auth, tmp_path, auth, expected_code, expected_message + ): + ids = create_datasets(get_http_api_auth, 1) + document_ids = batch_upload_documents(get_http_api_auth, ids[0], 1, tmp_path) + res = download_document( + auth, ids[0], document_ids[0], tmp_path / "ragflow_tes.txt" + ) + assert res.status_code == codes.ok + with (tmp_path / "ragflow_tes.txt").open("r") as f: + response_json = json.load(f) + assert response_json["code"] == expected_code + assert response_json["message"] == expected_message + + +class TestDownloadDocument: + @pytest.mark.parametrize( + "generate_test_files", + [ + "docx", + "excel", + "ppt", + "image", + "pdf", + "txt", + "md", + "json", + "eml", + "html", + ], + indirect=True, + ) + def test_file_type_validation( + self, get_http_api_auth, generate_test_files, request + ): + ids = create_datasets(get_http_api_auth, 1) + fp = generate_test_files[request.node.callspec.params["generate_test_files"]] + res = upload_documnets(get_http_api_auth, ids[0], [fp]) + document_id = res["data"][0]["id"] + + res = download_document( + get_http_api_auth, + ids[0], + document_id, + fp.with_stem("ragflow_test_download"), + ) + assert res.status_code == codes.ok + assert compare_by_hash( + fp, + fp.with_stem("ragflow_test_download"), + ) + + @pytest.mark.parametrize( + "docment_id, expected_code, expected_message", + [ + pytest.param("", 0, "", marks=pytest.mark.xfail(reason="issue#6031")), + ( + "invalid_document_id", + 102, + "The dataset not own the document invalid_document_id.", + ), + ], + ) + def test_invalid_docment_id( + self, get_http_api_auth, tmp_path, docment_id, expected_code, expected_message + ): + ids = create_datasets(get_http_api_auth, 1) + res = download_document( + get_http_api_auth, + ids[0], + docment_id, + tmp_path / "ragflow_test_download_1.txt", + ) + assert res.status_code == codes.ok + with (tmp_path / "ragflow_test_download_1.txt").open("r") as f: + response_json = json.load(f) + assert response_json["code"] == expected_code + assert response_json["message"] == expected_message + + @pytest.mark.parametrize( + "dataset_id, expected_code, expected_message", + [ + ("", 100, ""), + ( + "invalid_dataset_id", + 102, + "You do not own the dataset invalid_dataset_id.", + ), + ], + ) + def test_invalid_dataset_id( + self, get_http_api_auth, tmp_path, dataset_id, expected_code, expected_message + ): + ids = create_datasets(get_http_api_auth, 1) + document_ids = batch_upload_documents(get_http_api_auth, ids[0], 1, tmp_path) + res = download_document( + get_http_api_auth, + dataset_id, + document_ids[0], + tmp_path / "ragflow_test_download_1.txt", + ) + assert res.status_code == codes.ok + with (tmp_path / "ragflow_test_download_1.txt").open("r") as f: + response_json = json.load(f) + assert response_json["code"] == expected_code + assert response_json["message"] == expected_message + + def test_same_file_repeat(self, get_http_api_auth, tmp_path): + num = 5 + ids = create_datasets(get_http_api_auth, 1) + document_ids = batch_upload_documents(get_http_api_auth, ids[0], 1, tmp_path) + for i in range(num): + res = download_document( + get_http_api_auth, + ids[0], + document_ids[0], + tmp_path / f"ragflow_test_download_{i}.txt", + ) + assert res.status_code == codes.ok + assert compare_by_hash( + tmp_path / "ragflow_test_upload_0.txt", + tmp_path / f"ragflow_test_download_{i}.txt", + ) + + def test_concurrent_download(self, get_http_api_auth, tmp_path): + document_count = 20 + ids = create_datasets(get_http_api_auth, 1) + document_ids = batch_upload_documents( + get_http_api_auth, ids[0], document_count, tmp_path + ) + + with ThreadPoolExecutor(max_workers=5) as executor: + futures = [ + executor.submit( + download_document, + get_http_api_auth, + ids[0], + document_ids[i], + tmp_path / f"ragflow_test_download_{i}.txt", + ) + for i in range(document_count) + ] + responses = [f.result() for f in futures] + assert all(r.status_code == codes.ok for r in responses) + for i in range(document_count): + assert compare_by_hash( + tmp_path / f"ragflow_test_upload_{i}.txt", + tmp_path / f"ragflow_test_download_{i}.txt", + ) diff --git a/sdk/python/test/test_http_api/test_file_management_within_dataset/test_upload_documents.py b/sdk/python/test/test_http_api/test_file_management_within_dataset/test_upload_documents.py index b4cc86f91..2f55532ae 100644 --- a/sdk/python/test/test_http_api/test_file_management_within_dataset/test_upload_documents.py +++ b/sdk/python/test/test_http_api/test_file_management_within_dataset/test_upload_documents.py @@ -64,7 +64,7 @@ class TestUploadDocuments: assert res["data"][0]["name"] == fp.name @pytest.mark.parametrize( - "file_type", + "generate_test_files", [ "docx", "excel", @@ -77,12 +77,13 @@ class TestUploadDocuments: "eml", "html", ], + indirect=True, ) def test_file_type_validation( - self, get_http_api_auth, generate_test_files, file_type + self, get_http_api_auth, generate_test_files, request ): ids = create_datasets(get_http_api_auth, 1) - fp = generate_test_files[file_type] + fp = generate_test_files[request.node.callspec.params["generate_test_files"]] res = upload_documnets(get_http_api_auth, ids[0], [fp]) assert res["code"] == 0 assert res["data"][0]["dataset_id"] == ids[0]