diff --git a/api/apps/dataset_api.py b/api/apps/dataset_api.py index 884207e28..3b290630b 100644 --- a/api/apps/dataset_api.py +++ b/api/apps/dataset_api.py @@ -12,15 +12,16 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import os import pathlib import re import warnings +from io import BytesIO -from flask import request +from flask import request, send_file from flask_login import login_required, current_user from httpx import HTTPError +from minio import S3Error from api.contants import NAME_LENGTH_LIMIT from api.db import FileType, ParserType, FileSource @@ -283,9 +284,12 @@ def upload_documents(dataset_id): return construct_json_result(code=RetCode.DATA_ERROR, message=f"You try to upload {num_file_objs} files, " f"which exceeds the maximum number of uploading files: {MAXIMUM_OF_UPLOADING_FILES}") + # no dataset + exist, dataset = KnowledgebaseService.get_by_id(dataset_id) + if not exist: + return construct_json_result(message="Can't find this dataset", code=RetCode.DATA_ERROR) + for file_obj in file_objs: - # the content of the file - file_content = file_obj.read() file_name = file_obj.filename # no name if not file_name: @@ -296,15 +300,6 @@ def upload_documents(dataset_id): if 'http' in file_name: return construct_json_result(code=RetCode.ARGUMENT_ERROR, message="Remote files have not unsupported.") - # the content is empty, raising a warning - if file_content == b'': - warnings.warn(f"[WARNING]: The file {file_name} is empty.") - - # no dataset - exist, dataset = KnowledgebaseService.get_by_id(dataset_id) - if not exist: - return construct_json_result(message="Can't find this dataset", code=RetCode.DATA_ERROR) - # get the root_folder root_folder = FileService.get_root_folder(current_user.id) # get the id of the root_folder @@ -342,8 +337,14 @@ def upload_documents(dataset_id): location = filename while MINIO.obj_exist(dataset_id, location): location += "_" + blob = file.read() + # the content is empty, raising a warning + if blob == b'': + warnings.warn(f"[WARNING]: The file {filename} is empty.") + MINIO.put(dataset_id, location, blob) + doc = { "id": get_uuid(), "kb_id": dataset.id, @@ -555,6 +556,40 @@ def is_illegal_value_for_enum(value, enum_class): return value not in enum_class.__members__.values() # ----------------------------download a file----------------------------------------------------- +@manager.route("//documents/", methods=["GET"]) +@login_required +def download_document(dataset_id, document_id): + try: + # Check whether there is this dataset + exist, _ = KnowledgebaseService.get_by_id(dataset_id) + if not exist: + return construct_json_result(code=RetCode.DATA_ERROR, message=f"This dataset '{dataset_id}' cannot be found!") + + # Check whether there is this document + exist, document = DocumentService.get_by_id(document_id) + if not exist: + return construct_json_result(message=f"This document '{document_id}' cannot be found!", + code=RetCode.ARGUMENT_ERROR) + + # The process of downloading + doc_id, doc_location = File2DocumentService.get_minio_address(doc_id=document_id) # minio address + file_stream = MINIO.get(doc_id, doc_location) + if not file_stream: + return construct_json_result(message="This file is empty.", code=RetCode.DATA_ERROR) + + file = BytesIO(file_stream) + + # Use send_file with a proper filename and MIME type + return send_file( + file, + as_attachment=True, + download_name=document.name, + mimetype='application/octet-stream' # Set a default MIME type + ) + + # Error + except Exception as e: + return construct_error_response(e) # ----------------------------start parsing----------------------------------------------------- @@ -564,7 +599,7 @@ def is_illegal_value_for_enum(value, enum_class): # ----------------------------list the chunks of the file----------------------------------------------------- -# ----------------------------delete the chunk----------------------------------------------------- +# -- --------------------------delete the chunk----------------------------------------------------- # ----------------------------edit the status of the chunk----------------------------------------------------- @@ -576,3 +611,5 @@ def is_illegal_value_for_enum(value, enum_class): # ----------------------------retrieval test----------------------------------------------------- + + diff --git a/sdk/python/ragflow/ragflow.py b/sdk/python/ragflow/ragflow.py index ee940e835..6275f921c 100644 --- a/sdk/python/ragflow/ragflow.py +++ b/sdk/python/ragflow/ragflow.py @@ -12,12 +12,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import json import os import requests +from api.db.services.document_service import DocumentService from api.settings import RetCode @@ -126,7 +126,22 @@ class RAGFlow: return response.json() # ----------------------------download a file----------------------------------------------------- + def download_file(self, dataset_id, document_id): + endpoint = f"{self.dataset_url}/{dataset_id}/documents/{document_id}" + res = requests.get(endpoint, headers=self.authorization_header) + content = res.content # binary data + # decode the binary data + try: + decoded_content = content.decode("utf-8") + json_data = json.loads(decoded_content) + return json_data # message + except json.JSONDecodeError: # binary data + _, document = DocumentService.get_by_id(document_id) + file_path = os.path.join(os.getcwd(), document.name) + with open(file_path, "wb") as file: + file.write(content) + return {"code": RetCode.SUCCESS, "data": content} # ----------------------------start parsing----------------------------------------------------- # ----------------------------stop parsing----------------------------------------------------- @@ -144,3 +159,4 @@ class RAGFlow: # ----------------------------get a specific chunk----------------------------------------------------- # ----------------------------retrieval test----------------------------------------------------- + diff --git a/sdk/python/test/common.py b/sdk/python/test/common.py index 5dd313f50..94acbf48c 100644 --- a/sdk/python/test/common.py +++ b/sdk/python/test/common.py @@ -1,4 +1,4 @@ -API_KEY = 'ImFhMmJhZmUwMmQxNzExZWZhZDdmMzA0M2Q3ZWU1MzdlIg.ZnDsIQ.u-0-_qCRU6a4WICxyAPsjaafyOo' +API_KEY = 'IjJkOGQ4ZDE2MzkyMjExZWZhYTk0MzA0M2Q3ZWU1MzdlIg.ZoUfug.RmqcYyCrlAnLtkzk6bYXiXN3eEY' HOST_ADDRESS = 'http://127.0.0.1:9380' \ No newline at end of file diff --git a/sdk/python/test/test_document.py b/sdk/python/test/test_document.py index 81b84692f..f7f87a148 100644 --- a/sdk/python/test/test_document.py +++ b/sdk/python/test/test_document.py @@ -3,7 +3,6 @@ from test_sdkbase import TestSdk from ragflow import RAGFlow import pytest from common import API_KEY, HOST_ADDRESS -from api.contants import NAME_LENGTH_LIMIT class TestFile(TestSdk): @@ -625,8 +624,76 @@ class TestFile(TestSdk): update_res = ragflow.update_file(created_res_id, doc_id, **params) assert (update_res["code"] == RetCode.DATA_ERROR and update_res["message"] == "Illegal value ? for 'template_type' field.") + # ----------------------------download a file----------------------------------------------------- + def test_download_nonexistent_document(self): + """ + Test downloading a document which does not exist. + """ + # create a dataset + ragflow = RAGFlow(API_KEY, HOST_ADDRESS) + created_res = ragflow.create_dataset("test_download_nonexistent_document") + created_res_id = created_res["data"]["dataset_id"] + res = ragflow.download_file(created_res_id, "imagination") + assert res["code"] == RetCode.ARGUMENT_ERROR and res["message"] == f"This document 'imagination' cannot be found!" + + def test_download_document_in_nonexistent_dataset(self): + """ + Test downloading a document whose dataset is nonexistent. + """ + # create a dataset + ragflow = RAGFlow(API_KEY, HOST_ADDRESS) + created_res = ragflow.create_dataset("test_download_nonexistent_document") + created_res_id = created_res["data"]["dataset_id"] + # upload files + file_paths = ["test_data/test.txt"] + uploading_res = ragflow.upload_local_file(created_res_id, file_paths) + # get the doc_id + data = uploading_res["data"][0] + doc_id = data["id"] + # download file + res = ragflow.download_file("imagination", doc_id) + assert res["code"] == RetCode.DATA_ERROR and res["message"] == f"This dataset 'imagination' cannot be found!" + + def test_download_document_with_success(self): + """ + Test the downloading of a document with success. + """ + # create a dataset + ragflow = RAGFlow(API_KEY, HOST_ADDRESS) + created_res = ragflow.create_dataset("test_download_nonexistent_document") + created_res_id = created_res["data"]["dataset_id"] + # upload files + file_paths = ["test_data/test.txt"] + uploading_res = ragflow.upload_local_file(created_res_id, file_paths) + # get the doc_id + data = uploading_res["data"][0] + doc_id = data["id"] + # download file + with open("test_data/test.txt", "rb") as file: + binary_data = file.read() + res = ragflow.download_file(created_res_id, doc_id) + assert res["code"] == RetCode.SUCCESS and res["data"] == binary_data + + def test_download_an_empty_document(self): + """ + Test the downloading of an empty document. + """ + # create a dataset + ragflow = RAGFlow(API_KEY, HOST_ADDRESS) + created_res = ragflow.create_dataset("test_download_nonexistent_document") + created_res_id = created_res["data"]["dataset_id"] + # upload files + file_paths = ["test_data/empty.txt"] + uploading_res = ragflow.upload_local_file(created_res_id, file_paths) + # get the doc_id + data = uploading_res["data"][0] + doc_id = data["id"] + # download file + res = ragflow.download_file(created_res_id, doc_id) + assert res["code"] == RetCode.DATA_ERROR and res["message"] == "This file is empty." + # ----------------------------start parsing----------------------------------------------------- # ----------------------------stop parsing-----------------------------------------------------