diff --git a/api/apps/api_app.py b/api/apps/api_app.py index 502dee6cd..47ca514d9 100644 --- a/api/apps/api_app.py +++ b/api/apps/api_app.py @@ -26,7 +26,7 @@ from api.db.db_models import APIToken, API4Conversation, Task, File from api.db.services import duplicate_name from api.db.services.api_service import APITokenService, API4ConversationService from api.db.services.dialog_service import DialogService, chat -from api.db.services.document_service import DocumentService +from api.db.services.document_service import DocumentService, doc_upload_and_parse from api.db.services.file2document_service import File2DocumentService from api.db.services.file_service import FileService from api.db.services.knowledgebase_service import KnowledgebaseService @@ -470,6 +470,29 @@ def upload(): return get_json_result(data=doc_result.to_json()) +@manager.route('/document/upload_and_parse', methods=['POST']) +@validate_request("conversation_id") +def upload_parse(): + token = request.headers.get('Authorization').split()[1] + objs = APIToken.query(token=token) + if not objs: + return get_json_result( + data=False, retmsg='Token is not valid!"', retcode=RetCode.AUTHENTICATION_ERROR) + + if 'file' not in request.files: + return get_json_result( + data=False, retmsg='No file part!', retcode=RetCode.ARGUMENT_ERROR) + + file_objs = request.files.getlist('file') + for file_obj in file_objs: + if file_obj.filename == '': + return get_json_result( + data=False, retmsg='No file selected!', retcode=RetCode.ARGUMENT_ERROR) + + doc_ids = doc_upload_and_parse(request.form.get("conversation_id"), file_objs, objs[0].tenant_id) + return get_json_result(data=doc_ids) + + @manager.route('/list_chunks', methods=['POST']) # @login_required def list_chunks(): @@ -560,7 +583,6 @@ def document_rm(): tenant_id = objs[0].tenant_id req = request.json - doc_ids = [] try: doc_ids = [DocumentService.get_doc_id_by_doc_name(doc_name) for doc_name in req.get("doc_names", [])] for doc_id in req.get("doc_ids", []): diff --git a/api/apps/document_app.py b/api/apps/document_app.py index 721f8d137..8ca804fa7 100644 --- a/api/apps/document_app.py +++ b/api/apps/document_app.py @@ -45,7 +45,7 @@ from api.db.services.knowledgebase_service import KnowledgebaseService from api.utils.api_utils import server_error_response, get_data_error_result, validate_request from api.utils import get_uuid from api.db import FileType, TaskStatus, ParserType, FileSource, LLMType -from api.db.services.document_service import DocumentService +from api.db.services.document_service import DocumentService, doc_upload_and_parse from api.settings import RetCode, stat_logger from api.utils.api_utils import get_json_result from rag.utils.minio_conn import MINIO @@ -75,7 +75,7 @@ def upload(): if not e: raise LookupError("Can't find this knowledgebase!") - err, _ = FileService.upload_document(kb, file_objs) + err, _ = FileService.upload_document(kb, file_objs, current_user.id) if err: return get_json_result( data=False, retmsg="\n".join(err), retcode=RetCode.SERVER_ERROR) @@ -212,7 +212,7 @@ def docinfos(): @manager.route('/thumbnails', methods=['GET']) -@login_required +#@login_required def thumbnails(): doc_ids = request.args.get("doc_ids").split(",") if not doc_ids: @@ -460,7 +460,6 @@ def get_image(image_id): @login_required @validate_request("conversation_id") def upload_and_parse(): - from rag.app import presentation, picture, naive, audio, email if 'file' not in request.files: return get_json_result( data=False, retmsg='No file part!', retcode=RetCode.ARGUMENT_ERROR) @@ -471,124 +470,6 @@ def upload_and_parse(): return get_json_result( data=False, retmsg='No file selected!', retcode=RetCode.ARGUMENT_ERROR) - e, conv = ConversationService.get_by_id(request.form.get("conversation_id")) - if not e: - return get_data_error_result(retmsg="Conversation not found!") - e, dia = DialogService.get_by_id(conv.dialog_id) - kb_id = dia.kb_ids[0] - e, kb = KnowledgebaseService.get_by_id(kb_id) - if not e: - raise LookupError("Can't find this knowledgebase!") + doc_ids = doc_upload_and_parse(request.form.get("conversation_id"), file_objs, current_user.id) - idxnm = search.index_name(kb.tenant_id) - if not ELASTICSEARCH.indexExist(idxnm): - ELASTICSEARCH.createIdx(idxnm, json.load( - open(os.path.join(get_project_base_directory(), "conf", "mapping.json"), "r"))) - - embd_mdl = LLMBundle(kb.tenant_id, LLMType.EMBEDDING, llm_name=kb.embd_id, lang=kb.language) - - err, files = FileService.upload_document(kb, file_objs) - if err: - return get_json_result( - data=False, retmsg="\n".join(err), retcode=RetCode.SERVER_ERROR) - - def dummy(prog=None, msg=""): - pass - - FACTORY = { - ParserType.PRESENTATION.value: presentation, - ParserType.PICTURE.value: picture, - ParserType.AUDIO.value: audio, - ParserType.EMAIL.value: email - } - parser_config = {"chunk_token_num": 4096, "delimiter": "\n!?;。;!?", "layout_recognize": False} - exe = ThreadPoolExecutor(max_workers=12) - threads = [] - for d, blob in files: - kwargs = { - "callback": dummy, - "parser_config": parser_config, - "from_page": 0, - "to_page": 100000, - "tenant_id": kb.tenant_id, - "lang": kb.language - } - threads.append(exe.submit(FACTORY.get(d["parser_id"], naive).chunk, d["name"], blob, **kwargs)) - - for (docinfo,_), th in zip(files, threads): - docs = [] - doc = { - "doc_id": docinfo["id"], - "kb_id": [kb.id] - } - for ck in th.result(): - d = deepcopy(doc) - d.update(ck) - md5 = hashlib.md5() - md5.update((ck["content_with_weight"] + - str(d["doc_id"])).encode("utf-8")) - d["_id"] = md5.hexdigest() - d["create_time"] = str(datetime.datetime.now()).replace("T", " ")[:19] - d["create_timestamp_flt"] = datetime.datetime.now().timestamp() - if not d.get("image"): - docs.append(d) - continue - - output_buffer = BytesIO() - if isinstance(d["image"], bytes): - output_buffer = BytesIO(d["image"]) - else: - d["image"].save(output_buffer, format='JPEG') - - MINIO.put(kb.id, d["_id"], output_buffer.getvalue()) - d["img_id"] = "{}-{}".format(kb.id, d["_id"]) - del d["image"] - docs.append(d) - - parser_ids = {d["id"]: d["parser_id"] for d, _ in files} - docids = [d["id"] for d, _ in files] - chunk_counts = {id: 0 for id in docids} - token_counts = {id: 0 for id in docids} - es_bulk_size = 64 - - def embedding(doc_id, cnts, batch_size=16): - nonlocal embd_mdl, chunk_counts, token_counts - vects = [] - for i in range(0, len(cnts), batch_size): - vts, c = embd_mdl.encode(cnts[i: i + batch_size]) - vects.extend(vts.tolist()) - chunk_counts[doc_id] += len(cnts[i:i + batch_size]) - token_counts[doc_id] += c - return vects - - _, tenant = TenantService.get_by_id(kb.tenant_id) - llm_bdl = LLMBundle(kb.tenant_id, LLMType.CHAT, tenant.llm_id) - for doc_id in docids: - cks = [c for c in docs if c["doc_id"] == doc_id] - - if False and parser_ids[doc_id] != ParserType.PICTURE.value: - mindmap = MindMapExtractor(llm_bdl) - try: - mind_map = json.dumps(mindmap([c["content_with_weight"] for c in docs if c["doc_id"] == doc_id]).output, ensure_ascii=False, indent=2) - if len(mind_map) < 32: raise Exception("Few content: "+mind_map) - cks.append({ - "doc_id": doc_id, - "kb_id": [kb.id], - "content_with_weight": mind_map, - "knowledge_graph_kwd": "mind_map" - }) - except Exception as e: - stat_logger.error("Mind map generation error:", traceback.format_exc()) - - vects = embedding(doc_id, [c["content_with_weight"] for c in cks]) - assert len(cks) == len(vects) - for i, d in enumerate(cks): - v = vects[i] - d["q_%d_vec" % len(v)] = v - for b in range(0, len(cks), es_bulk_size): - ELASTICSEARCH.bulk(cks[b:b + es_bulk_size], idxnm) - - DocumentService.increment_chunk_num( - doc_id, kb.id, token_counts[doc_id], chunk_counts[doc_id], 0) - - return get_json_result(data=[d["id"] for d,_ in files]) + return get_json_result(data=doc_ids) diff --git a/api/db/services/document_service.py b/api/db/services/document_service.py index 0eb2b8c94..328ee924a 100644 --- a/api/db/services/document_service.py +++ b/api/db/services/document_service.py @@ -13,20 +13,29 @@ # See the License for the specific language governing permissions and # limitations under the License. # +import hashlib +import json +import os import random +from concurrent.futures import ThreadPoolExecutor +from copy import deepcopy from datetime import datetime +from io import BytesIO + from elasticsearch_dsl import Q from peewee import fn from api.db.db_utils import bulk_insert_into_db from api.settings import stat_logger from api.utils import current_timestamp, get_format_time, get_uuid +from api.utils.file_utils import get_project_base_directory +from graphrag.mind_map_extractor import MindMapExtractor from rag.settings import SVR_QUEUE_NAME from rag.utils.es_conn import ELASTICSEARCH from rag.utils.minio_conn import MINIO from rag.nlp import search -from api.db import FileType, TaskStatus, ParserType +from api.db import FileType, TaskStatus, ParserType, LLMType from api.db.db_models import DB, Knowledgebase, Tenant, Task from api.db.db_models import Document from api.db.services.common_service import CommonService @@ -380,3 +389,136 @@ def queue_raptor_tasks(doc): bulk_insert_into_db(Task, [task], True) task["type"] = "raptor" assert REDIS_CONN.queue_product(SVR_QUEUE_NAME, message=task), "Can't access Redis. Please check the Redis' status." + + +def doc_upload_and_parse(conversation_id, file_objs, user_id): + from rag.app import presentation, picture, naive, audio, email + from api.db.services.dialog_service import ConversationService, DialogService + from api.db.services.file_service import FileService + from api.db.services.llm_service import LLMBundle + from api.db.services.user_service import TenantService + from api.db.services.api_service import API4ConversationService + + e, conv = ConversationService.get_by_id(conversation_id) + if not e: + e, conv = API4ConversationService.get_by_id(conversation_id) + assert e, "Conversation not found!" + + e, dia = DialogService.get_by_id(conv.dialog_id) + kb_id = dia.kb_ids[0] + e, kb = KnowledgebaseService.get_by_id(kb_id) + if not e: + raise LookupError("Can't find this knowledgebase!") + + idxnm = search.index_name(kb.tenant_id) + if not ELASTICSEARCH.indexExist(idxnm): + ELASTICSEARCH.createIdx(idxnm, json.load( + open(os.path.join(get_project_base_directory(), "conf", "mapping.json"), "r"))) + + embd_mdl = LLMBundle(kb.tenant_id, LLMType.EMBEDDING, llm_name=kb.embd_id, lang=kb.language) + + err, files = FileService.upload_document(kb, file_objs, user_id) + assert not err, "\n".join(err) + + def dummy(prog=None, msg=""): + pass + + FACTORY = { + ParserType.PRESENTATION.value: presentation, + ParserType.PICTURE.value: picture, + ParserType.AUDIO.value: audio, + ParserType.EMAIL.value: email + } + parser_config = {"chunk_token_num": 4096, "delimiter": "\n!?;。;!?", "layout_recognize": False} + exe = ThreadPoolExecutor(max_workers=12) + threads = [] + for d, blob in files: + kwargs = { + "callback": dummy, + "parser_config": parser_config, + "from_page": 0, + "to_page": 100000, + "tenant_id": kb.tenant_id, + "lang": kb.language + } + threads.append(exe.submit(FACTORY.get(d["parser_id"], naive).chunk, d["name"], blob, **kwargs)) + + for (docinfo, _), th in zip(files, threads): + docs = [] + doc = { + "doc_id": docinfo["id"], + "kb_id": [kb.id] + } + for ck in th.result(): + d = deepcopy(doc) + d.update(ck) + md5 = hashlib.md5() + md5.update((ck["content_with_weight"] + + str(d["doc_id"])).encode("utf-8")) + d["_id"] = md5.hexdigest() + d["create_time"] = str(datetime.now()).replace("T", " ")[:19] + d["create_timestamp_flt"] = datetime.now().timestamp() + if not d.get("image"): + docs.append(d) + continue + + output_buffer = BytesIO() + if isinstance(d["image"], bytes): + output_buffer = BytesIO(d["image"]) + else: + d["image"].save(output_buffer, format='JPEG') + + MINIO.put(kb.id, d["_id"], output_buffer.getvalue()) + d["img_id"] = "{}-{}".format(kb.id, d["_id"]) + del d["image"] + docs.append(d) + + parser_ids = {d["id"]: d["parser_id"] for d, _ in files} + docids = [d["id"] for d, _ in files] + chunk_counts = {id: 0 for id in docids} + token_counts = {id: 0 for id in docids} + es_bulk_size = 64 + + def embedding(doc_id, cnts, batch_size=16): + nonlocal embd_mdl, chunk_counts, token_counts + vects = [] + for i in range(0, len(cnts), batch_size): + vts, c = embd_mdl.encode(cnts[i: i + batch_size]) + vects.extend(vts.tolist()) + chunk_counts[doc_id] += len(cnts[i:i + batch_size]) + token_counts[doc_id] += c + return vects + + _, tenant = TenantService.get_by_id(kb.tenant_id) + llm_bdl = LLMBundle(kb.tenant_id, LLMType.CHAT, tenant.llm_id) + for doc_id in docids: + cks = [c for c in docs if c["doc_id"] == doc_id] + + if parser_ids[doc_id] != ParserType.PICTURE.value: + mindmap = MindMapExtractor(llm_bdl) + try: + mind_map = json.dumps(mindmap([c["content_with_weight"] for c in docs if c["doc_id"] == doc_id]).output, + ensure_ascii=False, indent=2) + if len(mind_map) < 32: raise Exception("Few content: " + mind_map) + cks.append({ + "id": get_uuid(), + "doc_id": doc_id, + "kb_id": [kb.id], + "content_with_weight": mind_map, + "knowledge_graph_kwd": "mind_map" + }) + except Exception as e: + stat_logger.error("Mind map generation error:", traceback.format_exc()) + + vects = embedding(doc_id, [c["content_with_weight"] for c in cks]) + assert len(cks) == len(vects) + for i, d in enumerate(cks): + v = vects[i] + d["q_%d_vec" % len(v)] = v + for b in range(0, len(cks), es_bulk_size): + ELASTICSEARCH.bulk(cks[b:b + es_bulk_size], idxnm) + + DocumentService.increment_chunk_num( + doc_id, kb.id, token_counts[doc_id], chunk_counts[doc_id], 0) + + return [d["id"] for d,_ in files] \ No newline at end of file diff --git a/api/db/services/file_service.py b/api/db/services/file_service.py index 27db6db56..670c4c36a 100644 --- a/api/db/services/file_service.py +++ b/api/db/services/file_service.py @@ -327,11 +327,11 @@ class FileService(CommonService): @classmethod @DB.connection_context() - def upload_document(self, kb, file_objs): - root_folder = self.get_root_folder(current_user.id) + def upload_document(self, kb, file_objs, user_id): + root_folder = self.get_root_folder(user_id) pf_id = root_folder["id"] - self.init_knowledgebase_docs(pf_id, current_user.id) - kb_root_folder = self.get_kb_folder(current_user.id) + self.init_knowledgebase_docs(pf_id, user_id) + kb_root_folder = self.get_kb_folder(user_id) kb_folder = self.new_a_file_from_kb(kb.tenant_id, kb.name, kb_root_folder["id"]) err, files = [], [] @@ -359,7 +359,7 @@ class FileService(CommonService): "kb_id": kb.id, "parser_id": kb.parser_id, "parser_config": kb.parser_config, - "created_by": current_user.id, + "created_by": user_id, "type": filetype, "name": filename, "location": location,