From 0ebf05440ee3d6353544fe165b6e6533532bb06d Mon Sep 17 00:00:00 2001 From: Yongteng Lei Date: Mon, 19 May 2025 14:54:06 +0800 Subject: [PATCH] Feat: repair corrupted PDF files on upload automatically (#7693) ### What problem does this PR solve? Try the best to repair corrupted PDF files on upload automatically. ### Type of change - [x] New Feature (non-breaking change which adds functionality) --- .github/workflows/tests.yml | 4 +- Dockerfile | 3 +- api/apps/document_app.py | 323 ++++++++++++-------------------- api/db/services/file_service.py | 143 ++++++-------- api/utils/file_utils.py | 101 +++++++--- 5 files changed, 251 insertions(+), 323 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 0ec4766e0..aad6b6fc9 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -50,9 +50,9 @@ jobs: # https://github.com/astral-sh/ruff-action - name: Static check with Ruff - uses: astral-sh/ruff-action@v2 + uses: astral-sh/ruff-action@v3 with: - version: ">=0.8.2" + version: ">=0.11.x" args: "check" - name: Build ragflow:nightly-slim diff --git a/Dockerfile b/Dockerfile index 47c533b9d..67fd26456 100644 --- a/Dockerfile +++ b/Dockerfile @@ -59,7 +59,8 @@ RUN --mount=type=cache,id=ragflow_apt,target=/var/cache/apt,sharing=locked \ apt install -y libatk-bridge2.0-0 && \ apt install -y libpython3-dev libgtk-4-1 libnss3 xdg-utils libgbm-dev && \ apt install -y libjemalloc-dev && \ - apt install -y python3-pip pipx nginx unzip curl wget git vim less + apt install -y python3-pip pipx nginx unzip curl wget git vim less && \ + apt install -y ghostscript RUN if [ "$NEED_MIRROR" == "1" ]; then \ pip3 config set global.index-url https://mirrors.aliyun.com/pypi/simple && \ diff --git a/api/apps/document_app.py b/api/apps/document_app.py index c37c9d1ee..43c7f813b 100644 --- a/api/apps/document_app.py +++ b/api/apps/document_app.py @@ -20,79 +20,73 @@ import re import flask from flask import request -from flask_login import login_required, current_user +from flask_login import current_user, login_required -from deepdoc.parser.html_parser import RAGFlowHtmlParser -from rag.nlp import search - -from api.db import VALID_FILE_TYPES, VALID_TASK_STATUS, FileType, TaskStatus, ParserType, FileSource +from api import settings +from api.constants import IMG_BASE64_PREFIX +from api.db import VALID_FILE_TYPES, VALID_TASK_STATUS, FileSource, FileType, ParserType, TaskStatus from api.db.db_models import File, Task +from api.db.services import duplicate_name +from api.db.services.document_service import DocumentService, doc_upload_and_parse from api.db.services.file2document_service import File2DocumentService from api.db.services.file_service import FileService -from api.db.services.task_service import queue_tasks -from api.db.services.user_service import UserTenantService -from api.db.services import duplicate_name from api.db.services.knowledgebase_service import KnowledgebaseService -from api.db.services.task_service import TaskService -from api.db.services.document_service import DocumentService, doc_upload_and_parse +from api.db.services.task_service import TaskService, queue_tasks +from api.db.services.user_service import UserTenantService +from api.utils import get_uuid from api.utils.api_utils import ( - server_error_response, get_data_error_result, + get_json_result, + server_error_response, validate_request, ) -from api.utils import get_uuid -from api import settings -from api.utils.api_utils import get_json_result -from rag.utils.storage_factory import STORAGE_IMPL -from api.utils.file_utils import filename_type, thumbnail, get_project_base_directory +from api.utils.file_utils import filename_type, get_project_base_directory, thumbnail from api.utils.web_utils import html2pdf, is_valid_url -from api.constants import IMG_BASE64_PREFIX +from deepdoc.parser.html_parser import RAGFlowHtmlParser +from rag.nlp import search +from rag.utils.storage_factory import STORAGE_IMPL -@manager.route('/upload', methods=['POST']) # noqa: F821 +@manager.route("/upload", methods=["POST"]) # noqa: F821 @login_required @validate_request("kb_id") def upload(): kb_id = request.form.get("kb_id") if not kb_id: - return get_json_result( - data=False, message='Lack of "KB ID"', code=settings.RetCode.ARGUMENT_ERROR) - if 'file' not in request.files: - return get_json_result( - data=False, message='No file part!', code=settings.RetCode.ARGUMENT_ERROR) + return get_json_result(data=False, message='Lack of "KB ID"', code=settings.RetCode.ARGUMENT_ERROR) + if "file" not in request.files: + return get_json_result(data=False, message="No file part!", code=settings.RetCode.ARGUMENT_ERROR) - file_objs = request.files.getlist('file') + file_objs = request.files.getlist("file") for file_obj in file_objs: - if file_obj.filename == '': - return get_json_result( - data=False, message='No file selected!', code=settings.RetCode.ARGUMENT_ERROR) + if file_obj.filename == "": + return get_json_result(data=False, message="No file selected!", code=settings.RetCode.ARGUMENT_ERROR) e, kb = KnowledgebaseService.get_by_id(kb_id) if not e: raise LookupError("Can't find this knowledgebase!") - err, files = FileService.upload_document(kb, file_objs, current_user.id) - files = [f[0] for f in files] # remove the blob - + + if not files: + return get_json_result(data=files, message="There seems to be an issue with your file format. Please verify it is correct and not corrupted.", code=settings.RetCode.DATA_ERROR) + files = [f[0] for f in files] # remove the blob + if err: - return get_json_result( - data=files, message="\n".join(err), code=settings.RetCode.SERVER_ERROR) + return get_json_result(data=files, message="\n".join(err), code=settings.RetCode.SERVER_ERROR) return get_json_result(data=files) -@manager.route('/web_crawl', methods=['POST']) # noqa: F821 +@manager.route("/web_crawl", methods=["POST"]) # noqa: F821 @login_required @validate_request("kb_id", "name", "url") def web_crawl(): kb_id = request.form.get("kb_id") if not kb_id: - return get_json_result( - data=False, message='Lack of "KB ID"', code=settings.RetCode.ARGUMENT_ERROR) + return get_json_result(data=False, message='Lack of "KB ID"', code=settings.RetCode.ARGUMENT_ERROR) name = request.form.get("name") url = request.form.get("url") if not is_valid_url(url): - return get_json_result( - data=False, message='The URL format is invalid', code=settings.RetCode.ARGUMENT_ERROR) + return get_json_result(data=False, message="The URL format is invalid", code=settings.RetCode.ARGUMENT_ERROR) e, kb = KnowledgebaseService.get_by_id(kb_id) if not e: raise LookupError("Can't find this knowledgebase!") @@ -108,10 +102,7 @@ def web_crawl(): kb_folder = FileService.new_a_file_from_kb(kb.tenant_id, kb.name, kb_root_folder["id"]) try: - filename = duplicate_name( - DocumentService.query, - name=name + ".pdf", - kb_id=kb.id) + filename = duplicate_name(DocumentService.query, name=name + ".pdf", kb_id=kb.id) filetype = filename_type(filename) if filetype == FileType.OTHER.value: raise RuntimeError("This type of file has not been supported yet!") @@ -130,7 +121,7 @@ def web_crawl(): "name": filename, "location": location, "size": len(blob), - "thumbnail": thumbnail(filename, blob) + "thumbnail": thumbnail(filename, blob), } if doc["type"] == FileType.VISUAL: doc["parser_id"] = ParserType.PICTURE.value @@ -147,58 +138,53 @@ def web_crawl(): return get_json_result(data=True) -@manager.route('/create', methods=['POST']) # noqa: F821 +@manager.route("/create", methods=["POST"]) # noqa: F821 @login_required @validate_request("name", "kb_id") def create(): req = request.json kb_id = req["kb_id"] if not kb_id: - return get_json_result( - data=False, message='Lack of "KB ID"', code=settings.RetCode.ARGUMENT_ERROR) + return get_json_result(data=False, message='Lack of "KB ID"', code=settings.RetCode.ARGUMENT_ERROR) try: e, kb = KnowledgebaseService.get_by_id(kb_id) if not e: - return get_data_error_result( - message="Can't find this knowledgebase!") + return get_data_error_result(message="Can't find this knowledgebase!") if DocumentService.query(name=req["name"], kb_id=kb_id): - return get_data_error_result( - message="Duplicated document name in the same knowledgebase.") + return get_data_error_result(message="Duplicated document name in the same knowledgebase.") - doc = DocumentService.insert({ - "id": get_uuid(), - "kb_id": kb.id, - "parser_id": kb.parser_id, - "parser_config": kb.parser_config, - "created_by": current_user.id, - "type": FileType.VIRTUAL, - "name": req["name"], - "location": "", - "size": 0 - }) + doc = DocumentService.insert( + { + "id": get_uuid(), + "kb_id": kb.id, + "parser_id": kb.parser_id, + "parser_config": kb.parser_config, + "created_by": current_user.id, + "type": FileType.VIRTUAL, + "name": req["name"], + "location": "", + "size": 0, + } + ) return get_json_result(data=doc.to_json()) except Exception as e: return server_error_response(e) -@manager.route('/list', methods=['POST']) # noqa: F821 +@manager.route("/list", methods=["POST"]) # noqa: F821 @login_required def list_docs(): kb_id = request.args.get("kb_id") if not kb_id: - return get_json_result( - data=False, message='Lack of "KB ID"', code=settings.RetCode.ARGUMENT_ERROR) + return get_json_result(data=False, message='Lack of "KB ID"', code=settings.RetCode.ARGUMENT_ERROR) tenants = UserTenantService.query(user_id=current_user.id) for tenant in tenants: - if KnowledgebaseService.query( - tenant_id=tenant.tenant_id, id=kb_id): + if KnowledgebaseService.query(tenant_id=tenant.tenant_id, id=kb_id): break else: - return get_json_result( - data=False, message='Only owner of knowledgebase authorized for this operation.', - code=settings.RetCode.OPERATING_ERROR) + return get_json_result(data=False, message="Only owner of knowledgebase authorized for this operation.", code=settings.RetCode.OPERATING_ERROR) keywords = request.args.get("keywords", "") page_number = int(request.args.get("page", 0)) @@ -212,83 +198,67 @@ def list_docs(): if run_status: invalid_status = {s for s in run_status if s not in VALID_TASK_STATUS} if invalid_status: - return get_data_error_result( - message=f"Invalid filter run status conditions: {', '.join(invalid_status)}" - ) + return get_data_error_result(message=f"Invalid filter run status conditions: {', '.join(invalid_status)}") types = req.get("types", []) if types: invalid_types = {t for t in types if t not in VALID_FILE_TYPES} if invalid_types: - return get_data_error_result( - message=f"Invalid filter conditions: {', '.join(invalid_types)} type{'s' if len(invalid_types) > 1 else ''}" - ) + return get_data_error_result(message=f"Invalid filter conditions: {', '.join(invalid_types)} type{'s' if len(invalid_types) > 1 else ''}") try: - docs, tol = DocumentService.get_by_kb_id( - kb_id, page_number, items_per_page, orderby, desc, keywords, run_status, types) + docs, tol = DocumentService.get_by_kb_id(kb_id, page_number, items_per_page, orderby, desc, keywords, run_status, types) for doc_item in docs: - if doc_item['thumbnail'] and not doc_item['thumbnail'].startswith(IMG_BASE64_PREFIX): - doc_item['thumbnail'] = f"/v1/document/image/{kb_id}-{doc_item['thumbnail']}" + if doc_item["thumbnail"] and not doc_item["thumbnail"].startswith(IMG_BASE64_PREFIX): + doc_item["thumbnail"] = f"/v1/document/image/{kb_id}-{doc_item['thumbnail']}" return get_json_result(data={"total": tol, "docs": docs}) except Exception as e: return server_error_response(e) -@manager.route('/infos', methods=['POST']) # noqa: F821 +@manager.route("/infos", methods=["POST"]) # noqa: F821 @login_required def docinfos(): req = request.json doc_ids = req["doc_ids"] for doc_id in doc_ids: if not DocumentService.accessible(doc_id, current_user.id): - return get_json_result( - data=False, - message='No authorization.', - code=settings.RetCode.AUTHENTICATION_ERROR - ) + return get_json_result(data=False, message="No authorization.", code=settings.RetCode.AUTHENTICATION_ERROR) docs = DocumentService.get_by_ids(doc_ids) return get_json_result(data=list(docs.dicts())) -@manager.route('/thumbnails', methods=['GET']) # noqa: F821 +@manager.route("/thumbnails", methods=["GET"]) # noqa: F821 # @login_required def thumbnails(): doc_ids = request.args.get("doc_ids").split(",") if not doc_ids: - return get_json_result( - data=False, message='Lack of "Document ID"', code=settings.RetCode.ARGUMENT_ERROR) + return get_json_result(data=False, message='Lack of "Document ID"', code=settings.RetCode.ARGUMENT_ERROR) try: docs = DocumentService.get_thumbnails(doc_ids) for doc_item in docs: - if doc_item['thumbnail'] and not doc_item['thumbnail'].startswith(IMG_BASE64_PREFIX): - doc_item['thumbnail'] = f"/v1/document/image/{doc_item['kb_id']}-{doc_item['thumbnail']}" + if doc_item["thumbnail"] and not doc_item["thumbnail"].startswith(IMG_BASE64_PREFIX): + doc_item["thumbnail"] = f"/v1/document/image/{doc_item['kb_id']}-{doc_item['thumbnail']}" return get_json_result(data={d["id"]: d["thumbnail"] for d in docs}) except Exception as e: return server_error_response(e) -@manager.route('/change_status', methods=['POST']) # noqa: F821 +@manager.route("/change_status", methods=["POST"]) # noqa: F821 @login_required @validate_request("doc_id", "status") def change_status(): req = request.json if str(req["status"]) not in ["0", "1"]: - return get_json_result( - data=False, - message='"Status" must be either 0 or 1!', - code=settings.RetCode.ARGUMENT_ERROR) + return get_json_result(data=False, message='"Status" must be either 0 or 1!', code=settings.RetCode.ARGUMENT_ERROR) if not DocumentService.accessible(req["doc_id"], current_user.id): - return get_json_result( - data=False, - message='No authorization.', - code=settings.RetCode.AUTHENTICATION_ERROR) + return get_json_result(data=False, message="No authorization.", code=settings.RetCode.AUTHENTICATION_ERROR) try: e, doc = DocumentService.get_by_id(req["doc_id"]) @@ -296,23 +266,19 @@ def change_status(): return get_data_error_result(message="Document not found!") e, kb = KnowledgebaseService.get_by_id(doc.kb_id) if not e: - return get_data_error_result( - message="Can't find this knowledgebase!") + return get_data_error_result(message="Can't find this knowledgebase!") - if not DocumentService.update_by_id( - req["doc_id"], {"status": str(req["status"])}): - return get_data_error_result( - message="Database error (Document update)!") + if not DocumentService.update_by_id(req["doc_id"], {"status": str(req["status"])}): + return get_data_error_result(message="Database error (Document update)!") status = int(req["status"]) - settings.docStoreConn.update({"doc_id": req["doc_id"]}, {"available_int": status}, - search.index_name(kb.tenant_id), doc.kb_id) + settings.docStoreConn.update({"doc_id": req["doc_id"]}, {"available_int": status}, search.index_name(kb.tenant_id), doc.kb_id) return get_json_result(data=True) except Exception as e: return server_error_response(e) -@manager.route('/rm', methods=['POST']) # noqa: F821 +@manager.route("/rm", methods=["POST"]) # noqa: F821 @login_required @validate_request("doc_id") def rm(): @@ -323,11 +289,7 @@ def rm(): for doc_id in doc_ids: if not DocumentService.accessible4deletion(doc_id, current_user.id): - return get_json_result( - data=False, - message='No authorization.', - code=settings.RetCode.AUTHENTICATION_ERROR - ) + return get_json_result(data=False, message="No authorization.", code=settings.RetCode.AUTHENTICATION_ERROR) root_folder = FileService.get_root_folder(current_user.id) pf_id = root_folder["id"] @@ -347,8 +309,7 @@ def rm(): TaskService.filter_delete([Task.doc_id == doc_id]) if not DocumentService.remove_document(doc, tenant_id): - return get_data_error_result( - message="Database error (Document removal)!") + return get_data_error_result(message="Database error (Document removal)!") f2d = File2DocumentService.get_by_document_id(doc_id) deleted_file_count = 0 @@ -376,18 +337,14 @@ def rm(): return get_json_result(data=True) -@manager.route('/run', methods=['POST']) # noqa: F821 +@manager.route("/run", methods=["POST"]) # noqa: F821 @login_required @validate_request("doc_ids", "run") -def run(): +def run(): req = request.json for doc_id in req["doc_ids"]: if not DocumentService.accessible(doc_id, current_user.id): - return get_json_result( - data=False, - message='No authorization.', - code=settings.RetCode.AUTHENTICATION_ERROR - ) + return get_json_result(data=False, message="No authorization.", code=settings.RetCode.AUTHENTICATION_ERROR) try: kb_table_num_map = {} for id in req["doc_ids"]: @@ -421,7 +378,7 @@ def run(): if kb_id not in kb_table_num_map: count = DocumentService.count_by_kb_id(kb_id=kb_id, keywords="", run_status=[TaskStatus.DONE], types=[]) kb_table_num_map[kb_id] = count - if kb_table_num_map[kb_id] <=0: + if kb_table_num_map[kb_id] <= 0: KnowledgebaseService.delete_field_map(kb_id) bucket, name = File2DocumentService.get_storage_address(doc_id=doc["id"]) queue_tasks(doc, bucket, name, 0) @@ -431,36 +388,25 @@ def run(): return server_error_response(e) -@manager.route('/rename', methods=['POST']) # noqa: F821 +@manager.route("/rename", methods=["POST"]) # noqa: F821 @login_required @validate_request("doc_id", "name") def rename(): req = request.json if not DocumentService.accessible(req["doc_id"], current_user.id): - return get_json_result( - data=False, - message='No authorization.', - code=settings.RetCode.AUTHENTICATION_ERROR - ) + return get_json_result(data=False, message="No authorization.", code=settings.RetCode.AUTHENTICATION_ERROR) try: e, doc = DocumentService.get_by_id(req["doc_id"]) if not e: return get_data_error_result(message="Document not found!") - if pathlib.Path(req["name"].lower()).suffix != pathlib.Path( - doc.name.lower()).suffix: - return get_json_result( - data=False, - message="The extension of file can't be changed", - code=settings.RetCode.ARGUMENT_ERROR) + if pathlib.Path(req["name"].lower()).suffix != pathlib.Path(doc.name.lower()).suffix: + return get_json_result(data=False, message="The extension of file can't be changed", code=settings.RetCode.ARGUMENT_ERROR) for d in DocumentService.query(name=req["name"], kb_id=doc.kb_id): if d.name == req["name"]: - return get_data_error_result( - message="Duplicated document name in the same knowledgebase.") + return get_data_error_result(message="Duplicated document name in the same knowledgebase.") - if not DocumentService.update_by_id( - req["doc_id"], {"name": req["name"]}): - return get_data_error_result( - message="Database error (Document rename)!") + if not DocumentService.update_by_id(req["doc_id"], {"name": req["name"]}): + return get_data_error_result(message="Database error (Document rename)!") informs = File2DocumentService.get_by_document_id(req["doc_id"]) if informs: @@ -472,7 +418,7 @@ def rename(): return server_error_response(e) -@manager.route('/get/', methods=['GET']) # noqa: F821 +@manager.route("/get/", methods=["GET"]) # noqa: F821 # @login_required def get(doc_id): try: @@ -486,29 +432,22 @@ def get(doc_id): ext = re.search(r"\.([^.]+)$", doc.name) if ext: if doc.type == FileType.VISUAL.value: - response.headers.set('Content-Type', 'image/%s' % ext.group(1)) + response.headers.set("Content-Type", "image/%s" % ext.group(1)) else: - response.headers.set( - 'Content-Type', - 'application/%s' % - ext.group(1)) + response.headers.set("Content-Type", "application/%s" % ext.group(1)) return response except Exception as e: return server_error_response(e) -@manager.route('/change_parser', methods=['POST']) # noqa: F821 +@manager.route("/change_parser", methods=["POST"]) # noqa: F821 @login_required @validate_request("doc_id", "parser_id") def change_parser(): req = request.json if not DocumentService.accessible(req["doc_id"], current_user.id): - return get_json_result( - data=False, - message='No authorization.', - code=settings.RetCode.AUTHENTICATION_ERROR - ) + return get_json_result(data=False, message="No authorization.", code=settings.RetCode.AUTHENTICATION_ERROR) try: e, doc = DocumentService.get_by_id(req["doc_id"]) if not e: @@ -520,21 +459,16 @@ def change_parser(): else: return get_json_result(data=True) - if ((doc.type == FileType.VISUAL and req["parser_id"] != "picture") - or (re.search( - r"\.(ppt|pptx|pages)$", doc.name) and req["parser_id"] != "presentation")): + if (doc.type == FileType.VISUAL and req["parser_id"] != "picture") or (re.search(r"\.(ppt|pptx|pages)$", doc.name) and req["parser_id"] != "presentation"): return get_data_error_result(message="Not supported yet!") - e = DocumentService.update_by_id(doc.id, - {"parser_id": req["parser_id"], "progress": 0, "progress_msg": "", - "run": TaskStatus.UNSTART.value}) + e = DocumentService.update_by_id(doc.id, {"parser_id": req["parser_id"], "progress": 0, "progress_msg": "", "run": TaskStatus.UNSTART.value}) if not e: return get_data_error_result(message="Document not found!") if "parser_config" in req: DocumentService.update_parser_config(doc.id, req["parser_config"]) if doc.token_num > 0: - e = DocumentService.increment_chunk_num(doc.id, doc.kb_id, doc.token_num * -1, doc.chunk_num * -1, - doc.process_duation * -1) + e = DocumentService.increment_chunk_num(doc.id, doc.kb_id, doc.token_num * -1, doc.chunk_num * -1, doc.process_duation * -1) if not e: return get_data_error_result(message="Document not found!") tenant_id = DocumentService.get_tenant_id(req["doc_id"]) @@ -548,7 +482,7 @@ def change_parser(): return server_error_response(e) -@manager.route('/image/', methods=['GET']) # noqa: F821 +@manager.route("/image/", methods=["GET"]) # noqa: F821 # @login_required def get_image(image_id): try: @@ -557,53 +491,46 @@ def get_image(image_id): return get_data_error_result(message="Image not found.") bkt, nm = image_id.split("-") response = flask.make_response(STORAGE_IMPL.get(bkt, nm)) - response.headers.set('Content-Type', 'image/JPEG') + response.headers.set("Content-Type", "image/JPEG") return response except Exception as e: return server_error_response(e) -@manager.route('/upload_and_parse', methods=['POST']) # noqa: F821 +@manager.route("/upload_and_parse", methods=["POST"]) # noqa: F821 @login_required @validate_request("conversation_id") def upload_and_parse(): - if 'file' not in request.files: - return get_json_result( - data=False, message='No file part!', code=settings.RetCode.ARGUMENT_ERROR) + if "file" not in request.files: + return get_json_result(data=False, message="No file part!", code=settings.RetCode.ARGUMENT_ERROR) - file_objs = request.files.getlist('file') + file_objs = request.files.getlist("file") for file_obj in file_objs: - if file_obj.filename == '': - return get_json_result( - data=False, message='No file selected!', code=settings.RetCode.ARGUMENT_ERROR) + if file_obj.filename == "": + return get_json_result(data=False, message="No file selected!", code=settings.RetCode.ARGUMENT_ERROR) doc_ids = doc_upload_and_parse(request.form.get("conversation_id"), file_objs, current_user.id) return get_json_result(data=doc_ids) -@manager.route('/parse', methods=['POST']) # noqa: F821 +@manager.route("/parse", methods=["POST"]) # noqa: F821 @login_required def parse(): url = request.json.get("url") if request.json else "" if url: if not is_valid_url(url): - return get_json_result( - data=False, message='The URL format is invalid', code=settings.RetCode.ARGUMENT_ERROR) + return get_json_result(data=False, message="The URL format is invalid", code=settings.RetCode.ARGUMENT_ERROR) download_path = os.path.join(get_project_base_directory(), "logs/downloads") os.makedirs(download_path, exist_ok=True) from seleniumwire.webdriver import Chrome, ChromeOptions + options = ChromeOptions() - options.add_argument('--headless') - options.add_argument('--disable-gpu') - options.add_argument('--no-sandbox') - options.add_argument('--disable-dev-shm-usage') - options.add_experimental_option('prefs', { - 'download.default_directory': download_path, - 'download.prompt_for_download': False, - 'download.directory_upgrade': True, - 'safebrowsing.enabled': True - }) + options.add_argument("--headless") + options.add_argument("--disable-gpu") + options.add_argument("--no-sandbox") + options.add_argument("--disable-dev-shm-usage") + options.add_experimental_option("prefs", {"download.default_directory": download_path, "download.prompt_for_download": False, "download.directory_upgrade": True, "safebrowsing.enabled": True}) driver = Chrome(options=options) driver.get(url) res_headers = [r.response.headers for r in driver.requests if r and r.response] @@ -626,51 +553,41 @@ def parse(): r = re.search(r"filename=\"([^\"]+)\"", str(res_headers)) if not r or not r.group(1): - return get_json_result( - data=False, message="Can't not identify downloaded file", code=settings.RetCode.ARGUMENT_ERROR) + return get_json_result(data=False, message="Can't not identify downloaded file", code=settings.RetCode.ARGUMENT_ERROR) f = File(r.group(1), os.path.join(download_path, r.group(1))) txt = FileService.parse_docs([f], current_user.id) return get_json_result(data=txt) - if 'file' not in request.files: - return get_json_result( - data=False, message='No file part!', code=settings.RetCode.ARGUMENT_ERROR) + if "file" not in request.files: + return get_json_result(data=False, message="No file part!", code=settings.RetCode.ARGUMENT_ERROR) - file_objs = request.files.getlist('file') + file_objs = request.files.getlist("file") txt = FileService.parse_docs(file_objs, current_user.id) return get_json_result(data=txt) -@manager.route('/set_meta', methods=['POST']) # noqa: F821 +@manager.route("/set_meta", methods=["POST"]) # noqa: F821 @login_required @validate_request("doc_id", "meta") def set_meta(): req = request.json if not DocumentService.accessible(req["doc_id"], current_user.id): - return get_json_result( - data=False, - message='No authorization.', - code=settings.RetCode.AUTHENTICATION_ERROR - ) + return get_json_result(data=False, message="No authorization.", code=settings.RetCode.AUTHENTICATION_ERROR) try: meta = json.loads(req["meta"]) except Exception as e: - return get_json_result( - data=False, message=f'Json syntax error: {e}', code=settings.RetCode.ARGUMENT_ERROR) + return get_json_result(data=False, message=f"Json syntax error: {e}", code=settings.RetCode.ARGUMENT_ERROR) if not isinstance(meta, dict): - return get_json_result( - data=False, message='Meta data should be in Json map format, like {"key": "value"}', code=settings.RetCode.ARGUMENT_ERROR) + return get_json_result(data=False, message='Meta data should be in Json map format, like {"key": "value"}', code=settings.RetCode.ARGUMENT_ERROR) try: e, doc = DocumentService.get_by_id(req["doc_id"]) if not e: return get_data_error_result(message="Document not found!") - if not DocumentService.update_by_id( - req["doc_id"], {"meta_fields": meta}): - return get_data_error_result( - message="Database error (meta updates)!") + if not DocumentService.update_by_id(req["doc_id"], {"meta_fields": meta}): + return get_data_error_result(message="Database error (meta updates)!") return get_json_result(data=True) except Exception as e: diff --git a/api/db/services/file_service.py b/api/db/services/file_service.py index 60b51a0e4..803164f9e 100644 --- a/api/db/services/file_service.py +++ b/api/db/services/file_service.py @@ -14,22 +14,21 @@ # limitations under the License. # import logging -import re import os +import re from concurrent.futures import ThreadPoolExecutor from flask_login import current_user from peewee import fn -from api.db import FileType, KNOWLEDGEBASE_FOLDER_NAME, FileSource, ParserType -from api.db.db_models import DB, File2Document, Knowledgebase -from api.db.db_models import File, Document +from api.db import KNOWLEDGEBASE_FOLDER_NAME, FileSource, FileType, ParserType +from api.db.db_models import DB, Document, File, File2Document, Knowledgebase from api.db.services import duplicate_name from api.db.services.common_service import CommonService from api.db.services.document_service import DocumentService from api.db.services.file2document_service import File2DocumentService from api.utils import get_uuid -from api.utils.file_utils import filename_type, thumbnail_img +from api.utils.file_utils import filename_type, read_potential_broken_pdf, thumbnail_img from rag.utils.storage_factory import STORAGE_IMPL @@ -39,8 +38,7 @@ class FileService(CommonService): @classmethod @DB.connection_context() - def get_by_pf_id(cls, tenant_id, pf_id, page_number, items_per_page, - orderby, desc, keywords): + def get_by_pf_id(cls, tenant_id, pf_id, page_number, items_per_page, orderby, desc, keywords): # Get files by parent folder ID with pagination and filtering # Args: # tenant_id: ID of the tenant @@ -53,17 +51,9 @@ class FileService(CommonService): # Returns: # Tuple of (file_list, total_count) if keywords: - files = cls.model.select().where( - (cls.model.tenant_id == tenant_id), - (cls.model.parent_id == pf_id), - (fn.LOWER(cls.model.name).contains(keywords.lower())), - ~(cls.model.id == pf_id) - ) + files = cls.model.select().where((cls.model.tenant_id == tenant_id), (cls.model.parent_id == pf_id), (fn.LOWER(cls.model.name).contains(keywords.lower())), ~(cls.model.id == pf_id)) else: - files = cls.model.select().where((cls.model.tenant_id == tenant_id), - (cls.model.parent_id == pf_id), - ~(cls.model.id == pf_id) - ) + files = cls.model.select().where((cls.model.tenant_id == tenant_id), (cls.model.parent_id == pf_id), ~(cls.model.id == pf_id)) count = files.count() if desc: files = files.order_by(cls.model.getter_by(orderby).desc()) @@ -76,16 +66,20 @@ class FileService(CommonService): for file in res_files: if file["type"] == FileType.FOLDER.value: file["size"] = cls.get_folder_size(file["id"]) - file['kbs_info'] = [] - children = list(cls.model.select().where( - (cls.model.tenant_id == tenant_id), - (cls.model.parent_id == file["id"]), - ~(cls.model.id == file["id"]), - ).dicts()) - file["has_child_folder"] = any(value["type"] == FileType.FOLDER.value for value in children) + file["kbs_info"] = [] + children = list( + cls.model.select() + .where( + (cls.model.tenant_id == tenant_id), + (cls.model.parent_id == file["id"]), + ~(cls.model.id == file["id"]), + ) + .dicts() + ) + file["has_child_folder"] = any(value["type"] == FileType.FOLDER.value for value in children) continue - kbs_info = cls.get_kb_id_by_file_id(file['id']) - file['kbs_info'] = kbs_info + kbs_info = cls.get_kb_id_by_file_id(file["id"]) + file["kbs_info"] = kbs_info return res_files, count @@ -97,16 +91,18 @@ class FileService(CommonService): # file_id: File ID # Returns: # List of dictionaries containing knowledge base IDs and names - kbs = (cls.model.select(*[Knowledgebase.id, Knowledgebase.name]) - .join(File2Document, on=(File2Document.file_id == file_id)) - .join(Document, on=(File2Document.document_id == Document.id)) - .join(Knowledgebase, on=(Knowledgebase.id == Document.kb_id)) - .where(cls.model.id == file_id)) + kbs = ( + cls.model.select(*[Knowledgebase.id, Knowledgebase.name]) + .join(File2Document, on=(File2Document.file_id == file_id)) + .join(Document, on=(File2Document.document_id == Document.id)) + .join(Knowledgebase, on=(Knowledgebase.id == Document.kb_id)) + .where(cls.model.id == file_id) + ) if not kbs: return [] kbs_info_list = [] for kb in list(kbs.dicts()): - kbs_info_list.append({"kb_id": kb['id'], "kb_name": kb['name']}) + kbs_info_list.append({"kb_id": kb["id"], "kb_name": kb["name"]}) return kbs_info_list @classmethod @@ -178,16 +174,9 @@ class FileService(CommonService): if count > len(name) - 2: return file else: - file = cls.insert({ - "id": get_uuid(), - "parent_id": parent_id, - "tenant_id": current_user.id, - "created_by": current_user.id, - "name": name[count], - "location": "", - "size": 0, - "type": FileType.FOLDER.value - }) + file = cls.insert( + {"id": get_uuid(), "parent_id": parent_id, "tenant_id": current_user.id, "created_by": current_user.id, "name": name[count], "location": "", "size": 0, "type": FileType.FOLDER.value} + ) return cls.create_folder(file, file.id, name, count + 1) @classmethod @@ -212,9 +201,7 @@ class FileService(CommonService): # tenant_id: Tenant ID # Returns: # Root folder dictionary - for file in cls.model.select().where((cls.model.tenant_id == tenant_id), - (cls.model.parent_id == cls.model.id) - ): + for file in cls.model.select().where((cls.model.tenant_id == tenant_id), (cls.model.parent_id == cls.model.id)): return file.to_dict() file_id = get_uuid() @@ -239,11 +226,8 @@ class FileService(CommonService): # tenant_id: Tenant ID # Returns: # Knowledge base folder dictionary - for root in cls.model.select().where( - (cls.model.tenant_id == tenant_id), (cls.model.parent_id == cls.model.id)): - for folder in cls.model.select().where( - (cls.model.tenant_id == tenant_id), (cls.model.parent_id == root.id), - (cls.model.name == KNOWLEDGEBASE_FOLDER_NAME)): + for root in cls.model.select().where((cls.model.tenant_id == tenant_id), (cls.model.parent_id == cls.model.id)): + for folder in cls.model.select().where((cls.model.tenant_id == tenant_id), (cls.model.parent_id == root.id), (cls.model.name == KNOWLEDGEBASE_FOLDER_NAME)): return folder.to_dict() assert False, "Can't find the KB folder. Database init error." @@ -271,7 +255,7 @@ class FileService(CommonService): "type": ty, "size": size, "location": location, - "source_type": FileSource.KNOWLEDGEBASE + "source_type": FileSource.KNOWLEDGEBASE, } cls.save(**file) return file @@ -283,12 +267,11 @@ class FileService(CommonService): # Args: # root_id: Root folder ID # tenant_id: Tenant ID - for _ in cls.model.select().where((cls.model.name == KNOWLEDGEBASE_FOLDER_NAME)\ - & (cls.model.parent_id == root_id)): + for _ in cls.model.select().where((cls.model.name == KNOWLEDGEBASE_FOLDER_NAME) & (cls.model.parent_id == root_id)): return folder = cls.new_a_file_from_kb(tenant_id, KNOWLEDGEBASE_FOLDER_NAME, root_id) - for kb in Knowledgebase.select(*[Knowledgebase.id, Knowledgebase.name]).where(Knowledgebase.tenant_id==tenant_id): + for kb in Knowledgebase.select(*[Knowledgebase.id, Knowledgebase.name]).where(Knowledgebase.tenant_id == tenant_id): kb_folder = cls.new_a_file_from_kb(tenant_id, kb.name, folder["id"]) for doc in DocumentService.query(kb_id=kb.id): FileService.add_file_from_kb(doc.to_dict(), kb_folder["id"], tenant_id) @@ -357,12 +340,10 @@ class FileService(CommonService): @DB.connection_context() def delete_folder_by_pf_id(cls, user_id, folder_id): try: - files = cls.model.select().where((cls.model.tenant_id == user_id) - & (cls.model.parent_id == folder_id)) + files = cls.model.select().where((cls.model.tenant_id == user_id) & (cls.model.parent_id == folder_id)) for file in files: cls.delete_folder_by_pf_id(user_id, file.id) - return cls.model.delete().where((cls.model.tenant_id == user_id) - & (cls.model.id == folder_id)).execute(), + return (cls.model.delete().where((cls.model.tenant_id == user_id) & (cls.model.id == folder_id)).execute(),) except Exception: logging.exception("delete_folder_by_pf_id") raise RuntimeError("Database error (File retrieval)!") @@ -380,8 +361,7 @@ class FileService(CommonService): def dfs(parent_id): nonlocal size - for f in cls.model.select(*[cls.model.id, cls.model.size, cls.model.type]).where( - cls.model.parent_id == parent_id, cls.model.id != parent_id): + for f in cls.model.select(*[cls.model.id, cls.model.size, cls.model.type]).where(cls.model.parent_id == parent_id, cls.model.id != parent_id): size += f.size if f.type == FileType.FOLDER.value: dfs(f.id) @@ -403,16 +383,16 @@ class FileService(CommonService): "type": doc["type"], "size": doc["size"], "location": doc["location"], - "source_type": FileSource.KNOWLEDGEBASE + "source_type": FileSource.KNOWLEDGEBASE, } cls.save(**file) File2DocumentService.save(**{"id": get_uuid(), "file_id": file["id"], "document_id": doc["id"]}) - + @classmethod @DB.connection_context() def move_file(cls, file_ids, folder_id): try: - cls.filter_update((cls.model.id << file_ids, ), { 'parent_id': folder_id }) + cls.filter_update((cls.model.id << file_ids,), {"parent_id": folder_id}) except Exception: logging.exception("move_file") raise RuntimeError("Database error (File move)!") @@ -429,16 +409,13 @@ class FileService(CommonService): err, files = [], [] for file in file_objs: try: - MAX_FILE_NUM_PER_USER = int(os.environ.get('MAX_FILE_NUM_PER_USER', 0)) + MAX_FILE_NUM_PER_USER = int(os.environ.get("MAX_FILE_NUM_PER_USER", 0)) if MAX_FILE_NUM_PER_USER > 0 and DocumentService.get_doc_count(kb.tenant_id) >= MAX_FILE_NUM_PER_USER: raise RuntimeError("Exceed the maximum file number of a free user!") if len(file.filename.encode("utf-8")) >= 128: raise RuntimeError("Exceed the maximum length of file name!") - filename = duplicate_name( - DocumentService.query, - name=file.filename, - kb_id=kb.id) + filename = duplicate_name(DocumentService.query, name=file.filename, kb_id=kb.id) filetype = filename_type(filename) if filetype == FileType.OTHER.value: raise RuntimeError("This type of file has not been supported yet!") @@ -446,15 +423,18 @@ class FileService(CommonService): location = filename while STORAGE_IMPL.obj_exist(kb.id, location): location += "_" + blob = file.read() + if filetype == FileType.PDF.value: + blob = read_potential_broken_pdf(blob) STORAGE_IMPL.put(kb.id, location, blob) doc_id = get_uuid() img = thumbnail_img(filename, blob) - thumbnail_location = '' + thumbnail_location = "" if img is not None: - thumbnail_location = f'thumbnail_{doc_id}.png' + thumbnail_location = f"thumbnail_{doc_id}.png" STORAGE_IMPL.put(kb.id, thumbnail_location, img) doc = { @@ -467,7 +447,7 @@ class FileService(CommonService): "name": filename, "location": location, "size": len(blob), - "thumbnail": thumbnail_location + "thumbnail": thumbnail_location, } DocumentService.insert(doc) @@ -480,29 +460,17 @@ class FileService(CommonService): @staticmethod def parse_docs(file_objs, user_id): - from rag.app import presentation, picture, naive, audio, email + from rag.app import audio, email, naive, picture, presentation def dummy(prog=None, msg=""): pass - FACTORY = { - ParserType.PRESENTATION.value: presentation, - ParserType.PICTURE.value: picture, - ParserType.AUDIO.value: audio, - ParserType.EMAIL.value: email - } + FACTORY = {ParserType.PRESENTATION.value: presentation, ParserType.PICTURE.value: picture, ParserType.AUDIO.value: audio, ParserType.EMAIL.value: email} parser_config = {"chunk_token_num": 16096, "delimiter": "\n!?;。;!?", "layout_recognize": "Plain Text"} exe = ThreadPoolExecutor(max_workers=12) threads = [] for file in file_objs: - kwargs = { - "lang": "English", - "callback": dummy, - "parser_config": parser_config, - "from_page": 0, - "to_page": 100000, - "tenant_id": user_id - } + kwargs = {"lang": "English", "callback": dummy, "parser_config": parser_config, "from_page": 0, "to_page": 100000, "tenant_id": user_id} filetype = filename_type(file.filename) blob = file.read() threads.append(exe.submit(FACTORY.get(FileService.get_parser(filetype, file.filename, ""), naive).chunk, file.filename, blob, **kwargs)) @@ -523,4 +491,5 @@ class FileService(CommonService): return ParserType.PRESENTATION.value if re.search(r"\.(eml)$", filename): return ParserType.EMAIL.value - return default \ No newline at end of file + return default + diff --git a/api/utils/file_utils.py b/api/utils/file_utils.py index da84816a1..b90527c70 100644 --- a/api/utils/file_utils.py +++ b/api/utils/file_utils.py @@ -17,17 +17,20 @@ import base64 import json import os import re +import shutil +import subprocess import sys +import tempfile import threading from io import BytesIO import pdfplumber -from PIL import Image from cachetools import LRUCache, cached +from PIL import Image from ruamel.yaml import YAML -from api.db import FileType from api.constants import IMG_BASE64_PREFIX +from api.db import FileType PROJECT_BASE = os.getenv("RAG_PROJECT_BASE") or os.getenv("RAG_DEPLOY_BASE") RAG_BASE = os.getenv("RAG_BASE") @@ -74,7 +77,7 @@ def get_rag_python_directory(*args): def get_home_cache_dir(): - dir = os.path.join(os.path.expanduser('~'), ".ragflow") + dir = os.path.join(os.path.expanduser("~"), ".ragflow") try: os.mkdir(dir) except OSError: @@ -92,9 +95,7 @@ def load_json_conf(conf_path): with open(json_conf_path) as f: return json.load(f) except BaseException: - raise EnvironmentError( - "loading json file config from '{}' failed!".format(json_conf_path) - ) + raise EnvironmentError("loading json file config from '{}' failed!".format(json_conf_path)) def dump_json_conf(config_data, conf_path): @@ -106,9 +107,7 @@ def dump_json_conf(config_data, conf_path): with open(json_conf_path, "w") as f: json.dump(config_data, f, indent=4) except BaseException: - raise EnvironmentError( - "loading json file config from '{}' failed!".format(json_conf_path) - ) + raise EnvironmentError("loading json file config from '{}' failed!".format(json_conf_path)) def load_json_conf_real_time(conf_path): @@ -120,9 +119,7 @@ def load_json_conf_real_time(conf_path): with open(json_conf_path) as f: return json.load(f) except BaseException: - raise EnvironmentError( - "loading json file config from '{}' failed!".format(json_conf_path) - ) + raise EnvironmentError("loading json file config from '{}' failed!".format(json_conf_path)) def load_yaml_conf(conf_path): @@ -130,12 +127,10 @@ def load_yaml_conf(conf_path): conf_path = os.path.join(get_project_base_directory(), conf_path) try: with open(conf_path) as f: - yaml = YAML(typ='safe', pure=True) + yaml = YAML(typ="safe", pure=True) return yaml.load(f) except Exception as e: - raise EnvironmentError( - "loading yaml file config from {} failed:".format(conf_path), e - ) + raise EnvironmentError("loading yaml file config from {} failed:".format(conf_path), e) def rewrite_yaml_conf(conf_path, config): @@ -146,13 +141,11 @@ def rewrite_yaml_conf(conf_path, config): yaml = YAML(typ="safe") yaml.dump(config, f) except Exception as e: - raise EnvironmentError( - "rewrite yaml file config {} failed:".format(conf_path), e - ) + raise EnvironmentError("rewrite yaml file config {} failed:".format(conf_path), e) def rewrite_json_file(filepath, json_data): - with open(filepath, "w", encoding='utf-8') as f: + with open(filepath, "w", encoding="utf-8") as f: json.dump(json_data, f, indent=4, separators=(",", ": ")) f.close() @@ -162,12 +155,10 @@ def filename_type(filename): if re.match(r".*\.pdf$", filename): return FileType.PDF.value - if re.match( - r".*\.(eml|doc|docx|ppt|pptx|yml|xml|htm|json|csv|txt|ini|xls|xlsx|wps|rtf|hlp|pages|numbers|key|md|py|js|java|c|cpp|h|php|go|ts|sh|cs|kt|html|sql)$", filename): + if re.match(r".*\.(eml|doc|docx|ppt|pptx|yml|xml|htm|json|csv|txt|ini|xls|xlsx|wps|rtf|hlp|pages|numbers|key|md|py|js|java|c|cpp|h|php|go|ts|sh|cs|kt|html|sql)$", filename): return FileType.DOC.value - if re.match( - r".*\.(wav|flac|ape|alac|wavpack|wv|mp3|aac|ogg|vorbis|opus|mp3)$", filename): + if re.match(r".*\.(wav|flac|ape|alac|wavpack|wv|mp3|aac|ogg|vorbis|opus|mp3)$", filename): return FileType.AURAL.value if re.match(r".*\.(jpg|jpeg|png|tif|gif|pcx|tga|exif|fpx|svg|psd|cdr|pcd|dxf|ufo|eps|ai|raw|WMF|webp|avif|apng|icon|ico|mpg|mpeg|avi|rm|rmvb|mov|wmv|asf|dat|asx|wvx|mpe|mpa|mp4)$", filename): @@ -175,6 +166,7 @@ def filename_type(filename): return FileType.OTHER.value + def thumbnail_img(filename, blob): """ MySQL LongText max length is 65535 @@ -183,6 +175,7 @@ def thumbnail_img(filename, blob): if re.match(r".*\.pdf$", filename): with sys.modules[LOCK_KEY_pdfplumber]: pdf = pdfplumber.open(BytesIO(blob)) + buffered = BytesIO() resolution = 32 img = None @@ -206,8 +199,9 @@ def thumbnail_img(filename, blob): return buffered.getvalue() elif re.match(r".*\.(ppt|pptx)$", filename): - import aspose.slides as slides import aspose.pydrawing as drawing + import aspose.slides as slides + try: with slides.Presentation(BytesIO(blob)) as presentation: buffered = BytesIO() @@ -215,8 +209,7 @@ def thumbnail_img(filename, blob): img = None for _ in range(10): # https://reference.aspose.com/slides/python-net/aspose.slides/slide/get_thumbnail/#float-float - presentation.slides[0].get_thumbnail(scale, scale).save( - buffered, drawing.imaging.ImageFormat.png) + presentation.slides[0].get_thumbnail(scale, scale).save(buffered, drawing.imaging.ImageFormat.png) img = buffered.getvalue() if len(img) >= 64000: scale = scale / 2.0 @@ -232,10 +225,9 @@ def thumbnail_img(filename, blob): def thumbnail(filename, blob): img = thumbnail_img(filename, blob) if img is not None: - return IMG_BASE64_PREFIX + \ - base64.b64encode(img).decode("utf-8") + return IMG_BASE64_PREFIX + base64.b64encode(img).decode("utf-8") else: - return '' + return "" def traversal_files(base): @@ -243,3 +235,52 @@ def traversal_files(base): for f in fs: fullname = os.path.join(root, f) yield fullname + + +def repair_pdf_with_ghostscript(input_bytes): + if shutil.which("gs") is None: + return input_bytes + + with tempfile.NamedTemporaryFile(suffix=".pdf") as temp_in, tempfile.NamedTemporaryFile(suffix=".pdf") as temp_out: + temp_in.write(input_bytes) + temp_in.flush() + + cmd = [ + "gs", + "-o", + temp_out.name, + "-sDEVICE=pdfwrite", + "-dPDFSETTINGS=/prepress", + temp_in.name, + ] + try: + proc = subprocess.run(cmd, capture_output=True, text=True) + if proc.returncode != 0: + return input_bytes + except Exception: + return input_bytes + + temp_out.seek(0) + repaired_bytes = temp_out.read() + + return repaired_bytes + + +def read_potential_broken_pdf(blob): + def try_open(blob): + try: + with pdfplumber.open(BytesIO(blob)) as pdf: + if pdf.pages: + return True + except Exception: + return False + return False + + if try_open(blob): + return blob + + repaired = repair_pdf_with_ghostscript(blob) + if try_open(repaired): + return repaired + + return blob