Feat: repair corrupted PDF files on upload automatically (#7693)

### What problem does this PR solve?

Try the best to repair corrupted PDF files on upload automatically.

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
Yongteng Lei 2025-05-19 14:54:06 +08:00 committed by GitHub
parent 7df1bd4b4a
commit 0ebf05440e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 251 additions and 323 deletions

View File

@ -50,9 +50,9 @@ jobs:
# https://github.com/astral-sh/ruff-action # https://github.com/astral-sh/ruff-action
- name: Static check with Ruff - name: Static check with Ruff
uses: astral-sh/ruff-action@v2 uses: astral-sh/ruff-action@v3
with: with:
version: ">=0.8.2" version: ">=0.11.x"
args: "check" args: "check"
- name: Build ragflow:nightly-slim - name: Build ragflow:nightly-slim

View File

@ -59,7 +59,8 @@ RUN --mount=type=cache,id=ragflow_apt,target=/var/cache/apt,sharing=locked \
apt install -y libatk-bridge2.0-0 && \ apt install -y libatk-bridge2.0-0 && \
apt install -y libpython3-dev libgtk-4-1 libnss3 xdg-utils libgbm-dev && \ apt install -y libpython3-dev libgtk-4-1 libnss3 xdg-utils libgbm-dev && \
apt install -y libjemalloc-dev && \ apt install -y libjemalloc-dev && \
apt install -y python3-pip pipx nginx unzip curl wget git vim less apt install -y python3-pip pipx nginx unzip curl wget git vim less && \
apt install -y ghostscript
RUN if [ "$NEED_MIRROR" == "1" ]; then \ RUN if [ "$NEED_MIRROR" == "1" ]; then \
pip3 config set global.index-url https://mirrors.aliyun.com/pypi/simple && \ pip3 config set global.index-url https://mirrors.aliyun.com/pypi/simple && \

View File

@ -20,79 +20,73 @@ import re
import flask import flask
from flask import request from flask import request
from flask_login import login_required, current_user from flask_login import current_user, login_required
from deepdoc.parser.html_parser import RAGFlowHtmlParser from api import settings
from rag.nlp import search from api.constants import IMG_BASE64_PREFIX
from api.db import VALID_FILE_TYPES, VALID_TASK_STATUS, FileSource, FileType, ParserType, TaskStatus
from api.db import VALID_FILE_TYPES, VALID_TASK_STATUS, FileType, TaskStatus, ParserType, FileSource
from api.db.db_models import File, Task from api.db.db_models import File, Task
from api.db.services import duplicate_name
from api.db.services.document_service import DocumentService, doc_upload_and_parse
from api.db.services.file2document_service import File2DocumentService from api.db.services.file2document_service import File2DocumentService
from api.db.services.file_service import FileService from api.db.services.file_service import FileService
from api.db.services.task_service import queue_tasks
from api.db.services.user_service import UserTenantService
from api.db.services import duplicate_name
from api.db.services.knowledgebase_service import KnowledgebaseService from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.task_service import TaskService from api.db.services.task_service import TaskService, queue_tasks
from api.db.services.document_service import DocumentService, doc_upload_and_parse from api.db.services.user_service import UserTenantService
from api.utils import get_uuid
from api.utils.api_utils import ( from api.utils.api_utils import (
server_error_response,
get_data_error_result, get_data_error_result,
get_json_result,
server_error_response,
validate_request, validate_request,
) )
from api.utils import get_uuid from api.utils.file_utils import filename_type, get_project_base_directory, thumbnail
from api import settings
from api.utils.api_utils import get_json_result
from rag.utils.storage_factory import STORAGE_IMPL
from api.utils.file_utils import filename_type, thumbnail, get_project_base_directory
from api.utils.web_utils import html2pdf, is_valid_url from api.utils.web_utils import html2pdf, is_valid_url
from api.constants import IMG_BASE64_PREFIX from deepdoc.parser.html_parser import RAGFlowHtmlParser
from rag.nlp import search
from rag.utils.storage_factory import STORAGE_IMPL
@manager.route('/upload', methods=['POST']) # noqa: F821 @manager.route("/upload", methods=["POST"]) # noqa: F821
@login_required @login_required
@validate_request("kb_id") @validate_request("kb_id")
def upload(): def upload():
kb_id = request.form.get("kb_id") kb_id = request.form.get("kb_id")
if not kb_id: if not kb_id:
return get_json_result( return get_json_result(data=False, message='Lack of "KB ID"', code=settings.RetCode.ARGUMENT_ERROR)
data=False, message='Lack of "KB ID"', code=settings.RetCode.ARGUMENT_ERROR) if "file" not in request.files:
if 'file' not in request.files: return get_json_result(data=False, message="No file part!", code=settings.RetCode.ARGUMENT_ERROR)
return get_json_result(
data=False, message='No file part!', code=settings.RetCode.ARGUMENT_ERROR)
file_objs = request.files.getlist('file') file_objs = request.files.getlist("file")
for file_obj in file_objs: for file_obj in file_objs:
if file_obj.filename == '': if file_obj.filename == "":
return get_json_result( return get_json_result(data=False, message="No file selected!", code=settings.RetCode.ARGUMENT_ERROR)
data=False, message='No file selected!', code=settings.RetCode.ARGUMENT_ERROR)
e, kb = KnowledgebaseService.get_by_id(kb_id) e, kb = KnowledgebaseService.get_by_id(kb_id)
if not e: if not e:
raise LookupError("Can't find this knowledgebase!") raise LookupError("Can't find this knowledgebase!")
err, files = FileService.upload_document(kb, file_objs, current_user.id) err, files = FileService.upload_document(kb, file_objs, current_user.id)
files = [f[0] for f in files] # remove the blob
if not files:
return get_json_result(data=files, message="There seems to be an issue with your file format. Please verify it is correct and not corrupted.", code=settings.RetCode.DATA_ERROR)
files = [f[0] for f in files] # remove the blob
if err: if err:
return get_json_result( return get_json_result(data=files, message="\n".join(err), code=settings.RetCode.SERVER_ERROR)
data=files, message="\n".join(err), code=settings.RetCode.SERVER_ERROR)
return get_json_result(data=files) return get_json_result(data=files)
@manager.route('/web_crawl', methods=['POST']) # noqa: F821 @manager.route("/web_crawl", methods=["POST"]) # noqa: F821
@login_required @login_required
@validate_request("kb_id", "name", "url") @validate_request("kb_id", "name", "url")
def web_crawl(): def web_crawl():
kb_id = request.form.get("kb_id") kb_id = request.form.get("kb_id")
if not kb_id: if not kb_id:
return get_json_result( return get_json_result(data=False, message='Lack of "KB ID"', code=settings.RetCode.ARGUMENT_ERROR)
data=False, message='Lack of "KB ID"', code=settings.RetCode.ARGUMENT_ERROR)
name = request.form.get("name") name = request.form.get("name")
url = request.form.get("url") url = request.form.get("url")
if not is_valid_url(url): if not is_valid_url(url):
return get_json_result( return get_json_result(data=False, message="The URL format is invalid", code=settings.RetCode.ARGUMENT_ERROR)
data=False, message='The URL format is invalid', code=settings.RetCode.ARGUMENT_ERROR)
e, kb = KnowledgebaseService.get_by_id(kb_id) e, kb = KnowledgebaseService.get_by_id(kb_id)
if not e: if not e:
raise LookupError("Can't find this knowledgebase!") raise LookupError("Can't find this knowledgebase!")
@ -108,10 +102,7 @@ def web_crawl():
kb_folder = FileService.new_a_file_from_kb(kb.tenant_id, kb.name, kb_root_folder["id"]) kb_folder = FileService.new_a_file_from_kb(kb.tenant_id, kb.name, kb_root_folder["id"])
try: try:
filename = duplicate_name( filename = duplicate_name(DocumentService.query, name=name + ".pdf", kb_id=kb.id)
DocumentService.query,
name=name + ".pdf",
kb_id=kb.id)
filetype = filename_type(filename) filetype = filename_type(filename)
if filetype == FileType.OTHER.value: if filetype == FileType.OTHER.value:
raise RuntimeError("This type of file has not been supported yet!") raise RuntimeError("This type of file has not been supported yet!")
@ -130,7 +121,7 @@ def web_crawl():
"name": filename, "name": filename,
"location": location, "location": location,
"size": len(blob), "size": len(blob),
"thumbnail": thumbnail(filename, blob) "thumbnail": thumbnail(filename, blob),
} }
if doc["type"] == FileType.VISUAL: if doc["type"] == FileType.VISUAL:
doc["parser_id"] = ParserType.PICTURE.value doc["parser_id"] = ParserType.PICTURE.value
@ -147,58 +138,53 @@ def web_crawl():
return get_json_result(data=True) return get_json_result(data=True)
@manager.route('/create', methods=['POST']) # noqa: F821 @manager.route("/create", methods=["POST"]) # noqa: F821
@login_required @login_required
@validate_request("name", "kb_id") @validate_request("name", "kb_id")
def create(): def create():
req = request.json req = request.json
kb_id = req["kb_id"] kb_id = req["kb_id"]
if not kb_id: if not kb_id:
return get_json_result( return get_json_result(data=False, message='Lack of "KB ID"', code=settings.RetCode.ARGUMENT_ERROR)
data=False, message='Lack of "KB ID"', code=settings.RetCode.ARGUMENT_ERROR)
try: try:
e, kb = KnowledgebaseService.get_by_id(kb_id) e, kb = KnowledgebaseService.get_by_id(kb_id)
if not e: if not e:
return get_data_error_result( return get_data_error_result(message="Can't find this knowledgebase!")
message="Can't find this knowledgebase!")
if DocumentService.query(name=req["name"], kb_id=kb_id): if DocumentService.query(name=req["name"], kb_id=kb_id):
return get_data_error_result( return get_data_error_result(message="Duplicated document name in the same knowledgebase.")
message="Duplicated document name in the same knowledgebase.")
doc = DocumentService.insert({ doc = DocumentService.insert(
"id": get_uuid(), {
"kb_id": kb.id, "id": get_uuid(),
"parser_id": kb.parser_id, "kb_id": kb.id,
"parser_config": kb.parser_config, "parser_id": kb.parser_id,
"created_by": current_user.id, "parser_config": kb.parser_config,
"type": FileType.VIRTUAL, "created_by": current_user.id,
"name": req["name"], "type": FileType.VIRTUAL,
"location": "", "name": req["name"],
"size": 0 "location": "",
}) "size": 0,
}
)
return get_json_result(data=doc.to_json()) return get_json_result(data=doc.to_json())
except Exception as e: except Exception as e:
return server_error_response(e) return server_error_response(e)
@manager.route('/list', methods=['POST']) # noqa: F821 @manager.route("/list", methods=["POST"]) # noqa: F821
@login_required @login_required
def list_docs(): def list_docs():
kb_id = request.args.get("kb_id") kb_id = request.args.get("kb_id")
if not kb_id: if not kb_id:
return get_json_result( return get_json_result(data=False, message='Lack of "KB ID"', code=settings.RetCode.ARGUMENT_ERROR)
data=False, message='Lack of "KB ID"', code=settings.RetCode.ARGUMENT_ERROR)
tenants = UserTenantService.query(user_id=current_user.id) tenants = UserTenantService.query(user_id=current_user.id)
for tenant in tenants: for tenant in tenants:
if KnowledgebaseService.query( if KnowledgebaseService.query(tenant_id=tenant.tenant_id, id=kb_id):
tenant_id=tenant.tenant_id, id=kb_id):
break break
else: else:
return get_json_result( return get_json_result(data=False, message="Only owner of knowledgebase authorized for this operation.", code=settings.RetCode.OPERATING_ERROR)
data=False, message='Only owner of knowledgebase authorized for this operation.',
code=settings.RetCode.OPERATING_ERROR)
keywords = request.args.get("keywords", "") keywords = request.args.get("keywords", "")
page_number = int(request.args.get("page", 0)) page_number = int(request.args.get("page", 0))
@ -212,83 +198,67 @@ def list_docs():
if run_status: if run_status:
invalid_status = {s for s in run_status if s not in VALID_TASK_STATUS} invalid_status = {s for s in run_status if s not in VALID_TASK_STATUS}
if invalid_status: if invalid_status:
return get_data_error_result( return get_data_error_result(message=f"Invalid filter run status conditions: {', '.join(invalid_status)}")
message=f"Invalid filter run status conditions: {', '.join(invalid_status)}"
)
types = req.get("types", []) types = req.get("types", [])
if types: if types:
invalid_types = {t for t in types if t not in VALID_FILE_TYPES} invalid_types = {t for t in types if t not in VALID_FILE_TYPES}
if invalid_types: if invalid_types:
return get_data_error_result( return get_data_error_result(message=f"Invalid filter conditions: {', '.join(invalid_types)} type{'s' if len(invalid_types) > 1 else ''}")
message=f"Invalid filter conditions: {', '.join(invalid_types)} type{'s' if len(invalid_types) > 1 else ''}"
)
try: try:
docs, tol = DocumentService.get_by_kb_id( docs, tol = DocumentService.get_by_kb_id(kb_id, page_number, items_per_page, orderby, desc, keywords, run_status, types)
kb_id, page_number, items_per_page, orderby, desc, keywords, run_status, types)
for doc_item in docs: for doc_item in docs:
if doc_item['thumbnail'] and not doc_item['thumbnail'].startswith(IMG_BASE64_PREFIX): if doc_item["thumbnail"] and not doc_item["thumbnail"].startswith(IMG_BASE64_PREFIX):
doc_item['thumbnail'] = f"/v1/document/image/{kb_id}-{doc_item['thumbnail']}" doc_item["thumbnail"] = f"/v1/document/image/{kb_id}-{doc_item['thumbnail']}"
return get_json_result(data={"total": tol, "docs": docs}) return get_json_result(data={"total": tol, "docs": docs})
except Exception as e: except Exception as e:
return server_error_response(e) return server_error_response(e)
@manager.route('/infos', methods=['POST']) # noqa: F821 @manager.route("/infos", methods=["POST"]) # noqa: F821
@login_required @login_required
def docinfos(): def docinfos():
req = request.json req = request.json
doc_ids = req["doc_ids"] doc_ids = req["doc_ids"]
for doc_id in doc_ids: for doc_id in doc_ids:
if not DocumentService.accessible(doc_id, current_user.id): if not DocumentService.accessible(doc_id, current_user.id):
return get_json_result( return get_json_result(data=False, message="No authorization.", code=settings.RetCode.AUTHENTICATION_ERROR)
data=False,
message='No authorization.',
code=settings.RetCode.AUTHENTICATION_ERROR
)
docs = DocumentService.get_by_ids(doc_ids) docs = DocumentService.get_by_ids(doc_ids)
return get_json_result(data=list(docs.dicts())) return get_json_result(data=list(docs.dicts()))
@manager.route('/thumbnails', methods=['GET']) # noqa: F821 @manager.route("/thumbnails", methods=["GET"]) # noqa: F821
# @login_required # @login_required
def thumbnails(): def thumbnails():
doc_ids = request.args.get("doc_ids").split(",") doc_ids = request.args.get("doc_ids").split(",")
if not doc_ids: if not doc_ids:
return get_json_result( return get_json_result(data=False, message='Lack of "Document ID"', code=settings.RetCode.ARGUMENT_ERROR)
data=False, message='Lack of "Document ID"', code=settings.RetCode.ARGUMENT_ERROR)
try: try:
docs = DocumentService.get_thumbnails(doc_ids) docs = DocumentService.get_thumbnails(doc_ids)
for doc_item in docs: for doc_item in docs:
if doc_item['thumbnail'] and not doc_item['thumbnail'].startswith(IMG_BASE64_PREFIX): if doc_item["thumbnail"] and not doc_item["thumbnail"].startswith(IMG_BASE64_PREFIX):
doc_item['thumbnail'] = f"/v1/document/image/{doc_item['kb_id']}-{doc_item['thumbnail']}" doc_item["thumbnail"] = f"/v1/document/image/{doc_item['kb_id']}-{doc_item['thumbnail']}"
return get_json_result(data={d["id"]: d["thumbnail"] for d in docs}) return get_json_result(data={d["id"]: d["thumbnail"] for d in docs})
except Exception as e: except Exception as e:
return server_error_response(e) return server_error_response(e)
@manager.route('/change_status', methods=['POST']) # noqa: F821 @manager.route("/change_status", methods=["POST"]) # noqa: F821
@login_required @login_required
@validate_request("doc_id", "status") @validate_request("doc_id", "status")
def change_status(): def change_status():
req = request.json req = request.json
if str(req["status"]) not in ["0", "1"]: if str(req["status"]) not in ["0", "1"]:
return get_json_result( return get_json_result(data=False, message='"Status" must be either 0 or 1!', code=settings.RetCode.ARGUMENT_ERROR)
data=False,
message='"Status" must be either 0 or 1!',
code=settings.RetCode.ARGUMENT_ERROR)
if not DocumentService.accessible(req["doc_id"], current_user.id): if not DocumentService.accessible(req["doc_id"], current_user.id):
return get_json_result( return get_json_result(data=False, message="No authorization.", code=settings.RetCode.AUTHENTICATION_ERROR)
data=False,
message='No authorization.',
code=settings.RetCode.AUTHENTICATION_ERROR)
try: try:
e, doc = DocumentService.get_by_id(req["doc_id"]) e, doc = DocumentService.get_by_id(req["doc_id"])
@ -296,23 +266,19 @@ def change_status():
return get_data_error_result(message="Document not found!") return get_data_error_result(message="Document not found!")
e, kb = KnowledgebaseService.get_by_id(doc.kb_id) e, kb = KnowledgebaseService.get_by_id(doc.kb_id)
if not e: if not e:
return get_data_error_result( return get_data_error_result(message="Can't find this knowledgebase!")
message="Can't find this knowledgebase!")
if not DocumentService.update_by_id( if not DocumentService.update_by_id(req["doc_id"], {"status": str(req["status"])}):
req["doc_id"], {"status": str(req["status"])}): return get_data_error_result(message="Database error (Document update)!")
return get_data_error_result(
message="Database error (Document update)!")
status = int(req["status"]) status = int(req["status"])
settings.docStoreConn.update({"doc_id": req["doc_id"]}, {"available_int": status}, settings.docStoreConn.update({"doc_id": req["doc_id"]}, {"available_int": status}, search.index_name(kb.tenant_id), doc.kb_id)
search.index_name(kb.tenant_id), doc.kb_id)
return get_json_result(data=True) return get_json_result(data=True)
except Exception as e: except Exception as e:
return server_error_response(e) return server_error_response(e)
@manager.route('/rm', methods=['POST']) # noqa: F821 @manager.route("/rm", methods=["POST"]) # noqa: F821
@login_required @login_required
@validate_request("doc_id") @validate_request("doc_id")
def rm(): def rm():
@ -323,11 +289,7 @@ def rm():
for doc_id in doc_ids: for doc_id in doc_ids:
if not DocumentService.accessible4deletion(doc_id, current_user.id): if not DocumentService.accessible4deletion(doc_id, current_user.id):
return get_json_result( return get_json_result(data=False, message="No authorization.", code=settings.RetCode.AUTHENTICATION_ERROR)
data=False,
message='No authorization.',
code=settings.RetCode.AUTHENTICATION_ERROR
)
root_folder = FileService.get_root_folder(current_user.id) root_folder = FileService.get_root_folder(current_user.id)
pf_id = root_folder["id"] pf_id = root_folder["id"]
@ -347,8 +309,7 @@ def rm():
TaskService.filter_delete([Task.doc_id == doc_id]) TaskService.filter_delete([Task.doc_id == doc_id])
if not DocumentService.remove_document(doc, tenant_id): if not DocumentService.remove_document(doc, tenant_id):
return get_data_error_result( return get_data_error_result(message="Database error (Document removal)!")
message="Database error (Document removal)!")
f2d = File2DocumentService.get_by_document_id(doc_id) f2d = File2DocumentService.get_by_document_id(doc_id)
deleted_file_count = 0 deleted_file_count = 0
@ -376,18 +337,14 @@ def rm():
return get_json_result(data=True) return get_json_result(data=True)
@manager.route('/run', methods=['POST']) # noqa: F821 @manager.route("/run", methods=["POST"]) # noqa: F821
@login_required @login_required
@validate_request("doc_ids", "run") @validate_request("doc_ids", "run")
def run(): def run():
req = request.json req = request.json
for doc_id in req["doc_ids"]: for doc_id in req["doc_ids"]:
if not DocumentService.accessible(doc_id, current_user.id): if not DocumentService.accessible(doc_id, current_user.id):
return get_json_result( return get_json_result(data=False, message="No authorization.", code=settings.RetCode.AUTHENTICATION_ERROR)
data=False,
message='No authorization.',
code=settings.RetCode.AUTHENTICATION_ERROR
)
try: try:
kb_table_num_map = {} kb_table_num_map = {}
for id in req["doc_ids"]: for id in req["doc_ids"]:
@ -421,7 +378,7 @@ def run():
if kb_id not in kb_table_num_map: if kb_id not in kb_table_num_map:
count = DocumentService.count_by_kb_id(kb_id=kb_id, keywords="", run_status=[TaskStatus.DONE], types=[]) count = DocumentService.count_by_kb_id(kb_id=kb_id, keywords="", run_status=[TaskStatus.DONE], types=[])
kb_table_num_map[kb_id] = count kb_table_num_map[kb_id] = count
if kb_table_num_map[kb_id] <=0: if kb_table_num_map[kb_id] <= 0:
KnowledgebaseService.delete_field_map(kb_id) KnowledgebaseService.delete_field_map(kb_id)
bucket, name = File2DocumentService.get_storage_address(doc_id=doc["id"]) bucket, name = File2DocumentService.get_storage_address(doc_id=doc["id"])
queue_tasks(doc, bucket, name, 0) queue_tasks(doc, bucket, name, 0)
@ -431,36 +388,25 @@ def run():
return server_error_response(e) return server_error_response(e)
@manager.route('/rename', methods=['POST']) # noqa: F821 @manager.route("/rename", methods=["POST"]) # noqa: F821
@login_required @login_required
@validate_request("doc_id", "name") @validate_request("doc_id", "name")
def rename(): def rename():
req = request.json req = request.json
if not DocumentService.accessible(req["doc_id"], current_user.id): if not DocumentService.accessible(req["doc_id"], current_user.id):
return get_json_result( return get_json_result(data=False, message="No authorization.", code=settings.RetCode.AUTHENTICATION_ERROR)
data=False,
message='No authorization.',
code=settings.RetCode.AUTHENTICATION_ERROR
)
try: try:
e, doc = DocumentService.get_by_id(req["doc_id"]) e, doc = DocumentService.get_by_id(req["doc_id"])
if not e: if not e:
return get_data_error_result(message="Document not found!") return get_data_error_result(message="Document not found!")
if pathlib.Path(req["name"].lower()).suffix != pathlib.Path( if pathlib.Path(req["name"].lower()).suffix != pathlib.Path(doc.name.lower()).suffix:
doc.name.lower()).suffix: return get_json_result(data=False, message="The extension of file can't be changed", code=settings.RetCode.ARGUMENT_ERROR)
return get_json_result(
data=False,
message="The extension of file can't be changed",
code=settings.RetCode.ARGUMENT_ERROR)
for d in DocumentService.query(name=req["name"], kb_id=doc.kb_id): for d in DocumentService.query(name=req["name"], kb_id=doc.kb_id):
if d.name == req["name"]: if d.name == req["name"]:
return get_data_error_result( return get_data_error_result(message="Duplicated document name in the same knowledgebase.")
message="Duplicated document name in the same knowledgebase.")
if not DocumentService.update_by_id( if not DocumentService.update_by_id(req["doc_id"], {"name": req["name"]}):
req["doc_id"], {"name": req["name"]}): return get_data_error_result(message="Database error (Document rename)!")
return get_data_error_result(
message="Database error (Document rename)!")
informs = File2DocumentService.get_by_document_id(req["doc_id"]) informs = File2DocumentService.get_by_document_id(req["doc_id"])
if informs: if informs:
@ -472,7 +418,7 @@ def rename():
return server_error_response(e) return server_error_response(e)
@manager.route('/get/<doc_id>', methods=['GET']) # noqa: F821 @manager.route("/get/<doc_id>", methods=["GET"]) # noqa: F821
# @login_required # @login_required
def get(doc_id): def get(doc_id):
try: try:
@ -486,29 +432,22 @@ def get(doc_id):
ext = re.search(r"\.([^.]+)$", doc.name) ext = re.search(r"\.([^.]+)$", doc.name)
if ext: if ext:
if doc.type == FileType.VISUAL.value: if doc.type == FileType.VISUAL.value:
response.headers.set('Content-Type', 'image/%s' % ext.group(1)) response.headers.set("Content-Type", "image/%s" % ext.group(1))
else: else:
response.headers.set( response.headers.set("Content-Type", "application/%s" % ext.group(1))
'Content-Type',
'application/%s' %
ext.group(1))
return response return response
except Exception as e: except Exception as e:
return server_error_response(e) return server_error_response(e)
@manager.route('/change_parser', methods=['POST']) # noqa: F821 @manager.route("/change_parser", methods=["POST"]) # noqa: F821
@login_required @login_required
@validate_request("doc_id", "parser_id") @validate_request("doc_id", "parser_id")
def change_parser(): def change_parser():
req = request.json req = request.json
if not DocumentService.accessible(req["doc_id"], current_user.id): if not DocumentService.accessible(req["doc_id"], current_user.id):
return get_json_result( return get_json_result(data=False, message="No authorization.", code=settings.RetCode.AUTHENTICATION_ERROR)
data=False,
message='No authorization.',
code=settings.RetCode.AUTHENTICATION_ERROR
)
try: try:
e, doc = DocumentService.get_by_id(req["doc_id"]) e, doc = DocumentService.get_by_id(req["doc_id"])
if not e: if not e:
@ -520,21 +459,16 @@ def change_parser():
else: else:
return get_json_result(data=True) return get_json_result(data=True)
if ((doc.type == FileType.VISUAL and req["parser_id"] != "picture") if (doc.type == FileType.VISUAL and req["parser_id"] != "picture") or (re.search(r"\.(ppt|pptx|pages)$", doc.name) and req["parser_id"] != "presentation"):
or (re.search(
r"\.(ppt|pptx|pages)$", doc.name) and req["parser_id"] != "presentation")):
return get_data_error_result(message="Not supported yet!") return get_data_error_result(message="Not supported yet!")
e = DocumentService.update_by_id(doc.id, e = DocumentService.update_by_id(doc.id, {"parser_id": req["parser_id"], "progress": 0, "progress_msg": "", "run": TaskStatus.UNSTART.value})
{"parser_id": req["parser_id"], "progress": 0, "progress_msg": "",
"run": TaskStatus.UNSTART.value})
if not e: if not e:
return get_data_error_result(message="Document not found!") return get_data_error_result(message="Document not found!")
if "parser_config" in req: if "parser_config" in req:
DocumentService.update_parser_config(doc.id, req["parser_config"]) DocumentService.update_parser_config(doc.id, req["parser_config"])
if doc.token_num > 0: if doc.token_num > 0:
e = DocumentService.increment_chunk_num(doc.id, doc.kb_id, doc.token_num * -1, doc.chunk_num * -1, e = DocumentService.increment_chunk_num(doc.id, doc.kb_id, doc.token_num * -1, doc.chunk_num * -1, doc.process_duation * -1)
doc.process_duation * -1)
if not e: if not e:
return get_data_error_result(message="Document not found!") return get_data_error_result(message="Document not found!")
tenant_id = DocumentService.get_tenant_id(req["doc_id"]) tenant_id = DocumentService.get_tenant_id(req["doc_id"])
@ -548,7 +482,7 @@ def change_parser():
return server_error_response(e) return server_error_response(e)
@manager.route('/image/<image_id>', methods=['GET']) # noqa: F821 @manager.route("/image/<image_id>", methods=["GET"]) # noqa: F821
# @login_required # @login_required
def get_image(image_id): def get_image(image_id):
try: try:
@ -557,53 +491,46 @@ def get_image(image_id):
return get_data_error_result(message="Image not found.") return get_data_error_result(message="Image not found.")
bkt, nm = image_id.split("-") bkt, nm = image_id.split("-")
response = flask.make_response(STORAGE_IMPL.get(bkt, nm)) response = flask.make_response(STORAGE_IMPL.get(bkt, nm))
response.headers.set('Content-Type', 'image/JPEG') response.headers.set("Content-Type", "image/JPEG")
return response return response
except Exception as e: except Exception as e:
return server_error_response(e) return server_error_response(e)
@manager.route('/upload_and_parse', methods=['POST']) # noqa: F821 @manager.route("/upload_and_parse", methods=["POST"]) # noqa: F821
@login_required @login_required
@validate_request("conversation_id") @validate_request("conversation_id")
def upload_and_parse(): def upload_and_parse():
if 'file' not in request.files: if "file" not in request.files:
return get_json_result( return get_json_result(data=False, message="No file part!", code=settings.RetCode.ARGUMENT_ERROR)
data=False, message='No file part!', code=settings.RetCode.ARGUMENT_ERROR)
file_objs = request.files.getlist('file') file_objs = request.files.getlist("file")
for file_obj in file_objs: for file_obj in file_objs:
if file_obj.filename == '': if file_obj.filename == "":
return get_json_result( return get_json_result(data=False, message="No file selected!", code=settings.RetCode.ARGUMENT_ERROR)
data=False, message='No file selected!', code=settings.RetCode.ARGUMENT_ERROR)
doc_ids = doc_upload_and_parse(request.form.get("conversation_id"), file_objs, current_user.id) doc_ids = doc_upload_and_parse(request.form.get("conversation_id"), file_objs, current_user.id)
return get_json_result(data=doc_ids) return get_json_result(data=doc_ids)
@manager.route('/parse', methods=['POST']) # noqa: F821 @manager.route("/parse", methods=["POST"]) # noqa: F821
@login_required @login_required
def parse(): def parse():
url = request.json.get("url") if request.json else "" url = request.json.get("url") if request.json else ""
if url: if url:
if not is_valid_url(url): if not is_valid_url(url):
return get_json_result( return get_json_result(data=False, message="The URL format is invalid", code=settings.RetCode.ARGUMENT_ERROR)
data=False, message='The URL format is invalid', code=settings.RetCode.ARGUMENT_ERROR)
download_path = os.path.join(get_project_base_directory(), "logs/downloads") download_path = os.path.join(get_project_base_directory(), "logs/downloads")
os.makedirs(download_path, exist_ok=True) os.makedirs(download_path, exist_ok=True)
from seleniumwire.webdriver import Chrome, ChromeOptions from seleniumwire.webdriver import Chrome, ChromeOptions
options = ChromeOptions() options = ChromeOptions()
options.add_argument('--headless') options.add_argument("--headless")
options.add_argument('--disable-gpu') options.add_argument("--disable-gpu")
options.add_argument('--no-sandbox') options.add_argument("--no-sandbox")
options.add_argument('--disable-dev-shm-usage') options.add_argument("--disable-dev-shm-usage")
options.add_experimental_option('prefs', { options.add_experimental_option("prefs", {"download.default_directory": download_path, "download.prompt_for_download": False, "download.directory_upgrade": True, "safebrowsing.enabled": True})
'download.default_directory': download_path,
'download.prompt_for_download': False,
'download.directory_upgrade': True,
'safebrowsing.enabled': True
})
driver = Chrome(options=options) driver = Chrome(options=options)
driver.get(url) driver.get(url)
res_headers = [r.response.headers for r in driver.requests if r and r.response] res_headers = [r.response.headers for r in driver.requests if r and r.response]
@ -626,51 +553,41 @@ def parse():
r = re.search(r"filename=\"([^\"]+)\"", str(res_headers)) r = re.search(r"filename=\"([^\"]+)\"", str(res_headers))
if not r or not r.group(1): if not r or not r.group(1):
return get_json_result( return get_json_result(data=False, message="Can't not identify downloaded file", code=settings.RetCode.ARGUMENT_ERROR)
data=False, message="Can't not identify downloaded file", code=settings.RetCode.ARGUMENT_ERROR)
f = File(r.group(1), os.path.join(download_path, r.group(1))) f = File(r.group(1), os.path.join(download_path, r.group(1)))
txt = FileService.parse_docs([f], current_user.id) txt = FileService.parse_docs([f], current_user.id)
return get_json_result(data=txt) return get_json_result(data=txt)
if 'file' not in request.files: if "file" not in request.files:
return get_json_result( return get_json_result(data=False, message="No file part!", code=settings.RetCode.ARGUMENT_ERROR)
data=False, message='No file part!', code=settings.RetCode.ARGUMENT_ERROR)
file_objs = request.files.getlist('file') file_objs = request.files.getlist("file")
txt = FileService.parse_docs(file_objs, current_user.id) txt = FileService.parse_docs(file_objs, current_user.id)
return get_json_result(data=txt) return get_json_result(data=txt)
@manager.route('/set_meta', methods=['POST']) # noqa: F821 @manager.route("/set_meta", methods=["POST"]) # noqa: F821
@login_required @login_required
@validate_request("doc_id", "meta") @validate_request("doc_id", "meta")
def set_meta(): def set_meta():
req = request.json req = request.json
if not DocumentService.accessible(req["doc_id"], current_user.id): if not DocumentService.accessible(req["doc_id"], current_user.id):
return get_json_result( return get_json_result(data=False, message="No authorization.", code=settings.RetCode.AUTHENTICATION_ERROR)
data=False,
message='No authorization.',
code=settings.RetCode.AUTHENTICATION_ERROR
)
try: try:
meta = json.loads(req["meta"]) meta = json.loads(req["meta"])
except Exception as e: except Exception as e:
return get_json_result( return get_json_result(data=False, message=f"Json syntax error: {e}", code=settings.RetCode.ARGUMENT_ERROR)
data=False, message=f'Json syntax error: {e}', code=settings.RetCode.ARGUMENT_ERROR)
if not isinstance(meta, dict): if not isinstance(meta, dict):
return get_json_result( return get_json_result(data=False, message='Meta data should be in Json map format, like {"key": "value"}', code=settings.RetCode.ARGUMENT_ERROR)
data=False, message='Meta data should be in Json map format, like {"key": "value"}', code=settings.RetCode.ARGUMENT_ERROR)
try: try:
e, doc = DocumentService.get_by_id(req["doc_id"]) e, doc = DocumentService.get_by_id(req["doc_id"])
if not e: if not e:
return get_data_error_result(message="Document not found!") return get_data_error_result(message="Document not found!")
if not DocumentService.update_by_id( if not DocumentService.update_by_id(req["doc_id"], {"meta_fields": meta}):
req["doc_id"], {"meta_fields": meta}): return get_data_error_result(message="Database error (meta updates)!")
return get_data_error_result(
message="Database error (meta updates)!")
return get_json_result(data=True) return get_json_result(data=True)
except Exception as e: except Exception as e:

