refine upload & parse (#1969)

### What problem does this PR solve?


### Type of change
- [x] Refactoring
This commit is contained in:
Kevin Hu 2024-08-15 19:30:43 +08:00 committed by GitHub
parent d92e927685
commit 4810cb2dc9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 177 additions and 132 deletions

View File

@ -26,7 +26,7 @@ from api.db.db_models import APIToken, API4Conversation, Task, File
from api.db.services import duplicate_name from api.db.services import duplicate_name
from api.db.services.api_service import APITokenService, API4ConversationService from api.db.services.api_service import APITokenService, API4ConversationService
from api.db.services.dialog_service import DialogService, chat from api.db.services.dialog_service import DialogService, chat
from api.db.services.document_service import DocumentService from api.db.services.document_service import DocumentService, doc_upload_and_parse
from api.db.services.file2document_service import File2DocumentService from api.db.services.file2document_service import File2DocumentService
from api.db.services.file_service import FileService from api.db.services.file_service import FileService
from api.db.services.knowledgebase_service import KnowledgebaseService from api.db.services.knowledgebase_service import KnowledgebaseService
@ -470,6 +470,29 @@ def upload():
return get_json_result(data=doc_result.to_json()) return get_json_result(data=doc_result.to_json())
@manager.route('/document/upload_and_parse', methods=['POST'])
@validate_request("conversation_id")
def upload_parse():
token = request.headers.get('Authorization').split()[1]
objs = APIToken.query(token=token)
if not objs:
return get_json_result(
data=False, retmsg='Token is not valid!"', retcode=RetCode.AUTHENTICATION_ERROR)
if 'file' not in request.files:
return get_json_result(
data=False, retmsg='No file part!', retcode=RetCode.ARGUMENT_ERROR)
file_objs = request.files.getlist('file')
for file_obj in file_objs:
if file_obj.filename == '':
return get_json_result(
data=False, retmsg='No file selected!', retcode=RetCode.ARGUMENT_ERROR)
doc_ids = doc_upload_and_parse(request.form.get("conversation_id"), file_objs, objs[0].tenant_id)
return get_json_result(data=doc_ids)
@manager.route('/list_chunks', methods=['POST']) @manager.route('/list_chunks', methods=['POST'])
# @login_required # @login_required
def list_chunks(): def list_chunks():
@ -560,7 +583,6 @@ def document_rm():
tenant_id = objs[0].tenant_id tenant_id = objs[0].tenant_id
req = request.json req = request.json
doc_ids = []
try: try:
doc_ids = [DocumentService.get_doc_id_by_doc_name(doc_name) for doc_name in req.get("doc_names", [])] doc_ids = [DocumentService.get_doc_id_by_doc_name(doc_name) for doc_name in req.get("doc_names", [])]
for doc_id in req.get("doc_ids", []): for doc_id in req.get("doc_ids", []):

View File

@ -45,7 +45,7 @@ from api.db.services.knowledgebase_service import KnowledgebaseService
from api.utils.api_utils import server_error_response, get_data_error_result, validate_request from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
from api.utils import get_uuid from api.utils import get_uuid
from api.db import FileType, TaskStatus, ParserType, FileSource, LLMType from api.db import FileType, TaskStatus, ParserType, FileSource, LLMType
from api.db.services.document_service import DocumentService from api.db.services.document_service import DocumentService, doc_upload_and_parse
from api.settings import RetCode, stat_logger from api.settings import RetCode, stat_logger
from api.utils.api_utils import get_json_result from api.utils.api_utils import get_json_result
from rag.utils.minio_conn import MINIO from rag.utils.minio_conn import MINIO
@ -75,7 +75,7 @@ def upload():
if not e: if not e:
raise LookupError("Can't find this knowledgebase!") raise LookupError("Can't find this knowledgebase!")
err, _ = FileService.upload_document(kb, file_objs) err, _ = FileService.upload_document(kb, file_objs, current_user.id)
if err: if err:
return get_json_result( return get_json_result(
data=False, retmsg="\n".join(err), retcode=RetCode.SERVER_ERROR) data=False, retmsg="\n".join(err), retcode=RetCode.SERVER_ERROR)
@ -212,7 +212,7 @@ def docinfos():
@manager.route('/thumbnails', methods=['GET']) @manager.route('/thumbnails', methods=['GET'])
@login_required #@login_required
def thumbnails(): def thumbnails():
doc_ids = request.args.get("doc_ids").split(",") doc_ids = request.args.get("doc_ids").split(",")
if not doc_ids: if not doc_ids:
@ -460,7 +460,6 @@ def get_image(image_id):
@login_required @login_required
@validate_request("conversation_id") @validate_request("conversation_id")
def upload_and_parse(): def upload_and_parse():
from rag.app import presentation, picture, naive, audio, email
if 'file' not in request.files: if 'file' not in request.files:
return get_json_result( return get_json_result(
data=False, retmsg='No file part!', retcode=RetCode.ARGUMENT_ERROR) data=False, retmsg='No file part!', retcode=RetCode.ARGUMENT_ERROR)
@ -471,124 +470,6 @@ def upload_and_parse():
return get_json_result( return get_json_result(
data=False, retmsg='No file selected!', retcode=RetCode.ARGUMENT_ERROR) data=False, retmsg='No file selected!', retcode=RetCode.ARGUMENT_ERROR)
e, conv = ConversationService.get_by_id(request.form.get("conversation_id")) doc_ids = doc_upload_and_parse(request.form.get("conversation_id"), file_objs, current_user.id)
if not e:
return get_data_error_result(retmsg="Conversation not found!")
e, dia = DialogService.get_by_id(conv.dialog_id)
kb_id = dia.kb_ids[0]
e, kb = KnowledgebaseService.get_by_id(kb_id)
if not e:
raise LookupError("Can't find this knowledgebase!")
idxnm = search.index_name(kb.tenant_id) return get_json_result(data=doc_ids)
if not ELASTICSEARCH.indexExist(idxnm):
ELASTICSEARCH.createIdx(idxnm, json.load(
open(os.path.join(get_project_base_directory(), "conf", "mapping.json"), "r")))
embd_mdl = LLMBundle(kb.tenant_id, LLMType.EMBEDDING, llm_name=kb.embd_id, lang=kb.language)
err, files = FileService.upload_document(kb, file_objs)
if err:
return get_json_result(
data=False, retmsg="\n".join(err), retcode=RetCode.SERVER_ERROR)
def dummy(prog=None, msg=""):
pass
FACTORY = {
ParserType.PRESENTATION.value: presentation,
ParserType.PICTURE.value: picture,
ParserType.AUDIO.value: audio,
ParserType.EMAIL.value: email
}
parser_config = {"chunk_token_num": 4096, "delimiter": "\n!?;。;!?", "layout_recognize": False}
exe = ThreadPoolExecutor(max_workers=12)
threads = []
for d, blob in files:
kwargs = {
"callback": dummy,
"parser_config": parser_config,
"from_page": 0,
"to_page": 100000,
"tenant_id": kb.tenant_id,
"lang": kb.language
}
threads.append(exe.submit(FACTORY.get(d["parser_id"], naive).chunk, d["name"], blob, **kwargs))
for (docinfo,_), th in zip(files, threads):
docs = []
doc = {
"doc_id": docinfo["id"],
"kb_id": [kb.id]
}
for ck in th.result():
d = deepcopy(doc)
d.update(ck)
md5 = hashlib.md5()
md5.update((ck["content_with_weight"] +
str(d["doc_id"])).encode("utf-8"))
d["_id"] = md5.hexdigest()
d["create_time"] = str(datetime.datetime.now()).replace("T", " ")[:19]
d["create_timestamp_flt"] = datetime.datetime.now().timestamp()
if not d.get("image"):
docs.append(d)
continue
output_buffer = BytesIO()
if isinstance(d["image"], bytes):
output_buffer = BytesIO(d["image"])
else:
d["image"].save(output_buffer, format='JPEG')
MINIO.put(kb.id, d["_id"], output_buffer.getvalue())
d["img_id"] = "{}-{}".format(kb.id, d["_id"])
del d["image"]
docs.append(d)
parser_ids = {d["id"]: d["parser_id"] for d, _ in files}
docids = [d["id"] for d, _ in files]
chunk_counts = {id: 0 for id in docids}
token_counts = {id: 0 for id in docids}
es_bulk_size = 64
def embedding(doc_id, cnts, batch_size=16):
nonlocal embd_mdl, chunk_counts, token_counts
vects = []
for i in range(0, len(cnts), batch_size):
vts, c = embd_mdl.encode(cnts[i: i + batch_size])
vects.extend(vts.tolist())
chunk_counts[doc_id] += len(cnts[i:i + batch_size])
token_counts[doc_id] += c
return vects
_, tenant = TenantService.get_by_id(kb.tenant_id)
llm_bdl = LLMBundle(kb.tenant_id, LLMType.CHAT, tenant.llm_id)
for doc_id in docids:
cks = [c for c in docs if c["doc_id"] == doc_id]
if False and parser_ids[doc_id] != ParserType.PICTURE.value:
mindmap = MindMapExtractor(llm_bdl)
try:
mind_map = json.dumps(mindmap([c["content_with_weight"] for c in docs if c["doc_id"] == doc_id]).output, ensure_ascii=False, indent=2)
if len(mind_map) < 32: raise Exception("Few content: "+mind_map)
cks.append({
"doc_id": doc_id,
"kb_id": [kb.id],
"content_with_weight": mind_map,
"knowledge_graph_kwd": "mind_map"
})
except Exception as e:
stat_logger.error("Mind map generation error:", traceback.format_exc())
vects = embedding(doc_id, [c["content_with_weight"] for c in cks])
assert len(cks) == len(vects)
for i, d in enumerate(cks):
v = vects[i]
d["q_%d_vec" % len(v)] = v
for b in range(0, len(cks), es_bulk_size):
ELASTICSEARCH.bulk(cks[b:b + es_bulk_size], idxnm)
DocumentService.increment_chunk_num(
doc_id, kb.id, token_counts[doc_id], chunk_counts[doc_id], 0)
return get_json_result(data=[d["id"] for d,_ in files])

View File

@ -13,20 +13,29 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# #
import hashlib
import json
import os
import random import random
from concurrent.futures import ThreadPoolExecutor
from copy import deepcopy
from datetime import datetime from datetime import datetime
from io import BytesIO
from elasticsearch_dsl import Q from elasticsearch_dsl import Q
from peewee import fn from peewee import fn
from api.db.db_utils import bulk_insert_into_db from api.db.db_utils import bulk_insert_into_db
from api.settings import stat_logger from api.settings import stat_logger
from api.utils import current_timestamp, get_format_time, get_uuid from api.utils import current_timestamp, get_format_time, get_uuid
from api.utils.file_utils import get_project_base_directory
from graphrag.mind_map_extractor import MindMapExtractor
from rag.settings import SVR_QUEUE_NAME from rag.settings import SVR_QUEUE_NAME
from rag.utils.es_conn import ELASTICSEARCH from rag.utils.es_conn import ELASTICSEARCH
from rag.utils.minio_conn import MINIO from rag.utils.minio_conn import MINIO
from rag.nlp import search from rag.nlp import search
from api.db import FileType, TaskStatus, ParserType from api.db import FileType, TaskStatus, ParserType, LLMType
from api.db.db_models import DB, Knowledgebase, Tenant, Task from api.db.db_models import DB, Knowledgebase, Tenant, Task
from api.db.db_models import Document from api.db.db_models import Document
from api.db.services.common_service import CommonService from api.db.services.common_service import CommonService
@ -380,3 +389,136 @@ def queue_raptor_tasks(doc):
bulk_insert_into_db(Task, [task], True) bulk_insert_into_db(Task, [task], True)
task["type"] = "raptor" task["type"] = "raptor"
assert REDIS_CONN.queue_product(SVR_QUEUE_NAME, message=task), "Can't access Redis. Please check the Redis' status." assert REDIS_CONN.queue_product(SVR_QUEUE_NAME, message=task), "Can't access Redis. Please check the Redis' status."
def doc_upload_and_parse(conversation_id, file_objs, user_id):
from rag.app import presentation, picture, naive, audio, email
from api.db.services.dialog_service import ConversationService, DialogService
from api.db.services.file_service import FileService
from api.db.services.llm_service import LLMBundle
from api.db.services.user_service import TenantService
from api.db.services.api_service import API4ConversationService
e, conv = ConversationService.get_by_id(conversation_id)
if not e:
e, conv = API4ConversationService.get_by_id(conversation_id)
assert e, "Conversation not found!"
e, dia = DialogService.get_by_id(conv.dialog_id)
kb_id = dia.kb_ids[0]
e, kb = KnowledgebaseService.get_by_id(kb_id)
if not e:
raise LookupError("Can't find this knowledgebase!")
idxnm = search.index_name(kb.tenant_id)
if not ELASTICSEARCH.indexExist(idxnm):
ELASTICSEARCH.createIdx(idxnm, json.load(
open(os.path.join(get_project_base_directory(), "conf", "mapping.json"), "r")))
embd_mdl = LLMBundle(kb.tenant_id, LLMType.EMBEDDING, llm_name=kb.embd_id, lang=kb.language)
err, files = FileService.upload_document(kb, file_objs, user_id)
assert not err, "\n".join(err)
def dummy(prog=None, msg=""):
pass
FACTORY = {
ParserType.PRESENTATION.value: presentation,
ParserType.PICTURE.value: picture,
ParserType.AUDIO.value: audio,
ParserType.EMAIL.value: email
}
parser_config = {"chunk_token_num": 4096, "delimiter": "\n!?;。;!?", "layout_recognize": False}
exe = ThreadPoolExecutor(max_workers=12)
threads = []
for d, blob in files:
kwargs = {
"callback": dummy,
"parser_config": parser_config,
"from_page": 0,
"to_page": 100000,
"tenant_id": kb.tenant_id,
"lang": kb.language
}
threads.append(exe.submit(FACTORY.get(d["parser_id"], naive).chunk, d["name"], blob, **kwargs))
for (docinfo, _), th in zip(files, threads):
docs = []
doc = {
"doc_id": docinfo["id"],
"kb_id": [kb.id]
}
for ck in th.result():
d = deepcopy(doc)
d.update(ck)
md5 = hashlib.md5()
md5.update((ck["content_with_weight"] +
str(d["doc_id"])).encode("utf-8"))
d["_id"] = md5.hexdigest()
d["create_time"] = str(datetime.now()).replace("T", " ")[:19]
d["create_timestamp_flt"] = datetime.now().timestamp()
if not d.get("image"):
docs.append(d)
continue
output_buffer = BytesIO()
if isinstance(d["image"], bytes):
output_buffer = BytesIO(d["image"])
else:
d["image"].save(output_buffer, format='JPEG')
MINIO.put(kb.id, d["_id"], output_buffer.getvalue())
d["img_id"] = "{}-{}".format(kb.id, d["_id"])
del d["image"]
docs.append(d)
parser_ids = {d["id"]: d["parser_id"] for d, _ in files}
docids = [d["id"] for d, _ in files]
chunk_counts = {id: 0 for id in docids}
token_counts = {id: 0 for id in docids}
es_bulk_size = 64
def embedding(doc_id, cnts, batch_size=16):
nonlocal embd_mdl, chunk_counts, token_counts
vects = []
for i in range(0, len(cnts), batch_size):
vts, c = embd_mdl.encode(cnts[i: i + batch_size])
vects.extend(vts.tolist())
chunk_counts[doc_id] += len(cnts[i:i + batch_size])
token_counts[doc_id] += c
return vects
_, tenant = TenantService.get_by_id(kb.tenant_id)
llm_bdl = LLMBundle(kb.tenant_id, LLMType.CHAT, tenant.llm_id)
for doc_id in docids:
cks = [c for c in docs if c["doc_id"] == doc_id]
if parser_ids[doc_id] != ParserType.PICTURE.value:
mindmap = MindMapExtractor(llm_bdl)
try:
mind_map = json.dumps(mindmap([c["content_with_weight"] for c in docs if c["doc_id"] == doc_id]).output,
ensure_ascii=False, indent=2)
if len(mind_map) < 32: raise Exception("Few content: " + mind_map)
cks.append({
"id": get_uuid(),
"doc_id": doc_id,
"kb_id": [kb.id],
"content_with_weight": mind_map,
"knowledge_graph_kwd": "mind_map"
})
except Exception as e:
stat_logger.error("Mind map generation error:", traceback.format_exc())
vects = embedding(doc_id, [c["content_with_weight"] for c in cks])
assert len(cks) == len(vects)
for i, d in enumerate(cks):
v = vects[i]
d["q_%d_vec" % len(v)] = v
for b in range(0, len(cks), es_bulk_size):
ELASTICSEARCH.bulk(cks[b:b + es_bulk_size], idxnm)
DocumentService.increment_chunk_num(
doc_id, kb.id, token_counts[doc_id], chunk_counts[doc_id], 0)
return [d["id"] for d,_ in files]

View File

@ -327,11 +327,11 @@ class FileService(CommonService):
@classmethod @classmethod
@DB.connection_context() @DB.connection_context()
def upload_document(self, kb, file_objs): def upload_document(self, kb, file_objs, user_id):
root_folder = self.get_root_folder(current_user.id) root_folder = self.get_root_folder(user_id)
pf_id = root_folder["id"] pf_id = root_folder["id"]
self.init_knowledgebase_docs(pf_id, current_user.id) self.init_knowledgebase_docs(pf_id, user_id)
kb_root_folder = self.get_kb_folder(current_user.id) kb_root_folder = self.get_kb_folder(user_id)
kb_folder = self.new_a_file_from_kb(kb.tenant_id, kb.name, kb_root_folder["id"]) kb_folder = self.new_a_file_from_kb(kb.tenant_id, kb.name, kb_root_folder["id"])
err, files = [], [] err, files = [], []
@ -359,7 +359,7 @@ class FileService(CommonService):
"kb_id": kb.id, "kb_id": kb.id,
"parser_id": kb.parser_id, "parser_id": kb.parser_id,
"parser_config": kb.parser_config, "parser_config": kb.parser_config,
"created_by": current_user.id, "created_by": user_id,
"type": filetype, "type": filetype,
"name": filename, "name": filename,
"location": location, "location": location,