View File

@ -14,22 +14,21 @@
# limitations under the License. # limitations under the License.
# #
import logging import logging
import re
import os import os
import re
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from flask_login import current_user from flask_login import current_user
from peewee import fn from peewee import fn
from api.db import FileType, KNOWLEDGEBASE_FOLDER_NAME, FileSource, ParserType from api.db import KNOWLEDGEBASE_FOLDER_NAME, FileSource, FileType, ParserType
from api.db.db_models import DB, File2Document, Knowledgebase from api.db.db_models import DB, Document, File, File2Document, Knowledgebase
from api.db.db_models import File, Document
from api.db.services import duplicate_name from api.db.services import duplicate_name
from api.db.services.common_service import CommonService from api.db.services.common_service import CommonService
from api.db.services.document_service import DocumentService from api.db.services.document_service import DocumentService
from api.db.services.file2document_service import File2DocumentService from api.db.services.file2document_service import File2DocumentService
from api.utils import get_uuid from api.utils import get_uuid
from api.utils.file_utils import filename_type, thumbnail_img from api.utils.file_utils import filename_type, read_potential_broken_pdf, thumbnail_img
from rag.utils.storage_factory import STORAGE_IMPL from rag.utils.storage_factory import STORAGE_IMPL
@ -39,8 +38,7 @@ class FileService(CommonService):
@classmethod @classmethod
@DB.connection_context() @DB.connection_context()
def get_by_pf_id(cls, tenant_id, pf_id, page_number, items_per_page, def get_by_pf_id(cls, tenant_id, pf_id, page_number, items_per_page, orderby, desc, keywords):
orderby, desc, keywords):
# Get files by parent folder ID with pagination and filtering # Get files by parent folder ID with pagination and filtering
# Args: # Args:
# tenant_id: ID of the tenant # tenant_id: ID of the tenant
@ -53,17 +51,9 @@ class FileService(CommonService):
# Returns: # Returns:
# Tuple of (file_list, total_count) # Tuple of (file_list, total_count)
if keywords: if keywords:
files = cls.model.select().where( files = cls.model.select().where((cls.model.tenant_id == tenant_id), (cls.model.parent_id == pf_id), (fn.LOWER(cls.model.name).contains(keywords.lower())), ~(cls.model.id == pf_id))
(cls.model.tenant_id == tenant_id),
(cls.model.parent_id == pf_id),
(fn.LOWER(cls.model.name).contains(keywords.lower())),
~(cls.model.id == pf_id)
)
else: else:
files = cls.model.select().where((cls.model.tenant_id == tenant_id), files = cls.model.select().where((cls.model.tenant_id == tenant_id), (cls.model.parent_id == pf_id), ~(cls.model.id == pf_id))
(cls.model.parent_id == pf_id),
~(cls.model.id == pf_id)
)
count = files.count() count = files.count()
if desc: if desc:
files = files.order_by(cls.model.getter_by(orderby).desc()) files = files.order_by(cls.model.getter_by(orderby).desc())
@ -76,16 +66,20 @@ class FileService(CommonService):
for file in res_files: for file in res_files:
if file["type"] == FileType.FOLDER.value: if file["type"] == FileType.FOLDER.value:
file["size"] = cls.get_folder_size(file["id"]) file["size"] = cls.get_folder_size(file["id"])
file['kbs_info'] = [] file["kbs_info"] = []
children = list(cls.model.select().where( children = list(
(cls.model.tenant_id == tenant_id), cls.model.select()
(cls.model.parent_id == file["id"]), .where(
~(cls.model.id == file["id"]), (cls.model.tenant_id == tenant_id),
).dicts()) (cls.model.parent_id == file["id"]),
~(cls.model.id == file["id"]),
)
.dicts()
)
file["has_child_folder"] = any(value["type"] == FileType.FOLDER.value for value in children) file["has_child_folder"] = any(value["type"] == FileType.FOLDER.value for value in children)
continue continue
kbs_info = cls.get_kb_id_by_file_id(file['id']) kbs_info = cls.get_kb_id_by_file_id(file["id"])
file['kbs_info'] = kbs_info file["kbs_info"] = kbs_info
return res_files, count return res_files, count
@ -97,16 +91,18 @@ class FileService(CommonService):
# file_id: File ID # file_id: File ID
# Returns: # Returns:
# List of dictionaries containing knowledge base IDs and names # List of dictionaries containing knowledge base IDs and names
kbs = (cls.model.select(*[Knowledgebase.id, Knowledgebase.name]) kbs = (
.join(File2Document, on=(File2Document.file_id == file_id)) cls.model.select(*[Knowledgebase.id, Knowledgebase.name])
.join(Document, on=(File2Document.document_id == Document.id)) .join(File2Document, on=(File2Document.file_id == file_id))
.join(Knowledgebase, on=(Knowledgebase.id == Document.kb_id)) .join(Document, on=(File2Document.document_id == Document.id))
.where(cls.model.id == file_id)) .join(Knowledgebase, on=(Knowledgebase.id == Document.kb_id))
.where(cls.model.id == file_id)
)
if not kbs: if not kbs:
return [] return []
kbs_info_list = [] kbs_info_list = []
for kb in list(kbs.dicts()): for kb in list(kbs.dicts()):
kbs_info_list.append({"kb_id": kb['id'], "kb_name": kb['name']}) kbs_info_list.append({"kb_id": kb["id"], "kb_name": kb["name"]})
return kbs_info_list return kbs_info_list
@classmethod @classmethod
@ -178,16 +174,9 @@ class FileService(CommonService):
if count > len(name) - 2: if count > len(name) - 2:
return file return file
else: else:
file = cls.insert({ file = cls.insert(
"id": get_uuid(), {"id": get_uuid(), "parent_id": parent_id, "tenant_id": current_user.id, "created_by": current_user.id, "name": name[count], "location": "", "size": 0, "type": FileType.FOLDER.value}
"parent_id": parent_id, )
"tenant_id": current_user.id,
"created_by": current_user.id,
"name": name[count],
"location": "",
"size": 0,
"type": FileType.FOLDER.value
})
return cls.create_folder(file, file.id, name, count + 1) return cls.create_folder(file, file.id, name, count + 1)
@classmethod @classmethod
@ -212,9 +201,7 @@ class FileService(CommonService):
# tenant_id: Tenant ID # tenant_id: Tenant ID
# Returns: # Returns:
# Root folder dictionary # Root folder dictionary
for file in cls.model.select().where((cls.model.tenant_id == tenant_id), for file in cls.model.select().where((cls.model.tenant_id == tenant_id), (cls.model.parent_id == cls.model.id)):
(cls.model.parent_id == cls.model.id)
):
return file.to_dict() return file.to_dict()
file_id = get_uuid() file_id = get_uuid()
@ -239,11 +226,8 @@ class FileService(CommonService):
# tenant_id: Tenant ID # tenant_id: Tenant ID
# Returns: # Returns:
# Knowledge base folder dictionary # Knowledge base folder dictionary
for root in cls.model.select().where( for root in cls.model.select().where((cls.model.tenant_id == tenant_id), (cls.model.parent_id == cls.model.id)):
(cls.model.tenant_id == tenant_id), (cls.model.parent_id == cls.model.id)): for folder in cls.model.select().where((cls.model.tenant_id == tenant_id), (cls.model.parent_id == root.id), (cls.model.name == KNOWLEDGEBASE_FOLDER_NAME)):
for folder in cls.model.select().where(
(cls.model.tenant_id == tenant_id), (cls.model.parent_id == root.id),
(cls.model.name == KNOWLEDGEBASE_FOLDER_NAME)):
return folder.to_dict() return folder.to_dict()
assert False, "Can't find the KB folder. Database init error." assert False, "Can't find the KB folder. Database init error."
@ -271,7 +255,7 @@ class FileService(CommonService):
"type": ty, "type": ty,
"size": size, "size": size,
"location": location, "location": location,
"source_type": FileSource.KNOWLEDGEBASE "source_type": FileSource.KNOWLEDGEBASE,
} }
cls.save(**file) cls.save(**file)
return file return file
@ -283,12 +267,11 @@ class FileService(CommonService):
# Args: # Args:
# root_id: Root folder ID # root_id: Root folder ID
# tenant_id: Tenant ID # tenant_id: Tenant ID
for _ in cls.model.select().where((cls.model.name == KNOWLEDGEBASE_FOLDER_NAME)\ for _ in cls.model.select().where((cls.model.name == KNOWLEDGEBASE_FOLDER_NAME) & (cls.model.parent_id == root_id)):
& (cls.model.parent_id == root_id)):
return return
folder = cls.new_a_file_from_kb(tenant_id, KNOWLEDGEBASE_FOLDER_NAME, root_id) folder = cls.new_a_file_from_kb(tenant_id, KNOWLEDGEBASE_FOLDER_NAME, root_id)
for kb in Knowledgebase.select(*[Knowledgebase.id, Knowledgebase.name]).where(Knowledgebase.tenant_id==tenant_id): for kb in Knowledgebase.select(*[Knowledgebase.id, Knowledgebase.name]).where(Knowledgebase.tenant_id == tenant_id):
kb_folder = cls.new_a_file_from_kb(tenant_id, kb.name, folder["id"]) kb_folder = cls.new_a_file_from_kb(tenant_id, kb.name, folder["id"])
for doc in DocumentService.query(kb_id=kb.id): for doc in DocumentService.query(kb_id=kb.id):
FileService.add_file_from_kb(doc.to_dict(), kb_folder["id"], tenant_id) FileService.add_file_from_kb(doc.to_dict(), kb_folder["id"], tenant_id)
@ -357,12 +340,10 @@ class FileService(CommonService):
@DB.connection_context() @DB.connection_context()
def delete_folder_by_pf_id(cls, user_id, folder_id): def delete_folder_by_pf_id(cls, user_id, folder_id):
try: try:
files = cls.model.select().where((cls.model.tenant_id == user_id) files = cls.model.select().where((cls.model.tenant_id == user_id) & (cls.model.parent_id == folder_id))
& (cls.model.parent_id == folder_id))
for file in files: for file in files:
cls.delete_folder_by_pf_id(user_id, file.id) cls.delete_folder_by_pf_id(user_id, file.id)
return cls.model.delete().where((cls.model.tenant_id == user_id) return (cls.model.delete().where((cls.model.tenant_id == user_id) & (cls.model.id == folder_id)).execute(),)
& (cls.model.id == folder_id)).execute(),
except Exception: except Exception:
logging.exception("delete_folder_by_pf_id") logging.exception("delete_folder_by_pf_id")
raise RuntimeError("Database error (File retrieval)!") raise RuntimeError("Database error (File retrieval)!")
@ -380,8 +361,7 @@ class FileService(CommonService):
def dfs(parent_id): def dfs(parent_id):
nonlocal size nonlocal size
for f in cls.model.select(*[cls.model.id, cls.model.size, cls.model.type]).where( for f in cls.model.select(*[cls.model.id, cls.model.size, cls.model.type]).where(cls.model.parent_id == parent_id, cls.model.id != parent_id):
cls.model.parent_id == parent_id, cls.model.id != parent_id):
size += f.size size += f.size
if f.type == FileType.FOLDER.value: if f.type == FileType.FOLDER.value:
dfs(f.id) dfs(f.id)
@ -403,7 +383,7 @@ class FileService(CommonService):
"type": doc["type"], "type": doc["type"],
"size": doc["size"], "size": doc["size"],
"location": doc["location"], "location": doc["location"],
"source_type": FileSource.KNOWLEDGEBASE "source_type": FileSource.KNOWLEDGEBASE,
} }
cls.save(**file) cls.save(**file)
File2DocumentService.save(**{"id": get_uuid(), "file_id": file["id"], "document_id": doc["id"]}) File2DocumentService.save(**{"id": get_uuid(), "file_id": file["id"], "document_id": doc["id"]})
@ -412,7 +392,7 @@ class FileService(CommonService):
@DB.connection_context() @DB.connection_context()
def move_file(cls, file_ids, folder_id): def move_file(cls, file_ids, folder_id):
try: try:
cls.filter_update((cls.model.id << file_ids, ), { 'parent_id': folder_id }) cls.filter_update((cls.model.id << file_ids,), {"parent_id": folder_id})
except Exception: except Exception:
logging.exception("move_file") logging.exception("move_file")
raise RuntimeError("Database error (File move)!") raise RuntimeError("Database error (File move)!")
@ -429,16 +409,13 @@ class FileService(CommonService):
err, files = [], [] err, files = [], []
for file in file_objs: for file in file_objs:
try: try:
MAX_FILE_NUM_PER_USER = int(os.environ.get('MAX_FILE_NUM_PER_USER', 0)) MAX_FILE_NUM_PER_USER = int(os.environ.get("MAX_FILE_NUM_PER_USER", 0))
if MAX_FILE_NUM_PER_USER > 0 and DocumentService.get_doc_count(kb.tenant_id) >= MAX_FILE_NUM_PER_USER: if MAX_FILE_NUM_PER_USER > 0 and DocumentService.get_doc_count(kb.tenant_id) >= MAX_FILE_NUM_PER_USER:
raise RuntimeError("Exceed the maximum file number of a free user!") raise RuntimeError("Exceed the maximum file number of a free user!")
if len(file.filename.encode("utf-8")) >= 128: if len(file.filename.encode("utf-8")) >= 128:
raise RuntimeError("Exceed the maximum length of file name!") raise RuntimeError("Exceed the maximum length of file name!")
filename = duplicate_name( filename = duplicate_name(DocumentService.query, name=file.filename, kb_id=kb.id)
DocumentService.query,
name=file.filename,
kb_id=kb.id)
filetype = filename_type(filename) filetype = filename_type(filename)
if filetype == FileType.OTHER.value: if filetype == FileType.OTHER.value:
raise RuntimeError("This type of file has not been supported yet!") raise RuntimeError("This type of file has not been supported yet!")
@ -446,15 +423,18 @@ class FileService(CommonService):
location = filename location = filename
while STORAGE_IMPL.obj_exist(kb.id, location): while STORAGE_IMPL.obj_exist(kb.id, location):
location += "_" location += "_"
blob = file.read() blob = file.read()
if filetype == FileType.PDF.value:
blob = read_potential_broken_pdf(blob)
STORAGE_IMPL.put(kb.id, location, blob) STORAGE_IMPL.put(kb.id, location, blob)
doc_id = get_uuid() doc_id = get_uuid()
img = thumbnail_img(filename, blob) img = thumbnail_img(filename, blob)
thumbnail_location = '' thumbnail_location = ""
if img is not None: if img is not None:
thumbnail_location = f'thumbnail_{doc_id}.png' thumbnail_location = f"thumbnail_{doc_id}.png"
STORAGE_IMPL.put(kb.id, thumbnail_location, img) STORAGE_IMPL.put(kb.id, thumbnail_location, img)
doc = { doc = {
@ -467,7 +447,7 @@ class FileService(CommonService):
"name": filename, "name": filename,
"location": location, "location": location,
"size": len(blob), "size": len(blob),
"thumbnail": thumbnail_location "thumbnail": thumbnail_location,
} }
DocumentService.insert(doc) DocumentService.insert(doc)
@ -480,29 +460,17 @@ class FileService(CommonService):
@staticmethod @staticmethod
def parse_docs(file_objs, user_id): def parse_docs(file_objs, user_id):
from rag.app import presentation, picture, naive, audio, email from rag.app import audio, email, naive, picture, presentation
def dummy(prog=None, msg=""): def dummy(prog=None, msg=""):
pass pass
FACTORY = { FACTORY = {ParserType.PRESENTATION.value: presentation, ParserType.PICTURE.value: picture, ParserType.AUDIO.value: audio, ParserType.EMAIL.value: email}
ParserType.PRESENTATION.value: presentation,
ParserType.PICTURE.value: picture,
ParserType.AUDIO.value: audio,
ParserType.EMAIL.value: email
}
parser_config = {"chunk_token_num": 16096, "delimiter": "\n!?;。;!?", "layout_recognize": "Plain Text"} parser_config = {"chunk_token_num": 16096, "delimiter": "\n!?;。;!?", "layout_recognize": "Plain Text"}
exe = ThreadPoolExecutor(max_workers=12) exe = ThreadPoolExecutor(max_workers=12)
threads = [] threads = []
for file in file_objs: for file in file_objs:
kwargs = { kwargs = {"lang": "English", "callback": dummy, "parser_config": parser_config, "from_page": 0, "to_page": 100000, "tenant_id": user_id}
"lang": "English",
"callback": dummy,
"parser_config": parser_config,
"from_page": 0,
"to_page": 100000,
"tenant_id": user_id
}
filetype = filename_type(file.filename) filetype = filename_type(file.filename)
blob = file.read() blob = file.read()
threads.append(exe.submit(FACTORY.get(FileService.get_parser(filetype, file.filename, ""), naive).chunk, file.filename, blob, **kwargs)) threads.append(exe.submit(FACTORY.get(FileService.get_parser(filetype, file.filename, ""), naive).chunk, file.filename, blob, **kwargs))
@ -524,3 +492,4 @@ class FileService(CommonService):
if re.search(r"\.(eml)$", filename): if re.search(r"\.(eml)$", filename):
return ParserType.EMAIL.value return ParserType.EMAIL.value
return default return default

View File

@ -17,17 +17,20 @@ import base64
import json import json
import os import os
import re import re
import shutil
import subprocess
import sys import sys
import tempfile
import threading import threading
from io import BytesIO from io import BytesIO
import pdfplumber import pdfplumber
from PIL import Image
from cachetools import LRUCache, cached from cachetools import LRUCache, cached
from PIL import Image
from ruamel.yaml import YAML from ruamel.yaml import YAML
from api.db import FileType
from api.constants import IMG_BASE64_PREFIX from api.constants import IMG_BASE64_PREFIX
from api.db import FileType
PROJECT_BASE = os.getenv("RAG_PROJECT_BASE") or os.getenv("RAG_DEPLOY_BASE") PROJECT_BASE = os.getenv("RAG_PROJECT_BASE") or os.getenv("RAG_DEPLOY_BASE")
RAG_BASE = os.getenv("RAG_BASE") RAG_BASE = os.getenv("RAG_BASE")
@ -74,7 +77,7 @@ def get_rag_python_directory(*args):
def get_home_cache_dir(): def get_home_cache_dir():
dir = os.path.join(os.path.expanduser('~'), ".ragflow") dir = os.path.join(os.path.expanduser("~"), ".ragflow")
try: try:
os.mkdir(dir) os.mkdir(dir)
except OSError: except OSError:
@ -92,9 +95,7 @@ def load_json_conf(conf_path):
with open(json_conf_path) as f: with open(json_conf_path) as f:
return json.load(f) return json.load(f)
except BaseException: except BaseException:
raise EnvironmentError( raise EnvironmentError("loading json file config from '{}' failed!".format(json_conf_path))
"loading json file config from '{}' failed!".format(json_conf_path)
)
def dump_json_conf(config_data, conf_path): def dump_json_conf(config_data, conf_path):
@ -106,9 +107,7 @@ def dump_json_conf(config_data, conf_path):
with open(json_conf_path, "w") as f: with open(json_conf_path, "w") as f:
json.dump(config_data, f, indent=4) json.dump(config_data, f, indent=4)
except BaseException: except BaseException:
raise EnvironmentError( raise EnvironmentError("loading json file config from '{}' failed!".format(json_conf_path))
"loading json file config from '{}' failed!".format(json_conf_path)
)
def load_json_conf_real_time(conf_path): def load_json_conf_real_time(conf_path):
@ -120,9 +119,7 @@ def load_json_conf_real_time(conf_path):
with open(json_conf_path) as f: with open(json_conf_path) as f:
return json.load(f) return json.load(f)
except BaseException: except BaseException:
raise EnvironmentError( raise EnvironmentError("loading json file config from '{}' failed!".format(json_conf_path))
"loading json file config from '{}' failed!".format(json_conf_path)
)
def load_yaml_conf(conf_path): def load_yaml_conf(conf_path):
@ -130,12 +127,10 @@ def load_yaml_conf(conf_path):
conf_path = os.path.join(get_project_base_directory(), conf_path) conf_path = os.path.join(get_project_base_directory(), conf_path)
try: try:
with open(conf_path) as f: with open(conf_path) as f:
yaml = YAML(typ='safe', pure=True) yaml = YAML(typ="safe", pure=True)
return yaml.load(f) return yaml.load(f)
except Exception as e: except Exception as e:
raise EnvironmentError( raise EnvironmentError("loading yaml file config from {} failed:".format(conf_path), e)
"loading yaml file config from {} failed:".format(conf_path), e
)
def rewrite_yaml_conf(conf_path, config): def rewrite_yaml_conf(conf_path, config):
@ -146,13 +141,11 @@ def rewrite_yaml_conf(conf_path, config):
yaml = YAML(typ="safe") yaml = YAML(typ="safe")
yaml.dump(config, f) yaml.dump(config, f)
except Exception as e: except Exception as e:
raise EnvironmentError( raise EnvironmentError("rewrite yaml file config {} failed:".format(conf_path), e)
"rewrite yaml file config {} failed:".format(conf_path), e
)
def rewrite_json_file(filepath, json_data): def rewrite_json_file(filepath, json_data):
with open(filepath, "w", encoding='utf-8') as f: with open(filepath, "w", encoding="utf-8") as f:
json.dump(json_data, f, indent=4, separators=(",", ": ")) json.dump(json_data, f, indent=4, separators=(",", ": "))
f.close() f.close()
@ -162,12 +155,10 @@ def filename_type(filename):
if re.match(r".*\.pdf$", filename): if re.match(r".*\.pdf$", filename):
return FileType.PDF.value return FileType.PDF.value
if re.match( if re.match(r".*\.(eml|doc|docx|ppt|pptx|yml|xml|htm|json|csv|txt|ini|xls|xlsx|wps|rtf|hlp|pages|numbers|key|md|py|js|java|c|cpp|h|php|go|ts|sh|cs|kt|html|sql)$", filename):
r".*\.(eml|doc|docx|ppt|pptx|yml|xml|htm|json|csv|txt|ini|xls|xlsx|wps|rtf|hlp|pages|numbers|key|md|py|js|java|c|cpp|h|php|go|ts|sh|cs|kt|html|sql)$", filename):
return FileType.DOC.value return FileType.DOC.value
if re.match( if re.match(r".*\.(wav|flac|ape|alac|wavpack|wv|mp3|aac|ogg|vorbis|opus|mp3)$", filename):
r".*\.(wav|flac|ape|alac|wavpack|wv|mp3|aac|ogg|vorbis|opus|mp3)$", filename):
return FileType.AURAL.value return FileType.AURAL.value
if re.match(r".*\.(jpg|jpeg|png|tif|gif|pcx|tga|exif|fpx|svg|psd|cdr|pcd|dxf|ufo|eps|ai|raw|WMF|webp|avif|apng|icon|ico|mpg|mpeg|avi|rm|rmvb|mov|wmv|asf|dat|asx|wvx|mpe|mpa|mp4)$", filename): if re.match(r".*\.(jpg|jpeg|png|tif|gif|pcx|tga|exif|fpx|svg|psd|cdr|pcd|dxf|ufo|eps|ai|raw|WMF|webp|avif|apng|icon|ico|mpg|mpeg|avi|rm|rmvb|mov|wmv|asf|dat|asx|wvx|mpe|mpa|mp4)$", filename):
@ -175,6 +166,7 @@ def filename_type(filename):
return FileType.OTHER.value return FileType.OTHER.value
def thumbnail_img(filename, blob): def thumbnail_img(filename, blob):
""" """
MySQL LongText max length is 65535 MySQL LongText max length is 65535
@ -183,6 +175,7 @@ def thumbnail_img(filename, blob):
if re.match(r".*\.pdf$", filename): if re.match(r".*\.pdf$", filename):
with sys.modules[LOCK_KEY_pdfplumber]: with sys.modules[LOCK_KEY_pdfplumber]:
pdf = pdfplumber.open(BytesIO(blob)) pdf = pdfplumber.open(BytesIO(blob))
buffered = BytesIO() buffered = BytesIO()
resolution = 32 resolution = 32
img = None img = None
@ -206,8 +199,9 @@ def thumbnail_img(filename, blob):
return buffered.getvalue() return buffered.getvalue()
elif re.match(r".*\.(ppt|pptx)$", filename): elif re.match(r".*\.(ppt|pptx)$", filename):
import aspose.slides as slides
import aspose.pydrawing as drawing import aspose.pydrawing as drawing
import aspose.slides as slides
try: try:
with slides.Presentation(BytesIO(blob)) as presentation: with slides.Presentation(BytesIO(blob)) as presentation:
buffered = BytesIO() buffered = BytesIO()
@ -215,8 +209,7 @@ def thumbnail_img(filename, blob):
img = None img = None
for _ in range(10): for _ in range(10):
# https://reference.aspose.com/slides/python-net/aspose.slides/slide/get_thumbnail/#float-float # https://reference.aspose.com/slides/python-net/aspose.slides/slide/get_thumbnail/#float-float
presentation.slides[0].get_thumbnail(scale, scale).save( presentation.slides[0].get_thumbnail(scale, scale).save(buffered, drawing.imaging.ImageFormat.png)
buffered, drawing.imaging.ImageFormat.png)
img = buffered.getvalue() img = buffered.getvalue()
if len(img) >= 64000: if len(img) >= 64000:
scale = scale / 2.0 scale = scale / 2.0
@ -232,10 +225,9 @@ def thumbnail_img(filename, blob):
def thumbnail(filename, blob): def thumbnail(filename, blob):
img = thumbnail_img(filename, blob) img = thumbnail_img(filename, blob)
if img is not None: if img is not None:
return IMG_BASE64_PREFIX + \ return IMG_BASE64_PREFIX + base64.b64encode(img).decode("utf-8")
base64.b64encode(img).decode("utf-8")
else: else:
return '' return ""
def traversal_files(base): def traversal_files(base):
@ -243,3 +235,52 @@ def traversal_files(base):
for f in fs: for f in fs:
fullname = os.path.join(root, f) fullname = os.path.join(root, f)
yield fullname yield fullname
def repair_pdf_with_ghostscript(input_bytes):
if shutil.which("gs") is None:
return input_bytes
with tempfile.NamedTemporaryFile(suffix=".pdf") as temp_in, tempfile.NamedTemporaryFile(suffix=".pdf") as temp_out:
temp_in.write(input_bytes)
temp_in.flush()
cmd = [
"gs",
"-o",
temp_out.name,
"-sDEVICE=pdfwrite",
"-dPDFSETTINGS=/prepress",
temp_in.name,
]
try:
proc = subprocess.run(cmd, capture_output=True, text=True)
if proc.returncode != 0:
return input_bytes
except Exception:
return input_bytes
temp_out.seek(0)
repaired_bytes = temp_out.read()
return repaired_bytes
def read_potential_broken_pdf(blob):
def try_open(blob):
try:
with pdfplumber.open(BytesIO(blob)) as pdf:
if pdf.pages:
return True
except Exception:
return False
return False
if try_open(blob):
return blob
repaired = repair_pdf_with_ghostscript(blob)
if try_open(repaired):
return repaired
return blob