apply pep8 formalize (#155)

This commit is contained in:
KevinHuSh 2024-03-27 11:33:46 +08:00 committed by GitHub
parent a02e836790
commit fd7fcb5baf
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
55 changed files with 1568 additions and 753 deletions

View File

@ -121,7 +121,9 @@ def get():
"important_kwd") "important_kwd")
def set(): def set():
req = request.json req = request.json
d = {"id": req["chunk_id"], "content_with_weight": req["content_with_weight"]} d = {
"id": req["chunk_id"],
"content_with_weight": req["content_with_weight"]}
d["content_ltks"] = huqie.qie(req["content_with_weight"]) d["content_ltks"] = huqie.qie(req["content_with_weight"])
d["content_sm_ltks"] = huqie.qieqie(d["content_ltks"]) d["content_sm_ltks"] = huqie.qieqie(d["content_ltks"])
d["important_kwd"] = req["important_kwd"] d["important_kwd"] = req["important_kwd"]
@ -140,10 +142,16 @@ def set():
return get_data_error_result(retmsg="Document not found!") return get_data_error_result(retmsg="Document not found!")
if doc.parser_id == ParserType.QA: if doc.parser_id == ParserType.QA:
arr = [t for t in re.split(r"[\n\t]", req["content_with_weight"]) if len(t) > 1] arr = [
if len(arr) != 2: return get_data_error_result(retmsg="Q&A must be separated by TAB/ENTER key.") t for t in re.split(
r"[\n\t]",
req["content_with_weight"]) if len(t) > 1]
if len(arr) != 2:
return get_data_error_result(
retmsg="Q&A must be separated by TAB/ENTER key.")
q, a = rmPrefix(arr[0]), rmPrefix[arr[1]] q, a = rmPrefix(arr[0]), rmPrefix[arr[1]]
d = beAdoc(d, arr[0], arr[1], not any([huqie.is_chinese(t) for t in q + a])) d = beAdoc(d, arr[0], arr[1], not any(
[huqie.is_chinese(t) for t in q + a]))
v, c = embd_mdl.encode([doc.name, req["content_with_weight"]]) v, c = embd_mdl.encode([doc.name, req["content_with_weight"]])
v = 0.1 * v[0] + 0.9 * v[1] if doc.parser_id != ParserType.QA else v[1] v = 0.1 * v[0] + 0.9 * v[1] if doc.parser_id != ParserType.QA else v[1]
@ -177,7 +185,8 @@ def switch():
def rm(): def rm():
req = request.json req = request.json
try: try:
if not ELASTICSEARCH.deleteByQuery(Q("ids", values=req["chunk_ids"]), search.index_name(current_user.id)): if not ELASTICSEARCH.deleteByQuery(
Q("ids", values=req["chunk_ids"]), search.index_name(current_user.id)):
return get_data_error_result(retmsg="Index updating failure") return get_data_error_result(retmsg="Index updating failure")
return get_json_result(data=True) return get_json_result(data=True)
except Exception as e: except Exception as e:

View File

@ -100,7 +100,10 @@ def rm():
def list_convsersation(): def list_convsersation():
dialog_id = request.args["dialog_id"] dialog_id = request.args["dialog_id"]
try: try:
convs = ConversationService.query(dialog_id=dialog_id, order_by=ConversationService.model.create_time, reverse=True) convs = ConversationService.query(
dialog_id=dialog_id,
order_by=ConversationService.model.create_time,
reverse=True)
convs = [d.to_dict() for d in convs] convs = [d.to_dict() for d in convs]
return get_json_result(data=convs) return get_json_result(data=convs)
except Exception as e: except Exception as e:
@ -111,19 +114,24 @@ def message_fit_in(msg, max_length=4000):
def count(): def count():
nonlocal msg nonlocal msg
tks_cnts = [] tks_cnts = []
for m in msg: tks_cnts.append({"role": m["role"], "count": num_tokens_from_string(m["content"])}) for m in msg:
tks_cnts.append(
{"role": m["role"], "count": num_tokens_from_string(m["content"])})
total = 0 total = 0
for m in tks_cnts: total += m["count"] for m in tks_cnts:
total += m["count"]
return total return total
c = count() c = count()
if c < max_length: return c, msg if c < max_length:
return c, msg
msg_ = [m for m in msg[:-1] if m.role == "system"] msg_ = [m for m in msg[:-1] if m.role == "system"]
msg_.append(msg[-1]) msg_.append(msg[-1])
msg = msg_ msg = msg_
c = count() c = count()
if c < max_length: return c, msg if c < max_length:
return c, msg
ll = num_tokens_from_string(msg_[0].content) ll = num_tokens_from_string(msg_[0].content)
l = num_tokens_from_string(msg_[-1].content) l = num_tokens_from_string(msg_[-1].content)
@ -146,8 +154,10 @@ def completion():
req = request.json req = request.json
msg = [] msg = []
for m in req["messages"]: for m in req["messages"]:
if m["role"] == "system": continue if m["role"] == "system":
if m["role"] == "assistant" and not msg: continue continue
if m["role"] == "assistant" and not msg:
continue
msg.append({"role": m["role"], "content": m["content"]}) msg.append({"role": m["role"], "content": m["content"]})
try: try:
e, conv = ConversationService.get_by_id(req["conversation_id"]) e, conv = ConversationService.get_by_id(req["conversation_id"])
@ -160,7 +170,8 @@ def completion():
del req["conversation_id"] del req["conversation_id"]
del req["messages"] del req["messages"]
ans = chat(dia, msg, **req) ans = chat(dia, msg, **req)
if not conv.reference: conv.reference = [] if not conv.reference:
conv.reference = []
conv.reference.append(ans["reference"]) conv.reference.append(ans["reference"])
conv.message.append({"role": "assistant", "content": ans["answer"]}) conv.message.append({"role": "assistant", "content": ans["answer"]})
ConversationService.update_by_id(conv.id, conv.to_dict()) ConversationService.update_by_id(conv.id, conv.to_dict())
@ -180,52 +191,67 @@ def chat(dialog, messages, **kwargs):
chat_mdl = LLMBundle(dialog.tenant_id, LLMType.CHAT, dialog.llm_id) chat_mdl = LLMBundle(dialog.tenant_id, LLMType.CHAT, dialog.llm_id)
field_map = KnowledgebaseService.get_field_map(dialog.kb_ids) field_map = KnowledgebaseService.get_field_map(dialog.kb_ids)
## try to use sql if field mapping is good to go # try to use sql if field mapping is good to go
if field_map: if field_map:
chat_logger.info("Use SQL to retrieval:{}".format(questions[-1])) chat_logger.info("Use SQL to retrieval:{}".format(questions[-1]))
return use_sql(questions[-1], field_map, dialog.tenant_id, chat_mdl) return use_sql(questions[-1], field_map, dialog.tenant_id, chat_mdl)
prompt_config = dialog.prompt_config prompt_config = dialog.prompt_config
for p in prompt_config["parameters"]: for p in prompt_config["parameters"]:
if p["key"] == "knowledge": continue if p["key"] == "knowledge":
if p["key"] not in kwargs and not p["optional"]: raise KeyError("Miss parameter: " + p["key"]) continue
if p["key"] not in kwargs and not p["optional"]:
raise KeyError("Miss parameter: " + p["key"])
if p["key"] not in kwargs: if p["key"] not in kwargs:
prompt_config["system"] = prompt_config["system"].replace("{%s}" % p["key"], " ") prompt_config["system"] = prompt_config["system"].replace(
"{%s}" % p["key"], " ")
for _ in range(len(questions)//2): for _ in range(len(questions) // 2):
questions.append(questions[-1]) questions.append(questions[-1])
if "knowledge" not in [p["key"] for p in prompt_config["parameters"]]: if "knowledge" not in [p["key"] for p in prompt_config["parameters"]]:
kbinfos = {"total":0, "chunks":[],"doc_aggs":[]} kbinfos = {"total": 0, "chunks": [], "doc_aggs": []}
else: else:
kbinfos = retrievaler.retrieval(" ".join(questions), embd_mdl, dialog.tenant_id, dialog.kb_ids, 1, dialog.top_n, kbinfos = retrievaler.retrieval(" ".join(questions), embd_mdl, dialog.tenant_id, dialog.kb_ids, 1, dialog.top_n,
dialog.similarity_threshold, dialog.similarity_threshold,
dialog.vector_similarity_weight, top=1024, aggs=False) dialog.vector_similarity_weight, top=1024, aggs=False)
knowledges = [ck["content_with_weight"] for ck in kbinfos["chunks"]] knowledges = [ck["content_with_weight"] for ck in kbinfos["chunks"]]
chat_logger.info("{}->{}".format(" ".join(questions), "\n->".join(knowledges))) chat_logger.info(
"{}->{}".format(" ".join(questions), "\n->".join(knowledges)))
if not knowledges and prompt_config.get("empty_response"): if not knowledges and prompt_config.get("empty_response"):
return {"answer": prompt_config["empty_response"], "reference": kbinfos} return {
"answer": prompt_config["empty_response"], "reference": kbinfos}
kwargs["knowledge"] = "\n".join(knowledges) kwargs["knowledge"] = "\n".join(knowledges)
gen_conf = dialog.llm_setting gen_conf = dialog.llm_setting
msg = [{"role": m["role"], "content": m["content"]} for m in messages if m["role"] != "system"] msg = [{"role": m["role"], "content": m["content"]}
for m in messages if m["role"] != "system"]
used_token_count, msg = message_fit_in(msg, int(llm.max_tokens * 0.97)) used_token_count, msg = message_fit_in(msg, int(llm.max_tokens * 0.97))
if "max_tokens" in gen_conf: if "max_tokens" in gen_conf:
gen_conf["max_tokens"] = min(gen_conf["max_tokens"], llm.max_tokens - used_token_count) gen_conf["max_tokens"] = min(
answer = chat_mdl.chat(prompt_config["system"].format(**kwargs), msg, gen_conf) gen_conf["max_tokens"],
chat_logger.info("User: {}|Assistant: {}".format(msg[-1]["content"], answer)) llm.max_tokens - used_token_count)
answer = chat_mdl.chat(
prompt_config["system"].format(
**kwargs), msg, gen_conf)
chat_logger.info("User: {}|Assistant: {}".format(
msg[-1]["content"], answer))
if knowledges: if knowledges:
answer, idx = retrievaler.insert_citations(answer, answer, idx = retrievaler.insert_citations(answer,
[ck["content_ltks"] for ck in kbinfos["chunks"]], [ck["content_ltks"]
[ck["vector"] for ck in kbinfos["chunks"]], for ck in kbinfos["chunks"]],
[ck["vector"]
for ck in kbinfos["chunks"]],
embd_mdl, embd_mdl,
tkweight=1 - dialog.vector_similarity_weight, tkweight=1 - dialog.vector_similarity_weight,
vtweight=dialog.vector_similarity_weight) vtweight=dialog.vector_similarity_weight)
idx = set([kbinfos["chunks"][int(i)]["doc_id"] for i in idx]) idx = set([kbinfos["chunks"][int(i)]["doc_id"] for i in idx])
kbinfos["doc_aggs"] = [d for d in kbinfos["doc_aggs"] if d["doc_id"] in idx] kbinfos["doc_aggs"] = [
d for d in kbinfos["doc_aggs"] if d["doc_id"] in idx]
for c in kbinfos["chunks"]: for c in kbinfos["chunks"]:
if c.get("vector"): del c["vector"] if c.get("vector"):
del c["vector"]
return {"answer": answer, "reference": kbinfos} return {"answer": answer, "reference": kbinfos}
@ -245,9 +271,11 @@ def use_sql(question, field_map, tenant_id, chat_mdl):
question question
) )
tried_times = 0 tried_times = 0
def get_table(): def get_table():
nonlocal sys_prompt, user_promt, question, tried_times nonlocal sys_prompt, user_promt, question, tried_times
sql = chat_mdl.chat(sys_prompt, [{"role": "user", "content": user_promt}], {"temperature": 0.06}) sql = chat_mdl.chat(sys_prompt, [{"role": "user", "content": user_promt}], {
"temperature": 0.06})
print(user_promt, sql) print(user_promt, sql)
chat_logger.info(f"{question}”==>{user_promt} get SQL: {sql}") chat_logger.info(f"{question}”==>{user_promt} get SQL: {sql}")
sql = re.sub(r"[\r\n]+", " ", sql.lower()) sql = re.sub(r"[\r\n]+", " ", sql.lower())
@ -262,8 +290,10 @@ def use_sql(question, field_map, tenant_id, chat_mdl):
else: else:
flds = [] flds = []
for k in field_map.keys(): for k in field_map.keys():
if k in forbidden_select_fields4resume:continue if k in forbidden_select_fields4resume:
if len(flds) > 11:break continue
if len(flds) > 11:
break
flds.append(k) flds.append(k)
sql = "select doc_id,docnm_kwd," + ",".join(flds) + sql[8:] sql = "select doc_id,docnm_kwd," + ",".join(flds) + sql[8:]
@ -302,16 +332,24 @@ def use_sql(question, field_map, tenant_id, chat_mdl):
chat_logger.info("GET table: {}".format(tbl)) chat_logger.info("GET table: {}".format(tbl))
print(tbl) print(tbl)
if tbl.get("error") or len(tbl["rows"]) == 0: return None, None if tbl.get("error") or len(tbl["rows"]) == 0:
return None, None
docid_idx = set([ii for ii, c in enumerate(tbl["columns"]) if c["name"] == "doc_id"]) docid_idx = set([ii for ii, c in enumerate(
docnm_idx = set([ii for ii, c in enumerate(tbl["columns"]) if c["name"] == "docnm_kwd"]) tbl["columns"]) if c["name"] == "doc_id"])
clmn_idx = [ii for ii in range(len(tbl["columns"])) if ii not in (docid_idx | docnm_idx)] docnm_idx = set([ii for ii, c in enumerate(
tbl["columns"]) if c["name"] == "docnm_kwd"])
clmn_idx = [ii for ii in range(
len(tbl["columns"])) if ii not in (docid_idx | docnm_idx)]
# compose markdown table # compose markdown table
clmns = "|"+"|".join([re.sub(r"(/.*|[^]+)", "", field_map.get(tbl["columns"][i]["name"], tbl["columns"][i]["name"])) for i in clmn_idx]) + ("|Source|" if docid_idx and docid_idx else "|") clmns = "|" + "|".join([re.sub(r"(/.*|[^]+)", "", field_map.get(tbl["columns"][i]["name"],
line = "|"+"|".join(["------" for _ in range(len(clmn_idx))]) + ("|------|" if docid_idx and docid_idx else "") tbl["columns"][i]["name"])) for i in clmn_idx]) + ("|Source|" if docid_idx and docid_idx else "|")
rows = ["|"+"|".join([rmSpace(str(r[i])) for i in clmn_idx]).replace("None", " ") + "|" for r in tbl["rows"]] line = "|" + "|".join(["------" for _ in range(len(clmn_idx))]) + \
("|------|" if docid_idx and docid_idx else "")
rows = ["|" +
"|".join([rmSpace(str(r[i])) for i in clmn_idx]).replace("None", " ") +
"|" for r in tbl["rows"]]
if not docid_idx or not docnm_idx: if not docid_idx or not docnm_idx:
chat_logger.warning("SQL missing field: " + sql) chat_logger.warning("SQL missing field: " + sql)
return "\n".join([clmns, line, "\n".join(rows)]), [] return "\n".join([clmns, line, "\n".join(rows)]), []
@ -328,5 +366,5 @@ def use_sql(question, field_map, tenant_id, chat_mdl):
return { return {
"answer": "\n".join([clmns, line, rows]), "answer": "\n".join([clmns, line, rows]),
"reference": {"chunks": [{"doc_id": r[docid_idx], "docnm_kwd": r[docnm_idx]} for r in tbl["rows"]], "reference": {"chunks": [{"doc_id": r[docid_idx], "docnm_kwd": r[docnm_idx]} for r in tbl["rows"]],
"doc_aggs":[{"doc_id": did, "doc_name": d["doc_name"], "count": d["count"]} for did, d in doc_aggs.items()]} "doc_aggs": [{"doc_id": did, "doc_name": d["doc_name"], "count": d["count"]} for did, d in doc_aggs.items()]}
} }

View File

@ -55,7 +55,8 @@ def set_dialog():
} }
prompt_config = req.get("prompt_config", default_prompt) prompt_config = req.get("prompt_config", default_prompt)
if not prompt_config["system"]: prompt_config["system"] = default_prompt["system"] if not prompt_config["system"]:
prompt_config["system"] = default_prompt["system"]
# if len(prompt_config["parameters"]) < 1: # if len(prompt_config["parameters"]) < 1:
# prompt_config["parameters"] = default_prompt["parameters"] # prompt_config["parameters"] = default_prompt["parameters"]
# for p in prompt_config["parameters"]: # for p in prompt_config["parameters"]:
@ -63,16 +64,21 @@ def set_dialog():
# else: prompt_config["parameters"].append(default_prompt["parameters"][0]) # else: prompt_config["parameters"].append(default_prompt["parameters"][0])
for p in prompt_config["parameters"]: for p in prompt_config["parameters"]:
if p["optional"]: continue if p["optional"]:
continue
if prompt_config["system"].find("{%s}" % p["key"]) < 0: if prompt_config["system"].find("{%s}" % p["key"]) < 0:
return get_data_error_result(retmsg="Parameter '{}' is not used".format(p["key"])) return get_data_error_result(
retmsg="Parameter '{}' is not used".format(p["key"]))
try: try:
e, tenant = TenantService.get_by_id(current_user.id) e, tenant = TenantService.get_by_id(current_user.id)
if not e: return get_data_error_result(retmsg="Tenant not found!") if not e:
return get_data_error_result(retmsg="Tenant not found!")
llm_id = req.get("llm_id", tenant.llm_id) llm_id = req.get("llm_id", tenant.llm_id)
if not dialog_id: if not dialog_id:
if not req.get("kb_ids"):return get_data_error_result(retmsg="Fail! Please select knowledgebase!") if not req.get("kb_ids"):
return get_data_error_result(
retmsg="Fail! Please select knowledgebase!")
dia = { dia = {
"id": get_uuid(), "id": get_uuid(),
"tenant_id": current_user.id, "tenant_id": current_user.id,
@ -86,17 +92,21 @@ def set_dialog():
"similarity_threshold": similarity_threshold, "similarity_threshold": similarity_threshold,
"vector_similarity_weight": vector_similarity_weight "vector_similarity_weight": vector_similarity_weight
} }
if not DialogService.save(**dia): return get_data_error_result(retmsg="Fail to new a dialog!") if not DialogService.save(**dia):
return get_data_error_result(retmsg="Fail to new a dialog!")
e, dia = DialogService.get_by_id(dia["id"]) e, dia = DialogService.get_by_id(dia["id"])
if not e: return get_data_error_result(retmsg="Fail to new a dialog!") if not e:
return get_data_error_result(retmsg="Fail to new a dialog!")
return get_json_result(data=dia.to_json()) return get_json_result(data=dia.to_json())
else: else:
del req["dialog_id"] del req["dialog_id"]
if "kb_names" in req: del req["kb_names"] if "kb_names" in req:
del req["kb_names"]
if not DialogService.update_by_id(dialog_id, req): if not DialogService.update_by_id(dialog_id, req):
return get_data_error_result(retmsg="Dialog not found!") return get_data_error_result(retmsg="Dialog not found!")
e, dia = DialogService.get_by_id(dialog_id) e, dia = DialogService.get_by_id(dialog_id)
if not e: return get_data_error_result(retmsg="Fail to update a dialog!") if not e:
return get_data_error_result(retmsg="Fail to update a dialog!")
dia = dia.to_dict() dia = dia.to_dict()
dia["kb_ids"], dia["kb_names"] = get_kb_names(dia["kb_ids"]) dia["kb_ids"], dia["kb_names"] = get_kb_names(dia["kb_ids"])
return get_json_result(data=dia) return get_json_result(data=dia)
@ -110,7 +120,8 @@ def get():
dialog_id = request.args["dialog_id"] dialog_id = request.args["dialog_id"]
try: try:
e, dia = DialogService.get_by_id(dialog_id) e, dia = DialogService.get_by_id(dialog_id)
if not e: return get_data_error_result(retmsg="Dialog not found!") if not e:
return get_data_error_result(retmsg="Dialog not found!")
dia = dia.to_dict() dia = dia.to_dict()
dia["kb_ids"], dia["kb_names"] = get_kb_names(dia["kb_ids"]) dia["kb_ids"], dia["kb_names"] = get_kb_names(dia["kb_ids"])
return get_json_result(data=dia) return get_json_result(data=dia)
@ -122,7 +133,8 @@ def get_kb_names(kb_ids):
ids, nms = [], [] ids, nms = [], []
for kid in kb_ids: for kid in kb_ids:
e, kb = KnowledgebaseService.get_by_id(kid) e, kb = KnowledgebaseService.get_by_id(kid)
if not e or kb.status != StatusEnum.VALID.value: continue if not e or kb.status != StatusEnum.VALID.value:
continue
ids.append(kid) ids.append(kid)
nms.append(kb.name) nms.append(kb.name)
return ids, nms return ids, nms
@ -132,7 +144,11 @@ def get_kb_names(kb_ids):
@login_required @login_required
def list(): def list():
try: try:
diags = DialogService.query(tenant_id=current_user.id, status=StatusEnum.VALID.value, reverse=True, order_by=DialogService.model.create_time) diags = DialogService.query(
tenant_id=current_user.id,
status=StatusEnum.VALID.value,
reverse=True,
order_by=DialogService.model.create_time)
diags = [d.to_dict() for d in diags] diags = [d.to_dict() for d in diags]
for d in diags: for d in diags:
d["kb_ids"], d["kb_names"] = get_kb_names(d["kb_ids"]) d["kb_ids"], d["kb_names"] = get_kb_names(d["kb_ids"])
@ -147,7 +163,8 @@ def list():
def rm(): def rm():
req = request.json req = request.json
try: try:
DialogService.update_many_by_id([{"id": id, "status": StatusEnum.INVALID.value} for id in req["dialog_ids"]]) DialogService.update_many_by_id(
[{"id": id, "status": StatusEnum.INVALID.value} for id in req["dialog_ids"]])
return get_json_result(data=True) return get_json_result(data=True)
except Exception as e: except Exception as e:
return server_error_response(e) return server_error_response(e)

View File

@ -57,6 +57,9 @@ def upload():
if not e: if not e:
return get_data_error_result( return get_data_error_result(
retmsg="Can't find this knowledgebase!") retmsg="Can't find this knowledgebase!")
if DocumentService.get_doc_count(kb.tenant_id) >= 128:
return get_data_error_result(
retmsg="Exceed the maximum file number of a free user!")
filename = duplicate_name( filename = duplicate_name(
DocumentService.query, DocumentService.query,
@ -215,9 +218,11 @@ def rm():
tenant_id = DocumentService.get_tenant_id(req["doc_id"]) tenant_id = DocumentService.get_tenant_id(req["doc_id"])
if not tenant_id: if not tenant_id:
return get_data_error_result(retmsg="Tenant not found!") return get_data_error_result(retmsg="Tenant not found!")
ELASTICSEARCH.deleteByQuery(Q("match", doc_id=doc.id), idxnm=search.index_name(tenant_id)) ELASTICSEARCH.deleteByQuery(
Q("match", doc_id=doc.id), idxnm=search.index_name(tenant_id))
DocumentService.increment_chunk_num(doc.id, doc.kb_id, doc.token_num * -1, doc.chunk_num * -1, 0) DocumentService.increment_chunk_num(
doc.id, doc.kb_id, doc.token_num * -1, doc.chunk_num * -1, 0)
if not DocumentService.delete(doc): if not DocumentService.delete(doc):
return get_data_error_result( return get_data_error_result(
retmsg="Database error (Document removal)!") retmsg="Database error (Document removal)!")
@ -245,7 +250,8 @@ def run():
tenant_id = DocumentService.get_tenant_id(id) tenant_id = DocumentService.get_tenant_id(id)
if not tenant_id: if not tenant_id:
return get_data_error_result(retmsg="Tenant not found!") return get_data_error_result(retmsg="Tenant not found!")
ELASTICSEARCH.deleteByQuery(Q("match", doc_id=id), idxnm=search.index_name(tenant_id)) ELASTICSEARCH.deleteByQuery(
Q("match", doc_id=id), idxnm=search.index_name(tenant_id))
return get_json_result(data=True) return get_json_result(data=True)
except Exception as e: except Exception as e:
@ -261,7 +267,8 @@ def rename():
e, doc = DocumentService.get_by_id(req["doc_id"]) e, doc = DocumentService.get_by_id(req["doc_id"])
if not e: if not e:
return get_data_error_result(retmsg="Document not found!") return get_data_error_result(retmsg="Document not found!")
if pathlib.Path(req["name"].lower()).suffix != pathlib.Path(doc.name.lower()).suffix: if pathlib.Path(req["name"].lower()).suffix != pathlib.Path(
doc.name.lower()).suffix:
return get_json_result( return get_json_result(
data=False, data=False,
retmsg="The extension of file can't be changed", retmsg="The extension of file can't be changed",
@ -294,7 +301,10 @@ def get(doc_id):
if doc.type == FileType.VISUAL.value: if doc.type == FileType.VISUAL.value:
response.headers.set('Content-Type', 'image/%s' % ext.group(1)) response.headers.set('Content-Type', 'image/%s' % ext.group(1))
else: else:
response.headers.set('Content-Type', 'application/%s' % ext.group(1)) response.headers.set(
'Content-Type',
'application/%s' %
ext.group(1))
return response return response
except Exception as e: except Exception as e:
return server_error_response(e) return server_error_response(e)
@ -313,9 +323,11 @@ def change_parser():
if "parser_config" in req: if "parser_config" in req:
if req["parser_config"] == doc.parser_config: if req["parser_config"] == doc.parser_config:
return get_json_result(data=True) return get_json_result(data=True)
else: return get_json_result(data=True) else:
return get_json_result(data=True)
if doc.type == FileType.VISUAL or re.search(r"\.(ppt|pptx|pages)$", doc.name): if doc.type == FileType.VISUAL or re.search(
r"\.(ppt|pptx|pages)$", doc.name):
return get_data_error_result(retmsg="Not supported yet!") return get_data_error_result(retmsg="Not supported yet!")
e = DocumentService.update_by_id(doc.id, e = DocumentService.update_by_id(doc.id,
@ -332,7 +344,8 @@ def change_parser():
tenant_id = DocumentService.get_tenant_id(req["doc_id"]) tenant_id = DocumentService.get_tenant_id(req["doc_id"])
if not tenant_id: if not tenant_id:
return get_data_error_result(retmsg="Tenant not found!") return get_data_error_result(retmsg="Tenant not found!")
ELASTICSEARCH.deleteByQuery(Q("match", doc_id=doc.id), idxnm=search.index_name(tenant_id)) ELASTICSEARCH.deleteByQuery(
Q("match", doc_id=doc.id), idxnm=search.index_name(tenant_id))
return get_json_result(data=True) return get_json_result(data=True)
except Exception as e: except Exception as e:

View File

@ -33,15 +33,21 @@ from api.utils.api_utils import get_json_result
def create(): def create():
req = request.json req = request.json
req["name"] = req["name"].strip() req["name"] = req["name"].strip()
req["name"] = duplicate_name(KnowledgebaseService.query, name=req["name"], tenant_id=current_user.id, status=StatusEnum.VALID.value) req["name"] = duplicate_name(
KnowledgebaseService.query,
name=req["name"],
tenant_id=current_user.id,
status=StatusEnum.VALID.value)
try: try:
req["id"] = get_uuid() req["id"] = get_uuid()
req["tenant_id"] = current_user.id req["tenant_id"] = current_user.id
req["created_by"] = current_user.id req["created_by"] = current_user.id
e, t = TenantService.get_by_id(current_user.id) e, t = TenantService.get_by_id(current_user.id)
if not e: return get_data_error_result(retmsg="Tenant not found.") if not e:
return get_data_error_result(retmsg="Tenant not found.")
req["embd_id"] = t.embd_id req["embd_id"] = t.embd_id
if not KnowledgebaseService.save(**req): return get_data_error_result() if not KnowledgebaseService.save(**req):
return get_data_error_result()
return get_json_result(data={"kb_id": req["id"]}) return get_json_result(data={"kb_id": req["id"]})
except Exception as e: except Exception as e:
return server_error_response(e) return server_error_response(e)
@ -54,21 +60,29 @@ def update():
req = request.json req = request.json
req["name"] = req["name"].strip() req["name"] = req["name"].strip()
try: try:
if not KnowledgebaseService.query(created_by=current_user.id, id=req["kb_id"]): if not KnowledgebaseService.query(
return get_json_result(data=False, retmsg=f'Only owner of knowledgebase authorized for this operation.', retcode=RetCode.OPERATING_ERROR) created_by=current_user.id, id=req["kb_id"]):
return get_json_result(
data=False, retmsg=f'Only owner of knowledgebase authorized for this operation.', retcode=RetCode.OPERATING_ERROR)
e, kb = KnowledgebaseService.get_by_id(req["kb_id"]) e, kb = KnowledgebaseService.get_by_id(req["kb_id"])
if not e: return get_data_error_result(retmsg="Can't find this knowledgebase!") if not e:
return get_data_error_result(
retmsg="Can't find this knowledgebase!")
if req["name"].lower() != kb.name.lower() \ if req["name"].lower() != kb.name.lower() \
and len(KnowledgebaseService.query(name=req["name"], tenant_id=current_user.id, status=StatusEnum.VALID.value))>1: and len(KnowledgebaseService.query(name=req["name"], tenant_id=current_user.id, status=StatusEnum.VALID.value)) > 1:
return get_data_error_result(retmsg="Duplicated knowledgebase name.") return get_data_error_result(
retmsg="Duplicated knowledgebase name.")
del req["kb_id"] del req["kb_id"]
if not KnowledgebaseService.update_by_id(kb.id, req): return get_data_error_result() if not KnowledgebaseService.update_by_id(kb.id, req):
return get_data_error_result()
e, kb = KnowledgebaseService.get_by_id(kb.id) e, kb = KnowledgebaseService.get_by_id(kb.id)
if not e: return get_data_error_result(retmsg="Database error (Knowledgebase rename)!") if not e:
return get_data_error_result(
retmsg="Database error (Knowledgebase rename)!")
return get_json_result(data=kb.to_json()) return get_json_result(data=kb.to_json())
except Exception as e: except Exception as e:
@ -81,7 +95,9 @@ def detail():
kb_id = request.args["kb_id"] kb_id = request.args["kb_id"]
try: try:
kb = KnowledgebaseService.get_detail(kb_id) kb = KnowledgebaseService.get_detail(kb_id)
if not kb: return get_data_error_result(retmsg="Can't find this knowledgebase!") if not kb:
return get_data_error_result(
retmsg="Can't find this knowledgebase!")
return get_json_result(data=kb) return get_json_result(data=kb)
except Exception as e: except Exception as e:
return server_error_response(e) return server_error_response(e)
@ -96,7 +112,8 @@ def list():
desc = request.args.get("desc", True) desc = request.args.get("desc", True)
try: try:
tenants = TenantService.get_joined_tenants_by_user_id(current_user.id) tenants = TenantService.get_joined_tenants_by_user_id(current_user.id)
kbs = KnowledgebaseService.get_by_tenant_ids([m["tenant_id"] for m in tenants], current_user.id, page_number, items_per_page, orderby, desc) kbs = KnowledgebaseService.get_by_tenant_ids(
[m["tenant_id"] for m in tenants], current_user.id, page_number, items_per_page, orderby, desc)
return get_json_result(data=kbs) return get_json_result(data=kbs)
except Exception as e: except Exception as e:
return server_error_response(e) return server_error_response(e)
@ -108,10 +125,15 @@ def list():
def rm(): def rm():
req = request.json req = request.json
try: try:
if not KnowledgebaseService.query(created_by=current_user.id, id=req["kb_id"]): if not KnowledgebaseService.query(
return get_json_result(data=False, retmsg=f'Only owner of knowledgebase authorized for this operation.', retcode=RetCode.OPERATING_ERROR) created_by=current_user.id, id=req["kb_id"]):
return get_json_result(
data=False, retmsg=f'Only owner of knowledgebase authorized for this operation.', retcode=RetCode.OPERATING_ERROR)
if not KnowledgebaseService.update_by_id(req["kb_id"], {"status": StatusEnum.INVALID.value}): return get_data_error_result(retmsg="Database error (Knowledgebase removal)!") if not KnowledgebaseService.update_by_id(
req["kb_id"], {"status": StatusEnum.INVALID.value}):
return get_data_error_result(
retmsg="Database error (Knowledgebase removal)!")
return get_json_result(data=True) return get_json_result(data=True)
except Exception as e: except Exception as e:
return server_error_response(e) return server_error_response(e)

View File

@ -48,30 +48,42 @@ def set_api_key():
req["api_key"], llm.llm_name) req["api_key"], llm.llm_name)
try: try:
arr, tc = mdl.encode(["Test if the api key is available"]) arr, tc = mdl.encode(["Test if the api key is available"])
if len(arr[0]) == 0 or tc ==0: raise Exception("Fail") if len(arr[0]) == 0 or tc == 0:
raise Exception("Fail")
except Exception as e: except Exception as e:
msg += f"\nFail to access embedding model({llm.llm_name}) using this api key." msg += f"\nFail to access embedding model({llm.llm_name}) using this api key."
elif not chat_passed and llm.model_type == LLMType.CHAT.value: elif not chat_passed and llm.model_type == LLMType.CHAT.value:
mdl = ChatModel[factory]( mdl = ChatModel[factory](
req["api_key"], llm.llm_name) req["api_key"], llm.llm_name)
try: try:
m, tc = mdl.chat(None, [{"role": "user", "content": "Hello! How are you doing!"}], {"temperature": 0.9}) m, tc = mdl.chat(None, [{"role": "user", "content": "Hello! How are you doing!"}], {
if not tc: raise Exception(m) "temperature": 0.9})
if not tc:
raise Exception(m)
chat_passed = True chat_passed = True
except Exception as e: except Exception as e:
msg += f"\nFail to access model({llm.llm_name}) using this api key." + str(e) msg += f"\nFail to access model({llm.llm_name}) using this api key." + str(
e)
if msg: return get_data_error_result(retmsg=msg) if msg:
return get_data_error_result(retmsg=msg)
llm = { llm = {
"api_key": req["api_key"] "api_key": req["api_key"]
} }
for n in ["model_type", "llm_name"]: for n in ["model_type", "llm_name"]:
if n in req: llm[n] = req[n] if n in req:
llm[n] = req[n]
if not TenantLLMService.filter_update([TenantLLM.tenant_id==current_user.id, TenantLLM.llm_factory==factory], llm): if not TenantLLMService.filter_update(
[TenantLLM.tenant_id == current_user.id, TenantLLM.llm_factory == factory], llm):
for llm in LLMService.query(fid=factory): for llm in LLMService.query(fid=factory):
TenantLLMService.save(tenant_id=current_user.id, llm_factory=factory, llm_name=llm.llm_name, model_type=llm.model_type, api_key=req["api_key"]) TenantLLMService.save(
tenant_id=current_user.id,
llm_factory=factory,
llm_name=llm.llm_name,
model_type=llm.model_type,
api_key=req["api_key"])
return get_json_result(data=True) return get_json_result(data=True)
@ -105,17 +117,19 @@ def list():
objs = TenantLLMService.query(tenant_id=current_user.id) objs = TenantLLMService.query(tenant_id=current_user.id)
facts = set([o.to_dict()["llm_factory"] for o in objs if o.api_key]) facts = set([o.to_dict()["llm_factory"] for o in objs if o.api_key])
llms = LLMService.get_all() llms = LLMService.get_all()
llms = [m.to_dict() for m in llms if m.status == StatusEnum.VALID.value] llms = [m.to_dict()
for m in llms if m.status == StatusEnum.VALID.value]
for m in llms: for m in llms:
m["available"] = m["fid"] in facts or m["llm_name"].lower() == "flag-embedding" m["available"] = m["fid"] in facts or m["llm_name"].lower() == "flag-embedding"
res = {} res = {}
for m in llms: for m in llms:
if model_type and m["model_type"] != model_type: continue if model_type and m["model_type"] != model_type:
if m["fid"] not in res: res[m["fid"]] = [] continue
if m["fid"] not in res:
res[m["fid"]] = []
res[m["fid"]].append(m) res[m["fid"]].append(m)
return get_json_result(data=res) return get_json_result(data=res)
except Exception as e: except Exception as e:
return server_error_response(e) return server_error_response(e)

View File

@ -40,13 +40,16 @@ def login():
email = request.json.get('email', "") email = request.json.get('email', "")
users = UserService.query(email=email) users = UserService.query(email=email)
if not users: return get_json_result(data=False, retcode=RetCode.AUTHENTICATION_ERROR, retmsg=f'This Email is not registered!') if not users:
return get_json_result(
data=False, retcode=RetCode.AUTHENTICATION_ERROR, retmsg=f'This Email is not registered!')
password = request.json.get('password') password = request.json.get('password')
try: try:
password = decrypt(password) password = decrypt(password)
except: except BaseException:
return get_json_result(data=False, retcode=RetCode.SERVER_ERROR, retmsg='Fail to crypt password') return get_json_result(
data=False, retcode=RetCode.SERVER_ERROR, retmsg='Fail to crypt password')
user = UserService.query_user(email, password) user = UserService.query_user(email, password)
if user: if user:
@ -57,7 +60,8 @@ def login():
msg = "Welcome back!" msg = "Welcome back!"
return cors_reponse(data=response_data, auth=user.get_id(), retmsg=msg) return cors_reponse(data=response_data, auth=user.get_id(), retmsg=msg)
else: else:
return get_json_result(data=False, retcode=RetCode.AUTHENTICATION_ERROR, retmsg='Email and Password do not match!') return get_json_result(data=False, retcode=RetCode.AUTHENTICATION_ERROR,
retmsg='Email and Password do not match!')
@manager.route('/github_callback', methods=['GET']) @manager.route('/github_callback', methods=['GET'])
@ -96,15 +100,17 @@ def github_callback():
"last_login_time": get_format_time(), "last_login_time": get_format_time(),
"is_superuser": False, "is_superuser": False,
}) })
if not users: raise Exception('Register user failure.') if not users:
if len(users) > 1: raise Exception('Same E-mail exist!') raise Exception('Register user failure.')
if len(users) > 1:
raise Exception('Same E-mail exist!')
user = users[0] user = users[0]
login_user(user) login_user(user)
return redirect("/?auth=%s"%user.get_id()) return redirect("/?auth=%s" % user.get_id())
except Exception as e: except Exception as e:
rollback_user_registration(user_id) rollback_user_registration(user_id)
stat_logger.exception(e) stat_logger.exception(e)
return redirect("/?error=%s"%str(e)) return redirect("/?error=%s" % str(e))
user = users[0] user = users[0]
user.access_token = get_uuid() user.access_token = get_uuid()
login_user(user) login_user(user)
@ -114,11 +120,18 @@ def github_callback():
def user_info_from_github(access_token): def user_info_from_github(access_token):
import requests import requests
headers = {"Accept": "application/json", 'Authorization': f"token {access_token}"} headers = {"Accept": "application/json",
res = requests.get(f"https://api.github.com/user?access_token={access_token}", headers=headers) 'Authorization': f"token {access_token}"}
res = requests.get(
f"https://api.github.com/user?access_token={access_token}",
headers=headers)
user_info = res.json() user_info = res.json()
email_info = requests.get(f"https://api.github.com/user/emails?access_token={access_token}", headers=headers).json() email_info = requests.get(
user_info["email"] = next((email for email in email_info if email['primary'] == True), None)["email"] f"https://api.github.com/user/emails?access_token={access_token}",
headers=headers).json()
user_info["email"] = next(
(email for email in email_info if email['primary'] == True),
None)["email"]
return user_info return user_info
@ -138,13 +151,18 @@ def setting_user():
request_data = request.json request_data = request.json
if request_data.get("password"): if request_data.get("password"):
new_password = request_data.get("new_password") new_password = request_data.get("new_password")
if not check_password_hash(current_user.password, decrypt(request_data["password"])): if not check_password_hash(
return get_json_result(data=False, retcode=RetCode.AUTHENTICATION_ERROR, retmsg='Password error!') current_user.password, decrypt(request_data["password"])):
return get_json_result(
data=False, retcode=RetCode.AUTHENTICATION_ERROR, retmsg='Password error!')
if new_password: update_dict["password"] = generate_password_hash(decrypt(new_password)) if new_password:
update_dict["password"] = generate_password_hash(
decrypt(new_password))
for k in request_data.keys(): for k in request_data.keys():
if k in ["password", "new_password"]:continue if k in ["password", "new_password"]:
continue
update_dict[k] = request_data[k] update_dict[k] = request_data[k]
try: try:
@ -152,7 +170,8 @@ def setting_user():
return get_json_result(data=True) return get_json_result(data=True)
except Exception as e: except Exception as e:
stat_logger.exception(e) stat_logger.exception(e)
return get_json_result(data=False, retmsg='Update failure!', retcode=RetCode.EXCEPTION_ERROR) return get_json_result(
data=False, retmsg='Update failure!', retcode=RetCode.EXCEPTION_ERROR)
@manager.route("/info", methods=["GET"]) @manager.route("/info", methods=["GET"])
@ -173,7 +192,7 @@ def rollback_user_registration(user_id):
except Exception as e: except Exception as e:
pass pass
try: try:
TenantLLM.delete().where(TenantLLM.tenant_id==user_id).excute() TenantLLM.delete().where(TenantLLM.tenant_id == user_id).excute()
except Exception as e: except Exception as e:
pass pass
@ -197,9 +216,14 @@ def user_register(user_id, user):
} }
tenant_llm = [] tenant_llm = []
for llm in LLMService.query(fid=LLM_FACTORY): for llm in LLMService.query(fid=LLM_FACTORY):
tenant_llm.append({"tenant_id": user_id, "llm_factory": LLM_FACTORY, "llm_name": llm.llm_name, "model_type":llm.model_type, "api_key": API_KEY}) tenant_llm.append({"tenant_id": user_id,
"llm_factory": LLM_FACTORY,
"llm_name": llm.llm_name,
"model_type": llm.model_type,
"api_key": API_KEY})
if not UserService.save(**user):return if not UserService.save(**user):
return
TenantService.insert(**tenant) TenantService.insert(**tenant)
UserTenantService.insert(**usr_tenant) UserTenantService.insert(**usr_tenant)
TenantLLMService.insert_many(tenant_llm) TenantLLMService.insert_many(tenant_llm)
@ -211,7 +235,8 @@ def user_register(user_id, user):
def user_add(): def user_add():
req = request.json req = request.json
if UserService.query(email=req["email"]): if UserService.query(email=req["email"]):
return get_json_result(data=False, retmsg=f'Email: {req["email"]} has already registered!', retcode=RetCode.OPERATING_ERROR) return get_json_result(
data=False, retmsg=f'Email: {req["email"]} has already registered!', retcode=RetCode.OPERATING_ERROR)
if not re.match(r"^[\w\._-]+@([\w_-]+\.)+[\w-]{2,4}$", req["email"]): if not re.match(r"^[\w\._-]+@([\w_-]+\.)+[\w-]{2,4}$", req["email"]):
return get_json_result(data=False, retmsg=f'Invaliad e-mail: {req["email"]}!', return get_json_result(data=False, retmsg=f'Invaliad e-mail: {req["email"]}!',
retcode=RetCode.OPERATING_ERROR) retcode=RetCode.OPERATING_ERROR)
@ -229,16 +254,19 @@ def user_add():
user_id = get_uuid() user_id = get_uuid()
try: try:
users = user_register(user_id, user_dict) users = user_register(user_id, user_dict)
if not users: raise Exception('Register user failure.') if not users:
if len(users) > 1: raise Exception('Same E-mail exist!') raise Exception('Register user failure.')
if len(users) > 1:
raise Exception('Same E-mail exist!')
user = users[0] user = users[0]
login_user(user) login_user(user)
return cors_reponse(data=user.to_json(), auth=user.get_id(), retmsg="Welcome aboard!") return cors_reponse(data=user.to_json(),
auth=user.get_id(), retmsg="Welcome aboard!")
except Exception as e: except Exception as e:
rollback_user_registration(user_id) rollback_user_registration(user_id)
stat_logger.exception(e) stat_logger.exception(e)
return get_json_result(data=False, retmsg='User registration failure!', retcode=RetCode.EXCEPTION_ERROR) return get_json_result(
data=False, retmsg='User registration failure!', retcode=RetCode.EXCEPTION_ERROR)
@manager.route("/tenant_info", methods=["GET"]) @manager.route("/tenant_info", methods=["GET"])

View File

@ -50,7 +50,13 @@ def singleton(cls, *args, **kw):
CONTINUOUS_FIELD_TYPE = {IntegerField, FloatField, DateTimeField} CONTINUOUS_FIELD_TYPE = {IntegerField, FloatField, DateTimeField}
AUTO_DATE_TIMESTAMP_FIELD_PREFIX = {"create", "start", "end", "update", "read_access", "write_access"} AUTO_DATE_TIMESTAMP_FIELD_PREFIX = {
"create",
"start",
"end",
"update",
"read_access",
"write_access"}
class LongTextField(TextField): class LongTextField(TextField):
@ -73,7 +79,8 @@ class JSONField(LongTextField):
def python_value(self, value): def python_value(self, value):
if not value: if not value:
return self.default_value return self.default_value
return utils.json_loads(value, object_hook=self._object_hook, object_pairs_hook=self._object_pairs_hook) return utils.json_loads(
value, object_hook=self._object_hook, object_pairs_hook=self._object_pairs_hook)
class ListField(JSONField): class ListField(JSONField):
@ -81,7 +88,8 @@ class ListField(JSONField):
class SerializedField(LongTextField): class SerializedField(LongTextField):
def __init__(self, serialized_type=SerializedType.PICKLE, object_hook=None, object_pairs_hook=None, **kwargs): def __init__(self, serialized_type=SerializedType.PICKLE,
object_hook=None, object_pairs_hook=None, **kwargs):
self._serialized_type = serialized_type self._serialized_type = serialized_type
self._object_hook = object_hook self._object_hook = object_hook
self._object_pairs_hook = object_pairs_hook self._object_pairs_hook = object_pairs_hook
@ -95,7 +103,8 @@ class SerializedField(LongTextField):
return None return None
return utils.json_dumps(value, with_type=True) return utils.json_dumps(value, with_type=True)
else: else:
raise ValueError(f"the serialized type {self._serialized_type} is not supported") raise ValueError(
f"the serialized type {self._serialized_type} is not supported")
def python_value(self, value): def python_value(self, value):
if self._serialized_type == SerializedType.PICKLE: if self._serialized_type == SerializedType.PICKLE:
@ -103,9 +112,11 @@ class SerializedField(LongTextField):
elif self._serialized_type == SerializedType.JSON: elif self._serialized_type == SerializedType.JSON:
if value is None: if value is None:
return {} return {}
return utils.json_loads(value, object_hook=self._object_hook, object_pairs_hook=self._object_pairs_hook) return utils.json_loads(
value, object_hook=self._object_hook, object_pairs_hook=self._object_pairs_hook)
else: else:
raise ValueError(f"the serialized type {self._serialized_type} is not supported") raise ValueError(
f"the serialized type {self._serialized_type} is not supported")
def is_continuous_field(cls: typing.Type) -> bool: def is_continuous_field(cls: typing.Type) -> bool:
@ -150,7 +161,8 @@ class BaseModel(Model):
model_dict = self.__dict__['__data__'] model_dict = self.__dict__['__data__']
if not only_primary_with: if not only_primary_with:
return {remove_field_name_prefix(k): v for k, v in model_dict.items()} return {remove_field_name_prefix(
k): v for k, v in model_dict.items()}
human_model_dict = {} human_model_dict = {}
for k in self._meta.primary_key.field_names: for k in self._meta.primary_key.field_names:
@ -184,17 +196,22 @@ class BaseModel(Model):
if is_continuous_field(type(getattr(cls, attr_name))): if is_continuous_field(type(getattr(cls, attr_name))):
if len(f_v) == 2: if len(f_v) == 2:
for i, v in enumerate(f_v): for i, v in enumerate(f_v):
if isinstance(v, str) and f_n in auto_date_timestamp_field(): if isinstance(
v, str) and f_n in auto_date_timestamp_field():
# time type: %Y-%m-%d %H:%M:%S # time type: %Y-%m-%d %H:%M:%S
f_v[i] = utils.date_string_to_timestamp(v) f_v[i] = utils.date_string_to_timestamp(v)
lt_value = f_v[0] lt_value = f_v[0]
gt_value = f_v[1] gt_value = f_v[1]
if lt_value is not None and gt_value is not None: if lt_value is not None and gt_value is not None:
filters.append(cls.getter_by(attr_name).between(lt_value, gt_value)) filters.append(
cls.getter_by(attr_name).between(
lt_value, gt_value))
elif lt_value is not None: elif lt_value is not None:
filters.append(operator.attrgetter(attr_name)(cls) >= lt_value) filters.append(
operator.attrgetter(attr_name)(cls) >= lt_value)
elif gt_value is not None: elif gt_value is not None:
filters.append(operator.attrgetter(attr_name)(cls) <= gt_value) filters.append(
operator.attrgetter(attr_name)(cls) <= gt_value)
else: else:
filters.append(operator.attrgetter(attr_name)(cls) << f_v) filters.append(operator.attrgetter(attr_name)(cls) << f_v)
else: else:
@ -205,9 +222,11 @@ class BaseModel(Model):
if not order_by or not hasattr(cls, f"{order_by}"): if not order_by or not hasattr(cls, f"{order_by}"):
order_by = "create_time" order_by = "create_time"
if reverse is True: if reverse is True:
query_records = query_records.order_by(cls.getter_by(f"{order_by}").desc()) query_records = query_records.order_by(
cls.getter_by(f"{order_by}").desc())
elif reverse is False: elif reverse is False:
query_records = query_records.order_by(cls.getter_by(f"{order_by}").asc()) query_records = query_records.order_by(
cls.getter_by(f"{order_by}").asc())
return [query_record for query_record in query_records] return [query_record for query_record in query_records]
else: else:
return [] return []
@ -215,7 +234,8 @@ class BaseModel(Model):
@classmethod @classmethod
def insert(cls, __data=None, **insert): def insert(cls, __data=None, **insert):
if isinstance(__data, dict) and __data: if isinstance(__data, dict) and __data:
__data[cls._meta.combined["create_time"]] = utils.current_timestamp() __data[cls._meta.combined["create_time"]
] = utils.current_timestamp()
if insert: if insert:
insert["create_time"] = utils.current_timestamp() insert["create_time"] = utils.current_timestamp()
@ -228,7 +248,8 @@ class BaseModel(Model):
if not normalized: if not normalized:
return {} return {}
normalized[cls._meta.combined["update_time"]] = utils.current_timestamp() normalized[cls._meta.combined["update_time"]
] = utils.current_timestamp()
for f_n in AUTO_DATE_TIMESTAMP_FIELD_PREFIX: for f_n in AUTO_DATE_TIMESTAMP_FIELD_PREFIX:
if {f"{f_n}_time", f"{f_n}_date"}.issubset(cls._meta.combined.keys()) and \ if {f"{f_n}_time", f"{f_n}_date"}.issubset(cls._meta.combined.keys()) and \
@ -241,7 +262,8 @@ class BaseModel(Model):
class JsonSerializedField(SerializedField): class JsonSerializedField(SerializedField):
def __init__(self, object_hook=utils.from_dict_hook, object_pairs_hook=None, **kwargs): def __init__(self, object_hook=utils.from_dict_hook,
object_pairs_hook=None, **kwargs):
super(JsonSerializedField, self).__init__(serialized_type=SerializedType.JSON, object_hook=object_hook, super(JsonSerializedField, self).__init__(serialized_type=SerializedType.JSON, object_hook=object_hook,
object_pairs_hook=object_pairs_hook, **kwargs) object_pairs_hook=object_pairs_hook, **kwargs)
@ -251,7 +273,8 @@ class BaseDataBase:
def __init__(self): def __init__(self):
database_config = DATABASE.copy() database_config = DATABASE.copy()
db_name = database_config.pop("name") db_name = database_config.pop("name")
self.database_connection = PooledMySQLDatabase(db_name, **database_config) self.database_connection = PooledMySQLDatabase(
db_name, **database_config)
stat_logger.info('init mysql database on cluster mode successfully') stat_logger.info('init mysql database on cluster mode successfully')
@ -263,7 +286,8 @@ class DatabaseLock:
def lock(self): def lock(self):
# SQL parameters only support %s format placeholders # SQL parameters only support %s format placeholders
cursor = self.db.execute_sql("SELECT GET_LOCK(%s, %s)", (self.lock_name, self.timeout)) cursor = self.db.execute_sql(
"SELECT GET_LOCK(%s, %s)", (self.lock_name, self.timeout))
ret = cursor.fetchone() ret = cursor.fetchone()
if ret[0] == 0: if ret[0] == 0:
raise Exception(f'acquire mysql lock {self.lock_name} timeout') raise Exception(f'acquire mysql lock {self.lock_name} timeout')
@ -273,10 +297,12 @@ class DatabaseLock:
raise Exception(f'failed to acquire lock {self.lock_name}') raise Exception(f'failed to acquire lock {self.lock_name}')
def unlock(self): def unlock(self):
cursor = self.db.execute_sql("SELECT RELEASE_LOCK(%s)", (self.lock_name,)) cursor = self.db.execute_sql(
"SELECT RELEASE_LOCK(%s)", (self.lock_name,))
ret = cursor.fetchone() ret = cursor.fetchone()
if ret[0] == 0: if ret[0] == 0:
raise Exception(f'mysql lock {self.lock_name} was not established by this thread') raise Exception(
f'mysql lock {self.lock_name} was not established by this thread')
elif ret[0] == 1: elif ret[0] == 1:
return True return True
else: else:
@ -350,17 +376,37 @@ class User(DataBaseModel, UserMixin):
access_token = CharField(max_length=255, null=True) access_token = CharField(max_length=255, null=True)
nickname = CharField(max_length=100, null=False, help_text="nicky name") nickname = CharField(max_length=100, null=False, help_text="nicky name")
password = CharField(max_length=255, null=True, help_text="password") password = CharField(max_length=255, null=True, help_text="password")
email = CharField(max_length=255, null=False, help_text="email", index=True) email = CharField(
max_length=255,
null=False,
help_text="email",
index=True)
avatar = TextField(null=True, help_text="avatar base64 string") avatar = TextField(null=True, help_text="avatar base64 string")
language = CharField(max_length=32, null=True, help_text="English|Chinese", default="Chinese") language = CharField(
color_schema = CharField(max_length=32, null=True, help_text="Bright|Dark", default="Bright") max_length=32,
timezone = CharField(max_length=64, null=True, help_text="Timezone", default="UTC+8\tAsia/Shanghai") null=True,
help_text="English|Chinese",
default="Chinese")
color_schema = CharField(
max_length=32,
null=True,
help_text="Bright|Dark",
default="Bright")
timezone = CharField(
max_length=64,
null=True,
help_text="Timezone",
default="UTC+8\tAsia/Shanghai")
last_login_time = DateTimeField(null=True) last_login_time = DateTimeField(null=True)
is_authenticated = CharField(max_length=1, null=False, default="1") is_authenticated = CharField(max_length=1, null=False, default="1")
is_active = CharField(max_length=1, null=False, default="1") is_active = CharField(max_length=1, null=False, default="1")
is_anonymous = CharField(max_length=1, null=False, default="0") is_anonymous = CharField(max_length=1, null=False, default="0")
login_channel = CharField(null=True, help_text="from which user login") login_channel = CharField(null=True, help_text="from which user login")
status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted1: validate)", default="1") status = CharField(
max_length=1,
null=True,
help_text="is it validate(0: wasted1: validate)",
default="1")
is_superuser = BooleanField(null=True, help_text="is root", default=False) is_superuser = BooleanField(null=True, help_text="is root", default=False)
def __str__(self): def __str__(self):
@ -379,12 +425,28 @@ class Tenant(DataBaseModel):
name = CharField(max_length=100, null=True, help_text="Tenant name") name = CharField(max_length=100, null=True, help_text="Tenant name")
public_key = CharField(max_length=255, null=True) public_key = CharField(max_length=255, null=True)
llm_id = CharField(max_length=128, null=False, help_text="default llm ID") llm_id = CharField(max_length=128, null=False, help_text="default llm ID")
embd_id = CharField(max_length=128, null=False, help_text="default embedding model ID") embd_id = CharField(
asr_id = CharField(max_length=128, null=False, help_text="default ASR model ID") max_length=128,
img2txt_id = CharField(max_length=128, null=False, help_text="default image to text model ID") null=False,
parser_ids = CharField(max_length=256, null=False, help_text="document processors") help_text="default embedding model ID")
asr_id = CharField(
max_length=128,
null=False,
help_text="default ASR model ID")
img2txt_id = CharField(
max_length=128,
null=False,
help_text="default image to text model ID")
parser_ids = CharField(
max_length=256,
null=False,
help_text="document processors")
credit = IntegerField(default=512) credit = IntegerField(default=512)
status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted1: validate)", default="1") status = CharField(
max_length=1,
null=True,
help_text="is it validate(0: wasted1: validate)",
default="1")
class Meta: class Meta:
db_table = "tenant" db_table = "tenant"
@ -396,7 +458,11 @@ class UserTenant(DataBaseModel):
tenant_id = CharField(max_length=32, null=False) tenant_id = CharField(max_length=32, null=False)
role = CharField(max_length=32, null=False, help_text="UserTenantRole") role = CharField(max_length=32, null=False, help_text="UserTenantRole")
invited_by = CharField(max_length=32, null=False) invited_by = CharField(max_length=32, null=False)
status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted1: validate)", default="1") status = CharField(
max_length=1,
null=True,
help_text="is it validate(0: wasted1: validate)",
default="1")
class Meta: class Meta:
db_table = "user_tenant" db_table = "user_tenant"
@ -408,17 +474,32 @@ class InvitationCode(DataBaseModel):
visit_time = DateTimeField(null=True) visit_time = DateTimeField(null=True)
user_id = CharField(max_length=32, null=True) user_id = CharField(max_length=32, null=True)
tenant_id = CharField(max_length=32, null=True) tenant_id = CharField(max_length=32, null=True)
status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted1: validate)", default="1") status = CharField(
max_length=1,
null=True,
help_text="is it validate(0: wasted1: validate)",
default="1")
class Meta: class Meta:
db_table = "invitation_code" db_table = "invitation_code"
class LLMFactories(DataBaseModel): class LLMFactories(DataBaseModel):
name = CharField(max_length=128, null=False, help_text="LLM factory name", primary_key=True) name = CharField(
max_length=128,
null=False,
help_text="LLM factory name",
primary_key=True)
logo = TextField(null=True, help_text="llm logo base64") logo = TextField(null=True, help_text="llm logo base64")
tags = CharField(max_length=255, null=False, help_text="LLM, Text Embedding, Image2Text, ASR") tags = CharField(
status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted1: validate)", default="1") max_length=255,
null=False,
help_text="LLM, Text Embedding, Image2Text, ASR")
status = CharField(
max_length=1,
null=True,
help_text="is it validate(0: wasted1: validate)",
default="1")
def __str__(self): def __str__(self):
return self.name return self.name
@ -429,12 +510,27 @@ class LLMFactories(DataBaseModel):
class LLM(DataBaseModel): class LLM(DataBaseModel):
# LLMs dictionary # LLMs dictionary
llm_name = CharField(max_length=128, null=False, help_text="LLM name", index=True, primary_key=True) llm_name = CharField(
model_type = CharField(max_length=128, null=False, help_text="LLM, Text Embedding, Image2Text, ASR") max_length=128,
null=False,
help_text="LLM name",
index=True,
primary_key=True)
model_type = CharField(
max_length=128,
null=False,
help_text="LLM, Text Embedding, Image2Text, ASR")
fid = CharField(max_length=128, null=False, help_text="LLM factory id") fid = CharField(max_length=128, null=False, help_text="LLM factory id")
max_tokens = IntegerField(default=0) max_tokens = IntegerField(default=0)
tags = CharField(max_length=255, null=False, help_text="LLM, Text Embedding, Image2Text, Chat, 32k...") tags = CharField(
status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted1: validate)", default="1") max_length=255,
null=False,
help_text="LLM, Text Embedding, Image2Text, Chat, 32k...")
status = CharField(
max_length=1,
null=True,
help_text="is it validate(0: wasted1: validate)",
default="1")
def __str__(self): def __str__(self):
return self.llm_name return self.llm_name
@ -445,9 +541,19 @@ class LLM(DataBaseModel):
class TenantLLM(DataBaseModel): class TenantLLM(DataBaseModel):
tenant_id = CharField(max_length=32, null=False) tenant_id = CharField(max_length=32, null=False)
llm_factory = CharField(max_length=128, null=False, help_text="LLM factory name") llm_factory = CharField(
model_type = CharField(max_length=128, null=True, help_text="LLM, Text Embedding, Image2Text, ASR") max_length=128,
llm_name = CharField(max_length=128, null=True, help_text="LLM name", default="") null=False,
help_text="LLM factory name")
model_type = CharField(
max_length=128,
null=True,
help_text="LLM, Text Embedding, Image2Text, ASR")
llm_name = CharField(
max_length=128,
null=True,
help_text="LLM name",
default="")
api_key = CharField(max_length=255, null=True, help_text="API KEY") api_key = CharField(max_length=255, null=True, help_text="API KEY")
api_base = CharField(max_length=255, null=True, help_text="API Base") api_base = CharField(max_length=255, null=True, help_text="API Base")
used_tokens = IntegerField(default=0) used_tokens = IntegerField(default=0)
@ -464,11 +570,26 @@ class Knowledgebase(DataBaseModel):
id = CharField(max_length=32, primary_key=True) id = CharField(max_length=32, primary_key=True)
avatar = TextField(null=True, help_text="avatar base64 string") avatar = TextField(null=True, help_text="avatar base64 string")
tenant_id = CharField(max_length=32, null=False) tenant_id = CharField(max_length=32, null=False)
name = CharField(max_length=128, null=False, help_text="KB name", index=True) name = CharField(
language = CharField(max_length=32, null=True, default="Chinese", help_text="English|Chinese") max_length=128,
null=False,
help_text="KB name",
index=True)
language = CharField(
max_length=32,
null=True,
default="Chinese",
help_text="English|Chinese")
description = TextField(null=True, help_text="KB description") description = TextField(null=True, help_text="KB description")
embd_id = CharField(max_length=128, null=False, help_text="default embedding model ID") embd_id = CharField(
permission = CharField(max_length=16, null=False, help_text="me|team", default="me") max_length=128,
null=False,
help_text="default embedding model ID")
permission = CharField(
max_length=16,
null=False,
help_text="me|team",
default="me")
created_by = CharField(max_length=32, null=False) created_by = CharField(max_length=32, null=False)
doc_num = IntegerField(default=0) doc_num = IntegerField(default=0)
token_num = IntegerField(default=0) token_num = IntegerField(default=0)
@ -476,9 +597,17 @@ class Knowledgebase(DataBaseModel):
similarity_threshold = FloatField(default=0.2) similarity_threshold = FloatField(default=0.2)
vector_similarity_weight = FloatField(default=0.3) vector_similarity_weight = FloatField(default=0.3)
parser_id = CharField(max_length=32, null=False, help_text="default parser ID", default=ParserType.NAIVE.value) parser_id = CharField(
parser_config = JSONField(null=False, default={"pages":[[1,1000000]]}) max_length=32,
status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted1: validate)", default="1") null=False,
help_text="default parser ID",
default=ParserType.NAIVE.value)
parser_config = JSONField(null=False, default={"pages": [[1, 1000000]]})
status = CharField(
max_length=1,
null=True,
help_text="is it validate(0: wasted1: validate)",
default="1")
def __str__(self): def __str__(self):
return self.name return self.name
@ -491,22 +620,50 @@ class Document(DataBaseModel):
id = CharField(max_length=32, primary_key=True) id = CharField(max_length=32, primary_key=True)
thumbnail = TextField(null=True, help_text="thumbnail base64 string") thumbnail = TextField(null=True, help_text="thumbnail base64 string")
kb_id = CharField(max_length=256, null=False, index=True) kb_id = CharField(max_length=256, null=False, index=True)
parser_id = CharField(max_length=32, null=False, help_text="default parser ID") parser_id = CharField(
parser_config = JSONField(null=False, default={"pages":[[1,1000000]]}) max_length=32,
source_type = CharField(max_length=128, null=False, default="local", help_text="where dose this document from") null=False,
help_text="default parser ID")
parser_config = JSONField(null=False, default={"pages": [[1, 1000000]]})
source_type = CharField(
max_length=128,
null=False,
default="local",
help_text="where dose this document from")
type = CharField(max_length=32, null=False, help_text="file extension") type = CharField(max_length=32, null=False, help_text="file extension")
created_by = CharField(max_length=32, null=False, help_text="who created it") created_by = CharField(
name = CharField(max_length=255, null=True, help_text="file name", index=True) max_length=32,
location = CharField(max_length=255, null=True, help_text="where dose it store") null=False,
help_text="who created it")
name = CharField(
max_length=255,
null=True,
help_text="file name",
index=True)
location = CharField(
max_length=255,
null=True,
help_text="where dose it store")
size = IntegerField(default=0) size = IntegerField(default=0)
token_num = IntegerField(default=0) token_num = IntegerField(default=0)
chunk_num = IntegerField(default=0) chunk_num = IntegerField(default=0)
progress = FloatField(default=0) progress = FloatField(default=0)
progress_msg = TextField(null=True, help_text="process message", default="") progress_msg = TextField(
null=True,
help_text="process message",
default="")
process_begin_at = DateTimeField(null=True) process_begin_at = DateTimeField(null=True)
process_duation = FloatField(default=0) process_duation = FloatField(default=0)
run = CharField(max_length=1, null=True, help_text="start to run processing or cancel.(1: run it; 2: cancel)", default="0") run = CharField(
status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted1: validate)", default="1") max_length=1,
null=True,
help_text="start to run processing or cancel.(1: run it; 2: cancel)",
default="0")
status = CharField(
max_length=1,
null=True,
help_text="is it validate(0: wasted1: validate)",
default="1")
class Meta: class Meta:
db_table = "document" db_table = "document"
@ -520,30 +677,52 @@ class Task(DataBaseModel):
begin_at = DateTimeField(null=True) begin_at = DateTimeField(null=True)
process_duation = FloatField(default=0) process_duation = FloatField(default=0)
progress = FloatField(default=0) progress = FloatField(default=0)
progress_msg = TextField(null=True, help_text="process message", default="") progress_msg = TextField(
null=True,
help_text="process message",
default="")
class Dialog(DataBaseModel): class Dialog(DataBaseModel):
id = CharField(max_length=32, primary_key=True) id = CharField(max_length=32, primary_key=True)
tenant_id = CharField(max_length=32, null=False) tenant_id = CharField(max_length=32, null=False)
name = CharField(max_length=255, null=True, help_text="dialog application name") name = CharField(
max_length=255,
null=True,
help_text="dialog application name")
description = TextField(null=True, help_text="Dialog description") description = TextField(null=True, help_text="Dialog description")
icon = TextField(null=True, help_text="icon base64 string") icon = TextField(null=True, help_text="icon base64 string")
language = CharField(max_length=32, null=True, default="Chinese", help_text="English|Chinese") language = CharField(
max_length=32,
null=True,
default="Chinese",
help_text="English|Chinese")
llm_id = CharField(max_length=32, null=False, help_text="default llm ID") llm_id = CharField(max_length=32, null=False, help_text="default llm ID")
llm_setting = JSONField(null=False, default={"temperature": 0.1, "top_p": 0.3, "frequency_penalty": 0.7, llm_setting = JSONField(null=False, default={"temperature": 0.1, "top_p": 0.3, "frequency_penalty": 0.7,
"presence_penalty": 0.4, "max_tokens": 215}) "presence_penalty": 0.4, "max_tokens": 215})
prompt_type = CharField(max_length=16, null=False, default="simple", help_text="simple|advanced") prompt_type = CharField(
max_length=16,
null=False,
default="simple",
help_text="simple|advanced")
prompt_config = JSONField(null=False, default={"system": "", "prologue": "您好我是您的助手小樱长得可爱又善良can I help you?", prompt_config = JSONField(null=False, default={"system": "", "prologue": "您好我是您的助手小樱长得可爱又善良can I help you?",
"parameters": [], "empty_response": "Sorry! 知识库中未找到相关内容!"}) "parameters": [], "empty_response": "Sorry! 知识库中未找到相关内容!"})
similarity_threshold = FloatField(default=0.2) similarity_threshold = FloatField(default=0.2)
vector_similarity_weight = FloatField(default=0.3) vector_similarity_weight = FloatField(default=0.3)
top_n = IntegerField(default=6) top_n = IntegerField(default=6)
do_refer = CharField(max_length=1, null=False, help_text="it needs to insert reference index into answer or not", default="1") do_refer = CharField(
max_length=1,
null=False,
help_text="it needs to insert reference index into answer or not",
default="1")
kb_ids = JSONField(null=False, default=[]) kb_ids = JSONField(null=False, default=[])
status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted1: validate)", default="1") status = CharField(
max_length=1,
null=True,
help_text="is it validate(0: wasted1: validate)",
default="1")
class Meta: class Meta:
db_table = "dialog" db_table = "dialog"

View File

@ -32,8 +32,7 @@ LOGGER = getLogger()
def bulk_insert_into_db(model, data_source, replace_on_conflict=False): def bulk_insert_into_db(model, data_source, replace_on_conflict=False):
DB.create_tables([model]) DB.create_tables([model])
for i, data in enumerate(data_source):
for i,data in enumerate(data_source):
current_time = current_timestamp() + i current_time = current_timestamp() + i
current_date = timestamp_to_date(current_time) current_date = timestamp_to_date(current_time)
if 'create_time' not in data: if 'create_time' not in data:
@ -55,7 +54,8 @@ def bulk_insert_into_db(model, data_source, replace_on_conflict=False):
def get_dynamic_db_model(base, job_id): def get_dynamic_db_model(base, job_id):
return type(base.model(table_index=get_dynamic_tracking_table_index(job_id=job_id))) return type(base.model(
table_index=get_dynamic_tracking_table_index(job_id=job_id)))
def get_dynamic_tracking_table_index(job_id): def get_dynamic_tracking_table_index(job_id):
@ -86,7 +86,9 @@ supported_operators = {
'~': operator.inv, '~': operator.inv,
} }
def query_dict2expression(model: Type[DataBaseModel], query: Dict[str, Union[bool, int, str, list, tuple]]):
def query_dict2expression(
model: Type[DataBaseModel], query: Dict[str, Union[bool, int, str, list, tuple]]):
expression = [] expression = []
for field, value in query.items(): for field, value in query.items():
@ -95,7 +97,10 @@ def query_dict2expression(model: Type[DataBaseModel], query: Dict[str, Union[boo
op, *val = value op, *val = value
field = getattr(model, f'f_{field}') field = getattr(model, f'f_{field}')
value = supported_operators[op](field, val[0]) if op in supported_operators else getattr(field, op)(*val) value = supported_operators[op](
field, val[0]) if op in supported_operators else getattr(
field, op)(
*val)
expression.append(value) expression.append(value)
return reduce(operator.iand, expression) return reduce(operator.iand, expression)

View File

@ -61,16 +61,25 @@ def init_superuser():
TenantService.insert(**tenant) TenantService.insert(**tenant)
UserTenantService.insert(**usr_tenant) UserTenantService.insert(**usr_tenant)
TenantLLMService.insert_many(tenant_llm) TenantLLMService.insert_many(tenant_llm)
print("【INFO】Super user initialized. \033[93memail: admin@ragflow.io, password: admin\033[0m. Changing the password after logining is strongly recomanded.") print(
"【INFO】Super user initialized. \033[93memail: admin@ragflow.io, password: admin\033[0m. Changing the password after logining is strongly recomanded.")
chat_mdl = LLMBundle(tenant["id"], LLMType.CHAT, tenant["llm_id"]) chat_mdl = LLMBundle(tenant["id"], LLMType.CHAT, tenant["llm_id"])
msg = chat_mdl.chat(system="", history=[{"role": "user", "content": "Hello!"}], gen_conf={}) msg = chat_mdl.chat(system="", history=[
{"role": "user", "content": "Hello!"}], gen_conf={})
if msg.find("ERROR: ") == 0: if msg.find("ERROR: ") == 0:
print("\33[91m【ERROR】\33[0m: ", "'{}' dosen't work. {}".format(tenant["llm_id"], msg)) print(
"\33[91m【ERROR】\33[0m: ",
"'{}' dosen't work. {}".format(
tenant["llm_id"],
msg))
embd_mdl = LLMBundle(tenant["id"], LLMType.EMBEDDING, tenant["embd_id"]) embd_mdl = LLMBundle(tenant["id"], LLMType.EMBEDDING, tenant["embd_id"])
v, c = embd_mdl.encode(["Hello!"]) v, c = embd_mdl.encode(["Hello!"])
if c == 0: if c == 0:
print("\33[91m【ERROR】\33[0m:", " '{}' dosen't work!".format(tenant["embd_id"])) print(
"\33[91m【ERROR】\33[0m:",
" '{}' dosen't work!".format(
tenant["embd_id"]))
factory_infos = [{ factory_infos = [{
@ -78,28 +87,28 @@ factory_infos = [{
"logo": "", "logo": "",
"tags": "LLM,TEXT EMBEDDING,SPEECH2TEXT,MODERATION", "tags": "LLM,TEXT EMBEDDING,SPEECH2TEXT,MODERATION",
"status": "1", "status": "1",
},{ }, {
"name": "Tongyi-Qianwen", "name": "Tongyi-Qianwen",
"logo": "", "logo": "",
"tags": "LLM,TEXT EMBEDDING,SPEECH2TEXT,MODERATION", "tags": "LLM,TEXT EMBEDDING,SPEECH2TEXT,MODERATION",
"status": "1", "status": "1",
},{ }, {
"name": "ZHIPU-AI", "name": "ZHIPU-AI",
"logo": "", "logo": "",
"tags": "LLM,TEXT EMBEDDING,SPEECH2TEXT,MODERATION", "tags": "LLM,TEXT EMBEDDING,SPEECH2TEXT,MODERATION",
"status": "1", "status": "1",
}, },
{ {
"name": "Local", "name": "Local",
"logo": "", "logo": "",
"tags": "LLM,TEXT EMBEDDING,SPEECH2TEXT,MODERATION", "tags": "LLM,TEXT EMBEDDING,SPEECH2TEXT,MODERATION",
"status": "1", "status": "1",
},{ }, {
"name": "Moonshot", "name": "Moonshot",
"logo": "", "logo": "",
"tags": "LLM,TEXT EMBEDDING", "tags": "LLM,TEXT EMBEDDING",
"status": "1", "status": "1",
} }
# { # {
# "name": "文心一言", # "name": "文心一言",
# "logo": "", # "logo": "",
@ -107,6 +116,8 @@ factory_infos = [{
# "status": "1", # "status": "1",
# }, # },
] ]
def init_llm_factory(): def init_llm_factory():
llm_infos = [ llm_infos = [
# ---------------------- OpenAI ------------------------ # ---------------------- OpenAI ------------------------
@ -116,37 +127,37 @@ def init_llm_factory():
"tags": "LLM,CHAT,4K", "tags": "LLM,CHAT,4K",
"max_tokens": 4096, "max_tokens": 4096,
"model_type": LLMType.CHAT.value "model_type": LLMType.CHAT.value
},{ }, {
"fid": factory_infos[0]["name"], "fid": factory_infos[0]["name"],
"llm_name": "gpt-3.5-turbo-16k-0613", "llm_name": "gpt-3.5-turbo-16k-0613",
"tags": "LLM,CHAT,16k", "tags": "LLM,CHAT,16k",
"max_tokens": 16385, "max_tokens": 16385,
"model_type": LLMType.CHAT.value "model_type": LLMType.CHAT.value
},{ }, {
"fid": factory_infos[0]["name"], "fid": factory_infos[0]["name"],
"llm_name": "text-embedding-ada-002", "llm_name": "text-embedding-ada-002",
"tags": "TEXT EMBEDDING,8K", "tags": "TEXT EMBEDDING,8K",
"max_tokens": 8191, "max_tokens": 8191,
"model_type": LLMType.EMBEDDING.value "model_type": LLMType.EMBEDDING.value
},{ }, {
"fid": factory_infos[0]["name"], "fid": factory_infos[0]["name"],
"llm_name": "whisper-1", "llm_name": "whisper-1",
"tags": "SPEECH2TEXT", "tags": "SPEECH2TEXT",
"max_tokens": 25*1024*1024, "max_tokens": 25 * 1024 * 1024,
"model_type": LLMType.SPEECH2TEXT.value "model_type": LLMType.SPEECH2TEXT.value
},{ }, {
"fid": factory_infos[0]["name"], "fid": factory_infos[0]["name"],
"llm_name": "gpt-4", "llm_name": "gpt-4",
"tags": "LLM,CHAT,8K", "tags": "LLM,CHAT,8K",
"max_tokens": 8191, "max_tokens": 8191,
"model_type": LLMType.CHAT.value "model_type": LLMType.CHAT.value
},{ }, {
"fid": factory_infos[0]["name"], "fid": factory_infos[0]["name"],
"llm_name": "gpt-4-32k", "llm_name": "gpt-4-32k",
"tags": "LLM,CHAT,32K", "tags": "LLM,CHAT,32K",
"max_tokens": 32768, "max_tokens": 32768,
"model_type": LLMType.CHAT.value "model_type": LLMType.CHAT.value
},{ }, {
"fid": factory_infos[0]["name"], "fid": factory_infos[0]["name"],
"llm_name": "gpt-4-vision-preview", "llm_name": "gpt-4-vision-preview",
"tags": "LLM,CHAT,IMAGE2TEXT", "tags": "LLM,CHAT,IMAGE2TEXT",
@ -160,31 +171,31 @@ def init_llm_factory():
"tags": "LLM,CHAT,8K", "tags": "LLM,CHAT,8K",
"max_tokens": 8191, "max_tokens": 8191,
"model_type": LLMType.CHAT.value "model_type": LLMType.CHAT.value
},{ }, {
"fid": factory_infos[1]["name"], "fid": factory_infos[1]["name"],
"llm_name": "qwen-plus", "llm_name": "qwen-plus",
"tags": "LLM,CHAT,32K", "tags": "LLM,CHAT,32K",
"max_tokens": 32768, "max_tokens": 32768,
"model_type": LLMType.CHAT.value "model_type": LLMType.CHAT.value
},{ }, {
"fid": factory_infos[1]["name"], "fid": factory_infos[1]["name"],
"llm_name": "qwen-max-1201", "llm_name": "qwen-max-1201",
"tags": "LLM,CHAT,6K", "tags": "LLM,CHAT,6K",
"max_tokens": 5899, "max_tokens": 5899,
"model_type": LLMType.CHAT.value "model_type": LLMType.CHAT.value
},{ }, {
"fid": factory_infos[1]["name"], "fid": factory_infos[1]["name"],
"llm_name": "text-embedding-v2", "llm_name": "text-embedding-v2",
"tags": "TEXT EMBEDDING,2K", "tags": "TEXT EMBEDDING,2K",
"max_tokens": 2048, "max_tokens": 2048,
"model_type": LLMType.EMBEDDING.value "model_type": LLMType.EMBEDDING.value
},{ }, {
"fid": factory_infos[1]["name"], "fid": factory_infos[1]["name"],
"llm_name": "paraformer-realtime-8k-v1", "llm_name": "paraformer-realtime-8k-v1",
"tags": "SPEECH2TEXT", "tags": "SPEECH2TEXT",
"max_tokens": 25*1024*1024, "max_tokens": 25 * 1024 * 1024,
"model_type": LLMType.SPEECH2TEXT.value "model_type": LLMType.SPEECH2TEXT.value
},{ }, {
"fid": factory_infos[1]["name"], "fid": factory_infos[1]["name"],
"llm_name": "qwen-vl-max", "llm_name": "qwen-vl-max",
"tags": "LLM,CHAT,IMAGE2TEXT", "tags": "LLM,CHAT,IMAGE2TEXT",
@ -245,13 +256,13 @@ def init_llm_factory():
"tags": "TEXT EMBEDDING,", "tags": "TEXT EMBEDDING,",
"max_tokens": 128 * 1000, "max_tokens": 128 * 1000,
"model_type": LLMType.EMBEDDING.value "model_type": LLMType.EMBEDDING.value
},{ }, {
"fid": factory_infos[4]["name"], "fid": factory_infos[4]["name"],
"llm_name": "moonshot-v1-32k", "llm_name": "moonshot-v1-32k",
"tags": "LLM,CHAT,", "tags": "LLM,CHAT,",
"max_tokens": 32768, "max_tokens": 32768,
"model_type": LLMType.CHAT.value "model_type": LLMType.CHAT.value
},{ }, {
"fid": factory_infos[4]["name"], "fid": factory_infos[4]["name"],
"llm_name": "moonshot-v1-128k", "llm_name": "moonshot-v1-128k",
"tags": "LLM,CHAT", "tags": "LLM,CHAT",
@ -294,7 +305,6 @@ def init_web_data():
print("init web data success:{}".format(time.time() - start_time)) print("init web data success:{}".format(time.time() - start_time))
if __name__ == '__main__': if __name__ == '__main__':
init_web_db() init_web_db()
init_web_data() init_web_data()

View File

@ -18,7 +18,8 @@ class ReloadConfigBase:
def get_all(cls): def get_all(cls):
configs = {} configs = {}
for k, v in cls.__dict__.items(): for k, v in cls.__dict__.items():
if not callable(getattr(cls, k)) and not k.startswith("__") and not k.startswith("_"): if not callable(getattr(cls, k)) and not k.startswith(
"__") and not k.startswith("_"):
configs[k] = v configs[k] = v
return configs return configs

View File

@ -27,7 +27,8 @@ class CommonService:
@classmethod @classmethod
@DB.connection_context() @DB.connection_context()
def query(cls, cols=None, reverse=None, order_by=None, **kwargs): def query(cls, cols=None, reverse=None, order_by=None, **kwargs):
return cls.model.query(cols=cols, reverse=reverse, order_by=order_by, **kwargs) return cls.model.query(cols=cols, reverse=reverse,
order_by=order_by, **kwargs)
@classmethod @classmethod
@DB.connection_context() @DB.connection_context()
@ -40,9 +41,11 @@ class CommonService:
if not order_by or not hasattr(cls, order_by): if not order_by or not hasattr(cls, order_by):
order_by = "create_time" order_by = "create_time"
if reverse is True: if reverse is True:
query_records = query_records.order_by(cls.model.getter_by(order_by).desc()) query_records = query_records.order_by(
cls.model.getter_by(order_by).desc())
elif reverse is False: elif reverse is False:
query_records = query_records.order_by(cls.model.getter_by(order_by).asc()) query_records = query_records.order_by(
cls.model.getter_by(order_by).asc())
return query_records return query_records
@classmethod @classmethod
@ -61,7 +64,7 @@ class CommonService:
@classmethod @classmethod
@DB.connection_context() @DB.connection_context()
def save(cls, **kwargs): def save(cls, **kwargs):
#if "id" not in kwargs: # if "id" not in kwargs:
# kwargs["id"] = get_uuid() # kwargs["id"] = get_uuid()
sample_obj = cls.model(**kwargs).save(force_insert=True) sample_obj = cls.model(**kwargs).save(force_insert=True)
return sample_obj return sample_obj
@ -95,7 +98,8 @@ class CommonService:
for data in data_list: for data in data_list:
data["update_time"] = current_timestamp() data["update_time"] = current_timestamp()
data["update_date"] = datetime_format(datetime.now()) data["update_date"] = datetime_format(datetime.now())
cls.model.update(data).where(cls.model.id == data["id"]).execute() cls.model.update(data).where(
cls.model.id == data["id"]).execute()
@classmethod @classmethod
@DB.connection_context() @DB.connection_context()
@ -128,7 +132,6 @@ class CommonService:
def delete_by_id(cls, pid): def delete_by_id(cls, pid):
return cls.model.delete().where(cls.model.id == pid).execute() return cls.model.delete().where(cls.model.id == pid).execute()
@classmethod @classmethod
@DB.connection_context() @DB.connection_context()
def filter_delete(cls, filters): def filter_delete(cls, filters):
@ -151,19 +154,30 @@ class CommonService:
@classmethod @classmethod
@DB.connection_context() @DB.connection_context()
def filter_scope_list(cls, in_key, in_filters_list, filters=None, cols=None): def filter_scope_list(cls, in_key, in_filters_list,
filters=None, cols=None):
in_filters_tuple_list = cls.cut_list(in_filters_list, 20) in_filters_tuple_list = cls.cut_list(in_filters_list, 20)
if not filters: if not filters:
filters = [] filters = []
res_list = [] res_list = []
if cols: if cols:
for i in in_filters_tuple_list: for i in in_filters_tuple_list:
query_records = cls.model.select(*cols).where(getattr(cls.model, in_key).in_(i), *filters) query_records = cls.model.select(
*
cols).where(
getattr(
cls.model,
in_key).in_(i),
*
filters)
if query_records: if query_records:
res_list.extend([query_record for query_record in query_records]) res_list.extend(
[query_record for query_record in query_records])
else: else:
for i in in_filters_tuple_list: for i in in_filters_tuple_list:
query_records = cls.model.select().where(getattr(cls.model, in_key).in_(i), *filters) query_records = cls.model.select().where(
getattr(cls.model, in_key).in_(i), *filters)
if query_records: if query_records:
res_list.extend([query_record for query_record in query_records]) res_list.extend(
[query_record for query_record in query_records])
return res_list return res_list

View File

@ -21,6 +21,5 @@ class DialogService(CommonService):
model = Dialog model = Dialog
class ConversationService(CommonService): class ConversationService(CommonService):
model = Conversation model = Conversation

View File

@ -72,7 +72,20 @@ class DocumentService(CommonService):
@classmethod @classmethod
@DB.connection_context() @DB.connection_context()
def get_newly_uploaded(cls, tm, mod=0, comm=1, items_per_page=64): def get_newly_uploaded(cls, tm, mod=0, comm=1, items_per_page=64):
fields = [cls.model.id, cls.model.kb_id, cls.model.parser_id, cls.model.parser_config, cls.model.name, cls.model.type, cls.model.location, cls.model.size, Knowledgebase.tenant_id, Tenant.embd_id, Tenant.img2txt_id, Tenant.asr_id, cls.model.update_time] fields = [
cls.model.id,
cls.model.kb_id,
cls.model.parser_id,
cls.model.parser_config,
cls.model.name,
cls.model.type,
cls.model.location,
cls.model.size,
Knowledgebase.tenant_id,
Tenant.embd_id,
Tenant.img2txt_id,
Tenant.asr_id,
cls.model.update_time]
docs = cls.model.select(*fields) \ docs = cls.model.select(*fields) \
.join(Knowledgebase, on=(cls.model.kb_id == Knowledgebase.id)) \ .join(Knowledgebase, on=(cls.model.kb_id == Knowledgebase.id)) \
.join(Tenant, on=(Knowledgebase.tenant_id == Tenant.id))\ .join(Tenant, on=(Knowledgebase.tenant_id == Tenant.id))\
@ -104,39 +117,63 @@ class DocumentService(CommonService):
def increment_chunk_num(cls, doc_id, kb_id, token_num, chunk_num, duation): def increment_chunk_num(cls, doc_id, kb_id, token_num, chunk_num, duation):
num = cls.model.update(token_num=cls.model.token_num + token_num, num = cls.model.update(token_num=cls.model.token_num + token_num,
chunk_num=cls.model.chunk_num + chunk_num, chunk_num=cls.model.chunk_num + chunk_num,
process_duation=cls.model.process_duation+duation).where( process_duation=cls.model.process_duation + duation).where(
cls.model.id == doc_id).execute() cls.model.id == doc_id).execute()
if num == 0:raise LookupError("Document not found which is supposed to be there") if num == 0:
num = Knowledgebase.update(token_num=Knowledgebase.token_num+token_num, chunk_num=Knowledgebase.chunk_num+chunk_num).where(Knowledgebase.id==kb_id).execute() raise LookupError(
"Document not found which is supposed to be there")
num = Knowledgebase.update(
token_num=Knowledgebase.token_num +
token_num,
chunk_num=Knowledgebase.chunk_num +
chunk_num).where(
Knowledgebase.id == kb_id).execute()
return num return num
@classmethod @classmethod
@DB.connection_context() @DB.connection_context()
def get_tenant_id(cls, doc_id): def get_tenant_id(cls, doc_id):
docs = cls.model.select(Knowledgebase.tenant_id).join(Knowledgebase, on=(Knowledgebase.id == cls.model.kb_id)).where(cls.model.id == doc_id, Knowledgebase.status==StatusEnum.VALID.value) docs = cls.model.select(
Knowledgebase.tenant_id).join(
Knowledgebase, on=(
Knowledgebase.id == cls.model.kb_id)).where(
cls.model.id == doc_id, Knowledgebase.status == StatusEnum.VALID.value)
docs = docs.dicts() docs = docs.dicts()
if not docs:return if not docs:
return
return docs[0]["tenant_id"] return docs[0]["tenant_id"]
@classmethod @classmethod
@DB.connection_context() @DB.connection_context()
def get_thumbnails(cls, docids): def get_thumbnails(cls, docids):
fields = [cls.model.id, cls.model.thumbnail] fields = [cls.model.id, cls.model.thumbnail]
return list(cls.model.select(*fields).where(cls.model.id.in_(docids)).dicts()) return list(cls.model.select(
*fields).where(cls.model.id.in_(docids)).dicts())
@classmethod @classmethod
@DB.connection_context() @DB.connection_context()
def update_parser_config(cls, id, config): def update_parser_config(cls, id, config):
e, d = cls.get_by_id(id) e, d = cls.get_by_id(id)
if not e:raise LookupError(f"Document({id}) not found.") if not e:
raise LookupError(f"Document({id}) not found.")
def dfs_update(old, new): def dfs_update(old, new):
for k,v in new.items(): for k, v in new.items():
if k not in old: if k not in old:
old[k] = v old[k] = v
continue continue
if isinstance(v, dict): if isinstance(v, dict):
assert isinstance(old[k], dict) assert isinstance(old[k], dict)
dfs_update(old[k], v) dfs_update(old[k], v)
else: old[k] = v else:
old[k] = v
dfs_update(d.parser_config, config) dfs_update(d.parser_config, config)
cls.update_by_id(id, {"parser_config": d.parser_config}) cls.update_by_id(id, {"parser_config": d.parser_config})
@classmethod
@DB.connection_context()
def get_doc_count(cls, tenant_id):
docs = cls.model.select(cls.model.id).join(Knowledgebase,
on=(Knowledgebase.id == cls.model.kb_id)).where(
Knowledgebase.tenant_id == tenant_id)
return len(docs)

View File

@ -55,7 +55,7 @@ class KnowledgebaseService(CommonService):
cls.model.chunk_num, cls.model.chunk_num,
cls.model.parser_id, cls.model.parser_id,
cls.model.parser_config] cls.model.parser_config]
kbs = cls.model.select(*fields).join(Tenant, on=((Tenant.id == cls.model.tenant_id)&(Tenant.status== StatusEnum.VALID.value))).where( kbs = cls.model.select(*fields).join(Tenant, on=((Tenant.id == cls.model.tenant_id) & (Tenant.status == StatusEnum.VALID.value))).where(
(cls.model.id == kb_id), (cls.model.id == kb_id),
(cls.model.status == StatusEnum.VALID.value) (cls.model.status == StatusEnum.VALID.value)
) )
@ -69,9 +69,11 @@ class KnowledgebaseService(CommonService):
@DB.connection_context() @DB.connection_context()
def update_parser_config(cls, id, config): def update_parser_config(cls, id, config):
e, m = cls.get_by_id(id) e, m = cls.get_by_id(id)
if not e:raise LookupError(f"knowledgebase({id}) not found.") if not e:
raise LookupError(f"knowledgebase({id}) not found.")
def dfs_update(old, new): def dfs_update(old, new):
for k,v in new.items(): for k, v in new.items():
if k not in old: if k not in old:
old[k] = v old[k] = v
continue continue
@ -80,12 +82,12 @@ class KnowledgebaseService(CommonService):
dfs_update(old[k], v) dfs_update(old[k], v)
elif isinstance(v, list): elif isinstance(v, list):
assert isinstance(old[k], list) assert isinstance(old[k], list)
old[k] = list(set(old[k]+v)) old[k] = list(set(old[k] + v))
else: old[k] = v else:
old[k] = v
dfs_update(m.parser_config, config) dfs_update(m.parser_config, config)
cls.update_by_id(id, {"parser_config": m.parser_config}) cls.update_by_id(id, {"parser_config": m.parser_config})
@classmethod @classmethod
@DB.connection_context() @DB.connection_context()
def get_field_map(cls, ids): def get_field_map(cls, ids):
@ -94,4 +96,3 @@ class KnowledgebaseService(CommonService):
if k.parser_config and "field_map" in k.parser_config: if k.parser_config and "field_map" in k.parser_config:
conf.update(k.parser_config["field_map"]) conf.update(k.parser_config["field_map"])
return conf return conf

View File

@ -59,7 +59,8 @@ class TenantLLMService(CommonService):
@classmethod @classmethod
@DB.connection_context() @DB.connection_context()
def model_instance(cls, tenant_id, llm_type, llm_name=None, lang="Chinese"): def model_instance(cls, tenant_id, llm_type,
llm_name=None, lang="Chinese"):
e, tenant = TenantService.get_by_id(tenant_id) e, tenant = TenantService.get_by_id(tenant_id)
if not e: if not e:
raise LookupError("Tenant not found") raise LookupError("Tenant not found")
@ -126,29 +127,39 @@ class LLMBundle(object):
self.tenant_id = tenant_id self.tenant_id = tenant_id
self.llm_type = llm_type self.llm_type = llm_type
self.llm_name = llm_name self.llm_name = llm_name
self.mdl = TenantLLMService.model_instance(tenant_id, llm_type, llm_name, lang=lang) self.mdl = TenantLLMService.model_instance(
assert self.mdl, "Can't find mole for {}/{}/{}".format(tenant_id, llm_type, llm_name) tenant_id, llm_type, llm_name, lang=lang)
assert self.mdl, "Can't find mole for {}/{}/{}".format(
tenant_id, llm_type, llm_name)
def encode(self, texts: list, batch_size=32): def encode(self, texts: list, batch_size=32):
emd, used_tokens = self.mdl.encode(texts, batch_size) emd, used_tokens = self.mdl.encode(texts, batch_size)
if TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens): if TenantLLMService.increase_usage(
database_logger.error("Can't update token usage for {}/EMBEDDING".format(self.tenant_id)) self.tenant_id, self.llm_type, used_tokens):
database_logger.error(
"Can't update token usage for {}/EMBEDDING".format(self.tenant_id))
return emd, used_tokens return emd, used_tokens
def encode_queries(self, query: str): def encode_queries(self, query: str):
emd, used_tokens = self.mdl.encode_queries(query) emd, used_tokens = self.mdl.encode_queries(query)
if TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens): if TenantLLMService.increase_usage(
database_logger.error("Can't update token usage for {}/EMBEDDING".format(self.tenant_id)) self.tenant_id, self.llm_type, used_tokens):
database_logger.error(
"Can't update token usage for {}/EMBEDDING".format(self.tenant_id))
return emd, used_tokens return emd, used_tokens
def describe(self, image, max_tokens=300): def describe(self, image, max_tokens=300):
txt, used_tokens = self.mdl.describe(image, max_tokens) txt, used_tokens = self.mdl.describe(image, max_tokens)
if not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens): if not TenantLLMService.increase_usage(
database_logger.error("Can't update token usage for {}/IMAGE2TEXT".format(self.tenant_id)) self.tenant_id, self.llm_type, used_tokens):
database_logger.error(
"Can't update token usage for {}/IMAGE2TEXT".format(self.tenant_id))
return txt return txt
def chat(self, system, history, gen_conf): def chat(self, system, history, gen_conf):
txt, used_tokens = self.mdl.chat(system, history, gen_conf) txt, used_tokens = self.mdl.chat(system, history, gen_conf)
if TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens, self.llm_name): if TenantLLMService.increase_usage(
database_logger.error("Can't update token usage for {}/CHAT".format(self.tenant_id)) self.tenant_id, self.llm_type, used_tokens, self.llm_name):
database_logger.error(
"Can't update token usage for {}/CHAT".format(self.tenant_id))
return txt return txt

View File

@ -54,7 +54,8 @@ class UserService(CommonService):
if "id" not in kwargs: if "id" not in kwargs:
kwargs["id"] = get_uuid() kwargs["id"] = get_uuid()
if "password" in kwargs: if "password" in kwargs:
kwargs["password"] = generate_password_hash(str(kwargs["password"])) kwargs["password"] = generate_password_hash(
str(kwargs["password"]))
kwargs["create_time"] = current_timestamp() kwargs["create_time"] = current_timestamp()
kwargs["create_date"] = datetime_format(datetime.now()) kwargs["create_date"] = datetime_format(datetime.now())
@ -63,12 +64,12 @@ class UserService(CommonService):
obj = cls.model(**kwargs).save(force_insert=True) obj = cls.model(**kwargs).save(force_insert=True)
return obj return obj
@classmethod @classmethod
@DB.connection_context() @DB.connection_context()
def delete_user(cls, user_ids, update_user_dict): def delete_user(cls, user_ids, update_user_dict):
with DB.atomic(): with DB.atomic():
cls.model.update({"status": 0}).where(cls.model.id.in_(user_ids)).execute() cls.model.update({"status": 0}).where(
cls.model.id.in_(user_ids)).execute()
@classmethod @classmethod
@DB.connection_context() @DB.connection_context()
@ -77,7 +78,8 @@ class UserService(CommonService):
if user_dict: if user_dict:
user_dict["update_time"] = current_timestamp() user_dict["update_time"] = current_timestamp()
user_dict["update_date"] = datetime_format(datetime.now()) user_dict["update_date"] = datetime_format(datetime.now())
cls.model.update(user_dict).where(cls.model.id == user_id).execute() cls.model.update(user_dict).where(
cls.model.id == user_id).execute()
class TenantService(CommonService): class TenantService(CommonService):
@ -86,17 +88,32 @@ class TenantService(CommonService):
@classmethod @classmethod
@DB.connection_context() @DB.connection_context()
def get_by_user_id(cls, user_id): def get_by_user_id(cls, user_id):
fields = [cls.model.id.alias("tenant_id"), cls.model.name, cls.model.llm_id, cls.model.embd_id, cls.model.asr_id, cls.model.img2txt_id, cls.model.parser_ids, UserTenant.role] fields = [
return list(cls.model.select(*fields)\ cls.model.id.alias("tenant_id"),
.join(UserTenant, on=((cls.model.id == UserTenant.tenant_id) & (UserTenant.user_id==user_id) & (UserTenant.status == StatusEnum.VALID.value)))\ cls.model.name,
cls.model.llm_id,
cls.model.embd_id,
cls.model.asr_id,
cls.model.img2txt_id,
cls.model.parser_ids,
UserTenant.role]
return list(cls.model.select(*fields)
.join(UserTenant, on=((cls.model.id == UserTenant.tenant_id) & (UserTenant.user_id == user_id) & (UserTenant.status == StatusEnum.VALID.value)))
.where(cls.model.status == StatusEnum.VALID.value).dicts()) .where(cls.model.status == StatusEnum.VALID.value).dicts())
@classmethod @classmethod
@DB.connection_context() @DB.connection_context()
def get_joined_tenants_by_user_id(cls, user_id): def get_joined_tenants_by_user_id(cls, user_id):
fields = [cls.model.id.alias("tenant_id"), cls.model.name, cls.model.llm_id, cls.model.embd_id, cls.model.asr_id, cls.model.img2txt_id, UserTenant.role] fields = [
return list(cls.model.select(*fields)\ cls.model.id.alias("tenant_id"),
.join(UserTenant, on=((cls.model.id == UserTenant.tenant_id) & (UserTenant.user_id==user_id) & (UserTenant.status == StatusEnum.VALID.value) & (UserTenant.role==UserTenantRole.NORMAL.value)))\ cls.model.name,
cls.model.llm_id,
cls.model.embd_id,
cls.model.asr_id,
cls.model.img2txt_id,
UserTenant.role]
return list(cls.model.select(*fields)
.join(UserTenant, on=((cls.model.id == UserTenant.tenant_id) & (UserTenant.user_id == user_id) & (UserTenant.status == StatusEnum.VALID.value) & (UserTenant.role == UserTenantRole.NORMAL.value)))
.where(cls.model.status == StatusEnum.VALID.value).dicts()) .where(cls.model.status == StatusEnum.VALID.value).dicts())
@classmethod @classmethod
@ -104,7 +121,9 @@ class TenantService(CommonService):
def decrease(cls, user_id, num): def decrease(cls, user_id, num):
num = cls.model.update(credit=cls.model.credit - num).where( num = cls.model.update(credit=cls.model.credit - num).where(
cls.model.id == user_id).execute() cls.model.id == user_id).execute()
if num == 0: raise LookupError("Tenant not found which is supposed to be there") if num == 0:
raise LookupError("Tenant not found which is supposed to be there")
class UserTenantService(CommonService): class UserTenantService(CommonService):
model = UserTenant model = UserTenant

View File

@ -13,16 +13,22 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# #
from rag.utils import ELASTICSEARCH
from rag.nlp import search
import os import os
from enum import IntEnum, Enum from enum import IntEnum, Enum
from api.utils import get_base_config,decrypt_database_config from api.utils import get_base_config, decrypt_database_config
from api.utils.file_utils import get_project_base_directory from api.utils.file_utils import get_project_base_directory
from api.utils.log_utils import LoggerFactory, getLogger from api.utils.log_utils import LoggerFactory, getLogger
# Logger # Logger
LoggerFactory.set_directory(os.path.join(get_project_base_directory(), "logs", "api")) LoggerFactory.set_directory(
os.path.join(
get_project_base_directory(),
"logs",
"api"))
# {CRITICAL: 50, FATAL:50, ERROR:40, WARNING:30, WARN:30, INFO:20, DEBUG:10, NOTSET:0} # {CRITICAL: 50, FATAL:50, ERROR:40, WARNING:30, WARN:30, INFO:20, DEBUG:10, NOTSET:0}
LoggerFactory.LEVEL = 10 LoggerFactory.LEVEL = 10
@ -86,7 +92,9 @@ default_llm = {
LLM = get_base_config("user_default_llm", {}) LLM = get_base_config("user_default_llm", {})
LLM_FACTORY = LLM.get("factory", "Tongyi-Qianwen") LLM_FACTORY = LLM.get("factory", "Tongyi-Qianwen")
if LLM_FACTORY not in default_llm: if LLM_FACTORY not in default_llm:
print("\33[91m【ERROR】\33[0m:", f"LLM factory {LLM_FACTORY} has not supported yet, switch to 'Tongyi-Qianwen/QWen' automatically, and please check the API_KEY in service_conf.yaml.") print(
"\33[91m【ERROR】\33[0m:",
f"LLM factory {LLM_FACTORY} has not supported yet, switch to 'Tongyi-Qianwen/QWen' automatically, and please check the API_KEY in service_conf.yaml.")
LLM_FACTORY = "Tongyi-Qianwen" LLM_FACTORY = "Tongyi-Qianwen"
CHAT_MDL = default_llm[LLM_FACTORY]["chat_model"] CHAT_MDL = default_llm[LLM_FACTORY]["chat_model"]
EMBEDDING_MDL = default_llm[LLM_FACTORY]["embedding_model"] EMBEDDING_MDL = default_llm[LLM_FACTORY]["embedding_model"]
@ -94,7 +102,9 @@ ASR_MDL = default_llm[LLM_FACTORY]["asr_model"]
IMAGE2TEXT_MDL = default_llm[LLM_FACTORY]["image2text_model"] IMAGE2TEXT_MDL = default_llm[LLM_FACTORY]["image2text_model"]
API_KEY = LLM.get("api_key", "") API_KEY = LLM.get("api_key", "")
PARSERS = LLM.get("parsers", "naive:General,qa:Q&A,resume:Resume,manual:Manual,table:Table,paper:Paper,book:Book,laws:Laws,presentation:Presentation,picture:Picture,one:One") PARSERS = LLM.get(
"parsers",
"naive:General,qa:Q&A,resume:Resume,manual:Manual,table:Table,paper:Paper,book:Book,laws:Laws,presentation:Presentation,picture:Picture,one:One")
# distribution # distribution
DEPENDENT_DISTRIBUTION = get_base_config("dependent_distribution", False) DEPENDENT_DISTRIBUTION = get_base_config("dependent_distribution", False)
@ -103,13 +113,25 @@ RAG_FLOW_UPDATE_CHECK = False
HOST = get_base_config(RAG_FLOW_SERVICE_NAME, {}).get("host", "127.0.0.1") HOST = get_base_config(RAG_FLOW_SERVICE_NAME, {}).get("host", "127.0.0.1")
HTTP_PORT = get_base_config(RAG_FLOW_SERVICE_NAME, {}).get("http_port") HTTP_PORT = get_base_config(RAG_FLOW_SERVICE_NAME, {}).get("http_port")
SECRET_KEY = get_base_config(RAG_FLOW_SERVICE_NAME, {}).get("secret_key", "infiniflow") SECRET_KEY = get_base_config(
TOKEN_EXPIRE_IN = get_base_config(RAG_FLOW_SERVICE_NAME, {}).get("token_expires_in", 3600) RAG_FLOW_SERVICE_NAME,
{}).get(
"secret_key",
"infiniflow")
TOKEN_EXPIRE_IN = get_base_config(
RAG_FLOW_SERVICE_NAME, {}).get(
"token_expires_in", 3600)
NGINX_HOST = get_base_config(RAG_FLOW_SERVICE_NAME, {}).get("nginx", {}).get("host") or HOST NGINX_HOST = get_base_config(
NGINX_HTTP_PORT = get_base_config(RAG_FLOW_SERVICE_NAME, {}).get("nginx", {}).get("http_port") or HTTP_PORT RAG_FLOW_SERVICE_NAME, {}).get(
"nginx", {}).get("host") or HOST
NGINX_HTTP_PORT = get_base_config(
RAG_FLOW_SERVICE_NAME, {}).get(
"nginx", {}).get("http_port") or HTTP_PORT
RANDOM_INSTANCE_ID = get_base_config(RAG_FLOW_SERVICE_NAME, {}).get("random_instance_id", False) RANDOM_INSTANCE_ID = get_base_config(
RAG_FLOW_SERVICE_NAME, {}).get(
"random_instance_id", False)
PROXY = get_base_config(RAG_FLOW_SERVICE_NAME, {}).get("proxy") PROXY = get_base_config(RAG_FLOW_SERVICE_NAME, {}).get("proxy")
PROXY_PROTOCOL = get_base_config(RAG_FLOW_SERVICE_NAME, {}).get("protocol") PROXY_PROTOCOL = get_base_config(RAG_FLOW_SERVICE_NAME, {}).get("protocol")
@ -124,7 +146,9 @@ UPLOAD_DATA_FROM_CLIENT = True
AUTHENTICATION_CONF = get_base_config("authentication", {}) AUTHENTICATION_CONF = get_base_config("authentication", {})
# client # client
CLIENT_AUTHENTICATION = AUTHENTICATION_CONF.get("client", {}).get("switch", False) CLIENT_AUTHENTICATION = AUTHENTICATION_CONF.get(
"client", {}).get(
"switch", False)
HTTP_APP_KEY = AUTHENTICATION_CONF.get("client", {}).get("http_app_key") HTTP_APP_KEY = AUTHENTICATION_CONF.get("client", {}).get("http_app_key")
GITHUB_OAUTH = get_base_config("oauth", {}).get("github") GITHUB_OAUTH = get_base_config("oauth", {}).get("github")
WECHAT_OAUTH = get_base_config("oauth", {}).get("wechat") WECHAT_OAUTH = get_base_config("oauth", {}).get("wechat")
@ -151,8 +175,6 @@ AUTHENTICATION_DEFAULT_TIMEOUT = 7 * 24 * 60 * 60 # s
PRIVILEGE_COMMAND_WHITELIST = [] PRIVILEGE_COMMAND_WHITELIST = []
CHECK_NODES_IDENTITY = False CHECK_NODES_IDENTITY = False
from rag.nlp import search
from rag.utils import ELASTICSEARCH
retrievaler = search.Dealer(ELASTICSEARCH) retrievaler = search.Dealer(ELASTICSEARCH)
@ -162,7 +184,7 @@ class CustomEnum(Enum):
try: try:
cls(value) cls(value)
return True return True
except: except BaseException:
return False return False
@classmethod @classmethod

View File

@ -34,10 +34,12 @@ from . import file_utils
SERVICE_CONF = "service_conf.yaml" SERVICE_CONF = "service_conf.yaml"
def conf_realpath(conf_name): def conf_realpath(conf_name):
conf_path = f"conf/{conf_name}" conf_path = f"conf/{conf_name}"
return os.path.join(file_utils.get_project_base_directory(), conf_path) return os.path.join(file_utils.get_project_base_directory(), conf_path)
def get_base_config(key, default=None, conf_name=SERVICE_CONF) -> dict: def get_base_config(key, default=None, conf_name=SERVICE_CONF) -> dict:
local_config = {} local_config = {}
local_path = conf_realpath(f'local.{conf_name}') local_path = conf_realpath(f'local.{conf_name}')
@ -62,7 +64,8 @@ def get_base_config(key, default=None, conf_name=SERVICE_CONF) -> dict:
return config.get(key, default) if key is not None else config return config.get(key, default) if key is not None else config
use_deserialize_safe_module = get_base_config('use_deserialize_safe_module', False) use_deserialize_safe_module = get_base_config(
'use_deserialize_safe_module', False)
class CoordinationCommunicationProtocol(object): class CoordinationCommunicationProtocol(object):
@ -93,7 +96,8 @@ class BaseType:
data[_k] = _dict(vv) data[_k] = _dict(vv)
else: else:
data = obj data = obj
return {"type": obj.__class__.__name__, "data": data, "module": module} return {"type": obj.__class__.__name__,
"data": data, "module": module}
return _dict(self) return _dict(self)
@ -129,7 +133,8 @@ def rag_uuid():
def string_to_bytes(string): def string_to_bytes(string):
return string if isinstance(string, bytes) else string.encode(encoding="utf-8") return string if isinstance(
string, bytes) else string.encode(encoding="utf-8")
def bytes_to_string(byte): def bytes_to_string(byte):
@ -137,7 +142,11 @@ def bytes_to_string(byte):
def json_dumps(src, byte=False, indent=None, with_type=False): def json_dumps(src, byte=False, indent=None, with_type=False):
dest = json.dumps(src, indent=indent, cls=CustomJSONEncoder, with_type=with_type) dest = json.dumps(
src,
indent=indent,
cls=CustomJSONEncoder,
with_type=with_type)
if byte: if byte:
dest = string_to_bytes(dest) dest = string_to_bytes(dest)
return dest return dest
@ -146,7 +155,8 @@ def json_dumps(src, byte=False, indent=None, with_type=False):
def json_loads(src, object_hook=None, object_pairs_hook=None): def json_loads(src, object_hook=None, object_pairs_hook=None):
if isinstance(src, bytes): if isinstance(src, bytes):
src = bytes_to_string(src) src = bytes_to_string(src)
return json.loads(src, object_hook=object_hook, object_pairs_hook=object_pairs_hook) return json.loads(src, object_hook=object_hook,
object_pairs_hook=object_pairs_hook)
def current_timestamp(): def current_timestamp():
@ -177,7 +187,9 @@ def serialize_b64(src, to_str=False):
def deserialize_b64(src): def deserialize_b64(src):
src = base64.b64decode(string_to_bytes(src) if isinstance(src, str) else src) src = base64.b64decode(
string_to_bytes(src) if isinstance(
src, str) else src)
if use_deserialize_safe_module: if use_deserialize_safe_module:
return restricted_loads(src) return restricted_loads(src)
return pickle.loads(src) return pickle.loads(src)
@ -237,12 +249,14 @@ def get_lan_ip():
pass pass
return ip or '' return ip or ''
def from_dict_hook(in_dict: dict): def from_dict_hook(in_dict: dict):
if "type" in in_dict and "data" in in_dict: if "type" in in_dict and "data" in in_dict:
if in_dict["module"] is None: if in_dict["module"] is None:
return in_dict["data"] return in_dict["data"]
else: else:
return getattr(importlib.import_module(in_dict["module"]), in_dict["type"])(**in_dict["data"]) return getattr(importlib.import_module(
in_dict["module"]), in_dict["type"])(**in_dict["data"])
else: else:
return in_dict return in_dict
@ -259,12 +273,16 @@ def decrypt_database_password(password):
raise ValueError("No private key") raise ValueError("No private key")
module_fun = encrypt_module.split("#") module_fun = encrypt_module.split("#")
pwdecrypt_fun = getattr(importlib.import_module(module_fun[0]), module_fun[1]) pwdecrypt_fun = getattr(
importlib.import_module(
module_fun[0]),
module_fun[1])
return pwdecrypt_fun(private_key, password) return pwdecrypt_fun(private_key, password)
def decrypt_database_config(database=None, passwd_key="password", name="database"): def decrypt_database_config(
database=None, passwd_key="password", name="database"):
if not database: if not database:
database = get_base_config(name, {}) database = get_base_config(name, {})
@ -275,7 +293,8 @@ def decrypt_database_config(database=None, passwd_key="password", name="database
def update_config(key, value, conf_name=SERVICE_CONF): def update_config(key, value, conf_name=SERVICE_CONF):
conf_path = conf_realpath(conf_name=conf_name) conf_path = conf_realpath(conf_name=conf_name)
if not os.path.isabs(conf_path): if not os.path.isabs(conf_path):
conf_path = os.path.join(file_utils.get_project_base_directory(), conf_path) conf_path = os.path.join(
file_utils.get_project_base_directory(), conf_path)
with FileLock(os.path.join(os.path.dirname(conf_path), ".lock")): with FileLock(os.path.join(os.path.dirname(conf_path), ".lock")):
config = file_utils.load_yaml_conf(conf_path=conf_path) or {} config = file_utils.load_yaml_conf(conf_path=conf_path) or {}
@ -288,7 +307,8 @@ def get_uuid():
def datetime_format(date_time: datetime.datetime) -> datetime.datetime: def datetime_format(date_time: datetime.datetime) -> datetime.datetime:
return datetime.datetime(date_time.year, date_time.month, date_time.day, date_time.hour, date_time.minute, date_time.second) return datetime.datetime(date_time.year, date_time.month, date_time.day,
date_time.hour, date_time.minute, date_time.second)
def get_format_time() -> datetime.datetime: def get_format_time() -> datetime.datetime:
@ -307,14 +327,19 @@ def elapsed2time(elapsed):
def decrypt(line): def decrypt(line):
file_path = os.path.join(file_utils.get_project_base_directory(), "conf", "private.pem") file_path = os.path.join(
file_utils.get_project_base_directory(),
"conf",
"private.pem")
rsa_key = RSA.importKey(open(file_path).read(), "Welcome") rsa_key = RSA.importKey(open(file_path).read(), "Welcome")
cipher = Cipher_pkcs1_v1_5.new(rsa_key) cipher = Cipher_pkcs1_v1_5.new(rsa_key)
return cipher.decrypt(base64.b64decode(line), "Fail to decrypt password!").decode('utf-8') return cipher.decrypt(base64.b64decode(
line), "Fail to decrypt password!").decode('utf-8')
def download_img(url): def download_img(url):
if not url: return "" if not url:
return ""
response = requests.get(url) response = requests.get(url)
return "data:" + \ return "data:" + \
response.headers.get('Content-Type', 'image/jpg') + ";" + \ response.headers.get('Content-Type', 'image/jpg') + ";" + \

View File

@ -19,7 +19,7 @@ import time
from functools import wraps from functools import wraps
from io import BytesIO from io import BytesIO
from flask import ( from flask import (
Response, jsonify, send_file,make_response, Response, jsonify, send_file, make_response,
request as flask_request, request as flask_request,
) )
from werkzeug.http import HTTP_STATUS_CODES from werkzeug.http import HTTP_STATUS_CODES
@ -29,7 +29,7 @@ from api.versions import get_rag_version
from api.settings import RetCode from api.settings import RetCode
from api.settings import ( from api.settings import (
REQUEST_MAX_WAIT_SEC, REQUEST_WAIT_SEC, REQUEST_MAX_WAIT_SEC, REQUEST_WAIT_SEC,
stat_logger,CLIENT_AUTHENTICATION, HTTP_APP_KEY, SECRET_KEY stat_logger, CLIENT_AUTHENTICATION, HTTP_APP_KEY, SECRET_KEY
) )
import requests import requests
import functools import functools
@ -40,14 +40,21 @@ from hmac import HMAC
from urllib.parse import quote, urlencode from urllib.parse import quote, urlencode
requests.models.complexjson.dumps = functools.partial(json.dumps, cls=CustomJSONEncoder) requests.models.complexjson.dumps = functools.partial(
json.dumps, cls=CustomJSONEncoder)
def request(**kwargs): def request(**kwargs):
sess = requests.Session() sess = requests.Session()
stream = kwargs.pop('stream', sess.stream) stream = kwargs.pop('stream', sess.stream)
timeout = kwargs.pop('timeout', None) timeout = kwargs.pop('timeout', None)
kwargs['headers'] = {k.replace('_', '-').upper(): v for k, v in kwargs.get('headers', {}).items()} kwargs['headers'] = {
k.replace(
'_',
'-').upper(): v for k,
v in kwargs.get(
'headers',
{}).items()}
prepped = requests.Request(**kwargs).prepare() prepped = requests.Request(**kwargs).prepare()
if CLIENT_AUTHENTICATION and HTTP_APP_KEY and SECRET_KEY: if CLIENT_AUTHENTICATION and HTTP_APP_KEY and SECRET_KEY:
@ -59,7 +66,11 @@ def request(**kwargs):
HTTP_APP_KEY.encode('ascii'), HTTP_APP_KEY.encode('ascii'),
prepped.path_url.encode('ascii'), prepped.path_url.encode('ascii'),
prepped.body if kwargs.get('json') else b'', prepped.body if kwargs.get('json') else b'',
urlencode(sorted(kwargs['data'].items()), quote_via=quote, safe='-._~').encode('ascii') urlencode(
sorted(
kwargs['data'].items()),
quote_via=quote,
safe='-._~').encode('ascii')
if kwargs.get('data') and isinstance(kwargs['data'], dict) else b'', if kwargs.get('data') and isinstance(kwargs['data'], dict) else b'',
]), 'sha1').digest()).decode('ascii') ]), 'sha1').digest()).decode('ascii')
@ -88,11 +99,12 @@ def get_exponential_backoff_interval(retries, full_jitter=False):
return max(0, countdown) return max(0, countdown)
def get_json_result(retcode=RetCode.SUCCESS, retmsg='success', data=None, job_id=None, meta=None): def get_json_result(retcode=RetCode.SUCCESS, retmsg='success',
data=None, job_id=None, meta=None):
import re import re
result_dict = { result_dict = {
"retcode": retcode, "retcode": retcode,
"retmsg":retmsg, "retmsg": retmsg,
# "retmsg": re.sub(r"rag", "seceum", retmsg, flags=re.IGNORECASE), # "retmsg": re.sub(r"rag", "seceum", retmsg, flags=re.IGNORECASE),
"data": data, "data": data,
"jobId": job_id, "jobId": job_id,
@ -107,9 +119,17 @@ def get_json_result(retcode=RetCode.SUCCESS, retmsg='success', data=None, job_id
response[key] = value response[key] = value
return jsonify(response) return jsonify(response)
def get_data_error_result(retcode=RetCode.DATA_ERROR, retmsg='Sorry! Data missing!'):
def get_data_error_result(retcode=RetCode.DATA_ERROR,
retmsg='Sorry! Data missing!'):
import re import re
result_dict = {"retcode": retcode, "retmsg": re.sub(r"rag", "seceum", retmsg, flags=re.IGNORECASE)} result_dict = {
"retcode": retcode,
"retmsg": re.sub(
r"rag",
"seceum",
retmsg,
flags=re.IGNORECASE)}
response = {} response = {}
for key, value in result_dict.items(): for key, value in result_dict.items():
if value is None and key != "retcode": if value is None and key != "retcode":
@ -118,15 +138,17 @@ def get_data_error_result(retcode=RetCode.DATA_ERROR, retmsg='Sorry! Data missin
response[key] = value response[key] = value
return jsonify(response) return jsonify(response)
def server_error_response(e): def server_error_response(e):
stat_logger.exception(e) stat_logger.exception(e)
try: try:
if e.code==401: if e.code == 401:
return get_json_result(retcode=401, retmsg=repr(e)) return get_json_result(retcode=401, retmsg=repr(e))
except: except BaseException:
pass pass
if len(e.args) > 1: if len(e.args) > 1:
return get_json_result(retcode=RetCode.EXCEPTION_ERROR, retmsg=repr(e.args[0]), data=e.args[1]) return get_json_result(
retcode=RetCode.EXCEPTION_ERROR, retmsg=repr(e.args[0]), data=e.args[1])
return get_json_result(retcode=RetCode.EXCEPTION_ERROR, retmsg=repr(e)) return get_json_result(retcode=RetCode.EXCEPTION_ERROR, retmsg=repr(e))
@ -162,10 +184,13 @@ def validate_request(*args, **kwargs):
if no_arguments or error_arguments: if no_arguments or error_arguments:
error_string = "" error_string = ""
if no_arguments: if no_arguments:
error_string += "required argument are missing: {}; ".format(",".join(no_arguments)) error_string += "required argument are missing: {}; ".format(
",".join(no_arguments))
if error_arguments: if error_arguments:
error_string += "required argument values: {}".format(",".join(["{}={}".format(a[0], a[1]) for a in error_arguments])) error_string += "required argument values: {}".format(
return get_json_result(retcode=RetCode.ARGUMENT_ERROR, retmsg=error_string) ",".join(["{}={}".format(a[0], a[1]) for a in error_arguments]))
return get_json_result(
retcode=RetCode.ARGUMENT_ERROR, retmsg=error_string)
return func(*_args, **_kwargs) return func(*_args, **_kwargs)
return decorated_function return decorated_function
return wrapper return wrapper
@ -193,7 +218,8 @@ def get_json_result(retcode=RetCode.SUCCESS, retmsg='success', data=None):
return jsonify(response) return jsonify(response)
def cors_reponse(retcode=RetCode.SUCCESS, retmsg='success', data=None, auth=None): def cors_reponse(retcode=RetCode.SUCCESS,
retmsg='success', data=None, auth=None):
result_dict = {"retcode": retcode, "retmsg": retmsg, "data": data} result_dict = {"retcode": retcode, "retmsg": retmsg, "data": data}
response_dict = {} response_dict = {}
for key, value in result_dict.items(): for key, value in result_dict.items():

View File

@ -29,6 +29,7 @@ from api.db import FileType
PROJECT_BASE = os.getenv("RAG_PROJECT_BASE") or os.getenv("RAG_DEPLOY_BASE") PROJECT_BASE = os.getenv("RAG_PROJECT_BASE") or os.getenv("RAG_DEPLOY_BASE")
RAG_BASE = os.getenv("RAG_BASE") RAG_BASE = os.getenv("RAG_BASE")
def get_project_base_directory(*args): def get_project_base_directory(*args):
global PROJECT_BASE global PROJECT_BASE
if PROJECT_BASE is None: if PROJECT_BASE is None:
@ -65,7 +66,6 @@ def get_rag_python_directory(*args):
return get_rag_directory("python", *args) return get_rag_directory("python", *args)
@cached(cache=LRUCache(maxsize=10)) @cached(cache=LRUCache(maxsize=10))
def load_json_conf(conf_path): def load_json_conf(conf_path):
if os.path.isabs(conf_path): if os.path.isabs(conf_path):
@ -146,10 +146,12 @@ def filename_type(filename):
if re.match(r".*\.pdf$", filename): if re.match(r".*\.pdf$", filename):
return FileType.PDF.value return FileType.PDF.value
if re.match(r".*\.(docx|doc|ppt|pptx|yml|xml|htm|json|csv|txt|ini|xls|xlsx|wps|rtf|hlp|pages|numbers|key|md)$", filename): if re.match(
r".*\.(docx|doc|ppt|pptx|yml|xml|htm|json|csv|txt|ini|xls|xlsx|wps|rtf|hlp|pages|numbers|key|md)$", filename):
return FileType.DOC.value return FileType.DOC.value
if re.match(r".*\.(wav|flac|ape|alac|wavpack|wv|mp3|aac|ogg|vorbis|opus|mp3)$", filename): if re.match(
r".*\.(wav|flac|ape|alac|wavpack|wv|mp3|aac|ogg|vorbis|opus|mp3)$", filename):
return FileType.AURAL.value return FileType.AURAL.value
if re.match(r".*\.(jpg|jpeg|png|tif|gif|pcx|tga|exif|fpx|svg|psd|cdr|pcd|dxf|ufo|eps|ai|raw|WMF|webp|avif|apng|icon|ico|mpg|mpeg|avi|rm|rmvb|mov|wmv|asf|dat|asx|wvx|mpe|mpa|mp4)$", filename): if re.match(r".*\.(jpg|jpeg|png|tif|gif|pcx|tga|exif|fpx|svg|psd|cdr|pcd|dxf|ufo|eps|ai|raw|WMF|webp|avif|apng|icon|ico|mpg|mpeg|avi|rm|rmvb|mov|wmv|asf|dat|asx|wvx|mpe|mpa|mp4)$", filename):
@ -164,14 +166,16 @@ def thumbnail(filename, blob):
buffered = BytesIO() buffered = BytesIO()
Image.frombytes("RGB", [pix.width, pix.height], Image.frombytes("RGB", [pix.width, pix.height],
pix.samples).save(buffered, format="png") pix.samples).save(buffered, format="png")
return "data:image/png;base64," + base64.b64encode(buffered.getvalue()).decode("utf-8") return "data:image/png;base64," + \
base64.b64encode(buffered.getvalue()).decode("utf-8")
if re.match(r".*\.(jpg|jpeg|png|tif|gif|icon|ico|webp)$", filename): if re.match(r".*\.(jpg|jpeg|png|tif|gif|icon|ico|webp)$", filename):
image = Image.open(BytesIO(blob)) image = Image.open(BytesIO(blob))
image.thumbnail((30, 30)) image.thumbnail((30, 30))
buffered = BytesIO() buffered = BytesIO()
image.save(buffered, format="png") image.save(buffered, format="png")
return "data:image/png;base64," + base64.b64encode(buffered.getvalue()).decode("utf-8") return "data:image/png;base64," + \
base64.b64encode(buffered.getvalue()).decode("utf-8")
if re.match(r".*\.(ppt|pptx)$", filename): if re.match(r".*\.(ppt|pptx)$", filename):
import aspose.slides as slides import aspose.slides as slides
@ -179,8 +183,10 @@ def thumbnail(filename, blob):
try: try:
with slides.Presentation(BytesIO(blob)) as presentation: with slides.Presentation(BytesIO(blob)) as presentation:
buffered = BytesIO() buffered = BytesIO()
presentation.slides[0].get_thumbnail(0.03, 0.03).save(buffered, drawing.imaging.ImageFormat.png) presentation.slides[0].get_thumbnail(0.03, 0.03).save(
return "data:image/png;base64," + base64.b64encode(buffered.getvalue()).decode("utf-8") buffered, drawing.imaging.ImageFormat.png)
return "data:image/png;base64," + \
base64.b64encode(buffered.getvalue()).decode("utf-8")
except Exception as e: except Exception as e:
pass pass
@ -190,6 +196,3 @@ def traversal_files(base):
for f in fs: for f in fs:
fullname = os.path.join(root, f) fullname = os.path.join(root, f)
yield fullname yield fullname

View File

@ -23,6 +23,7 @@ from threading import RLock
from api.utils import file_utils from api.utils import file_utils
class LoggerFactory(object): class LoggerFactory(object):
TYPE = "FILE" TYPE = "FILE"
LOG_FORMAT = "[%(levelname)s] [%(asctime)s] [jobId] [%(process)s:%(thread)s] - [%(module)s.%(funcName)s] [line:%(lineno)d]: %(message)s" LOG_FORMAT = "[%(levelname)s] [%(asctime)s] [jobId] [%(process)s:%(thread)s] - [%(module)s.%(funcName)s] [line:%(lineno)d]: %(message)s"
@ -49,7 +50,8 @@ class LoggerFactory(object):
schedule_logger_dict = {} schedule_logger_dict = {}
@staticmethod @staticmethod
def set_directory(directory=None, parent_log_dir=None, append_to_parent_log=None, force=False): def set_directory(directory=None, parent_log_dir=None,
append_to_parent_log=None, force=False):
if parent_log_dir: if parent_log_dir:
LoggerFactory.PARENT_LOG_DIR = parent_log_dir LoggerFactory.PARENT_LOG_DIR = parent_log_dir
if append_to_parent_log: if append_to_parent_log:
@ -66,11 +68,13 @@ class LoggerFactory(object):
else: else:
os.makedirs(LoggerFactory.LOG_DIR, exist_ok=True) os.makedirs(LoggerFactory.LOG_DIR, exist_ok=True)
for loggerName, ghandler in LoggerFactory.global_handler_dict.items(): for loggerName, ghandler in LoggerFactory.global_handler_dict.items():
for className, (logger, handler) in LoggerFactory.logger_dict.items(): for className, (logger,
handler) in LoggerFactory.logger_dict.items():
logger.removeHandler(ghandler) logger.removeHandler(ghandler)
ghandler.close() ghandler.close()
LoggerFactory.global_handler_dict = {} LoggerFactory.global_handler_dict = {}
for className, (logger, handler) in LoggerFactory.logger_dict.items(): for className, (logger,
handler) in LoggerFactory.logger_dict.items():
logger.removeHandler(handler) logger.removeHandler(handler)
_handler = None _handler = None
if handler: if handler:
@ -111,19 +115,23 @@ class LoggerFactory(object):
if logger_name_key not in LoggerFactory.global_handler_dict: if logger_name_key not in LoggerFactory.global_handler_dict:
with LoggerFactory.lock: with LoggerFactory.lock:
if logger_name_key not in LoggerFactory.global_handler_dict: if logger_name_key not in LoggerFactory.global_handler_dict:
handler = LoggerFactory.get_handler(logger_name, level, log_dir) handler = LoggerFactory.get_handler(
logger_name, level, log_dir)
LoggerFactory.global_handler_dict[logger_name_key] = handler LoggerFactory.global_handler_dict[logger_name_key] = handler
return LoggerFactory.global_handler_dict[logger_name_key] return LoggerFactory.global_handler_dict[logger_name_key]
@staticmethod @staticmethod
def get_handler(class_name, level=None, log_dir=None, log_type=None, job_id=None): def get_handler(class_name, level=None, log_dir=None,
log_type=None, job_id=None):
if not log_type: if not log_type:
if not LoggerFactory.LOG_DIR or not class_name: if not LoggerFactory.LOG_DIR or not class_name:
return logging.StreamHandler() return logging.StreamHandler()
# return Diy_StreamHandler() # return Diy_StreamHandler()
if not log_dir: if not log_dir:
log_file = os.path.join(LoggerFactory.LOG_DIR, "{}.log".format(class_name)) log_file = os.path.join(
LoggerFactory.LOG_DIR,
"{}.log".format(class_name))
else: else:
log_file = os.path.join(log_dir, "{}.log".format(class_name)) log_file = os.path.join(log_dir, "{}.log".format(class_name))
else: else:
@ -170,7 +178,9 @@ class LoggerFactory(object):
for level in LoggerFactory.levels: for level in LoggerFactory.levels:
if level >= LoggerFactory.LEVEL: if level >= LoggerFactory.LEVEL:
level_logger_name = logging._levelToName[level] level_logger_name = logging._levelToName[level]
logger.addHandler(LoggerFactory.get_global_handler(level_logger_name, level)) logger.addHandler(
LoggerFactory.get_global_handler(
level_logger_name, level))
if LoggerFactory.append_to_parent_log and LoggerFactory.PARENT_LOG_DIR: if LoggerFactory.append_to_parent_log and LoggerFactory.PARENT_LOG_DIR:
for level in LoggerFactory.levels: for level in LoggerFactory.levels:
if level >= LoggerFactory.LEVEL: if level >= LoggerFactory.LEVEL:
@ -224,22 +234,26 @@ def start_log(msg, job=None, task=None, role=None, party_id=None, detail=None):
return f"{prefix}start to {msg}{suffix}" return f"{prefix}start to {msg}{suffix}"
def successful_log(msg, job=None, task=None, role=None, party_id=None, detail=None): def successful_log(msg, job=None, task=None, role=None,
party_id=None, detail=None):
prefix, suffix = base_msg(job, task, role, party_id, detail) prefix, suffix = base_msg(job, task, role, party_id, detail)
return f"{prefix}{msg} successfully{suffix}" return f"{prefix}{msg} successfully{suffix}"
def warning_log(msg, job=None, task=None, role=None, party_id=None, detail=None): def warning_log(msg, job=None, task=None, role=None,
party_id=None, detail=None):
prefix, suffix = base_msg(job, task, role, party_id, detail) prefix, suffix = base_msg(job, task, role, party_id, detail)
return f"{prefix}{msg} is not effective{suffix}" return f"{prefix}{msg} is not effective{suffix}"
def failed_log(msg, job=None, task=None, role=None, party_id=None, detail=None): def failed_log(msg, job=None, task=None, role=None,
party_id=None, detail=None):
prefix, suffix = base_msg(job, task, role, party_id, detail) prefix, suffix = base_msg(job, task, role, party_id, detail)
return f"{prefix}failed to {msg}{suffix}" return f"{prefix}failed to {msg}{suffix}"
def base_msg(job=None, task=None, role: str = None, party_id: typing.Union[str, int] = None, detail=None): def base_msg(job=None, task=None, role: str = None,
party_id: typing.Union[str, int] = None, detail=None):
if detail: if detail:
detail_msg = f" detail: \n{detail}" detail_msg = f" detail: \n{detail}"
else: else:
@ -285,10 +299,14 @@ def get_job_logger(job_id, log_type):
for job_log_dir in log_dirs: for job_log_dir in log_dirs:
handler = LoggerFactory.get_handler(class_name=None, level=LoggerFactory.LEVEL, handler = LoggerFactory.get_handler(class_name=None, level=LoggerFactory.LEVEL,
log_dir=job_log_dir, log_type=log_type, job_id=job_id) log_dir=job_log_dir, log_type=log_type, job_id=job_id)
error_handler = LoggerFactory.get_handler(class_name=None, level=logging.ERROR, log_dir=job_log_dir, log_type=log_type, job_id=job_id) error_handler = LoggerFactory.get_handler(
class_name=None,
level=logging.ERROR,
log_dir=job_log_dir,
log_type=log_type,
job_id=job_id)
logger.addHandler(handler) logger.addHandler(handler)
logger.addHandler(error_handler) logger.addHandler(error_handler)
with LoggerFactory.lock: with LoggerFactory.lock:
LoggerFactory.schedule_logger_dict[job_id + log_type] = logger LoggerFactory.schedule_logger_dict[job_id + log_type] = logger
return logger return logger

View File

@ -1,18 +1,23 @@
import base64, os, sys import base64
import os
import sys
from Cryptodome.PublicKey import RSA from Cryptodome.PublicKey import RSA
from Cryptodome.Cipher import PKCS1_v1_5 as Cipher_pkcs1_v1_5 from Cryptodome.Cipher import PKCS1_v1_5 as Cipher_pkcs1_v1_5
from api.utils import decrypt, file_utils from api.utils import decrypt, file_utils
def crypt(line): def crypt(line):
file_path = os.path.join(file_utils.get_project_base_directory(), "conf", "public.pem") file_path = os.path.join(
file_utils.get_project_base_directory(),
"conf",
"public.pem")
rsa_key = RSA.importKey(open(file_path).read()) rsa_key = RSA.importKey(open(file_path).read())
cipher = Cipher_pkcs1_v1_5.new(rsa_key) cipher = Cipher_pkcs1_v1_5.new(rsa_key)
return base64.b64encode(cipher.encrypt(line.encode('utf-8'))).decode("utf-8") return base64.b64encode(cipher.encrypt(
line.encode('utf-8'))).decode("utf-8")
if __name__ == "__main__": if __name__ == "__main__":
pswd = crypt(sys.argv[1]) pswd = crypt(sys.argv[1])
print(pswd) print(pswd)
print(decrypt(pswd)) print(decrypt(pswd))

View File

@ -4,5 +4,3 @@ from .pdf_parser import HuParser as PdfParser, PlainParser
from .docx_parser import HuDocxParser as DocxParser from .docx_parser import HuDocxParser as DocxParser
from .excel_parser import HuExcelParser as ExcelParser from .excel_parser import HuExcelParser as ExcelParser
from .ppt_parser import HuPptParser as PptParser from .ppt_parser import HuPptParser as PptParser

View File

@ -99,12 +99,15 @@ class HuDocxParser:
return ["\n".join(lines)] return ["\n".join(lines)]
def __call__(self, fnm, from_page=0, to_page=100000): def __call__(self, fnm, from_page=0, to_page=100000):
self.doc = Document(fnm) if isinstance(fnm, str) else Document(BytesIO(fnm)) self.doc = Document(fnm) if isinstance(
fnm, str) else Document(BytesIO(fnm))
pn = 0 pn = 0
secs = [] secs = []
for p in self.doc.paragraphs: for p in self.doc.paragraphs:
if pn > to_page: break if pn > to_page:
if from_page <= pn < to_page and p.text.strip(): secs.append((p.text, p.style.name)) break
if from_page <= pn < to_page and p.text.strip():
secs.append((p.text, p.style.name))
for run in p.runs: for run in p.runs:
if 'lastRenderedPageBreak' in run._element.xml: if 'lastRenderedPageBreak' in run._element.xml:
pn += 1 pn += 1

View File

@ -15,13 +15,16 @@ class HuExcelParser:
ws = wb[sheetname] ws = wb[sheetname]
rows = list(ws.rows) rows = list(ws.rows)
tb += f"<table><caption>{sheetname}</caption><tr>" tb += f"<table><caption>{sheetname}</caption><tr>"
for t in list(rows[0]): tb += f"<th>{t.value}</th>" for t in list(rows[0]):
tb += f"<th>{t.value}</th>"
tb += "</tr>" tb += "</tr>"
for r in list(rows[1:]): for r in list(rows[1:]):
tb += "<tr>" tb += "<tr>"
for i,c in enumerate(r): for i, c in enumerate(r):
if c.value is None: tb += "<td></td>" if c.value is None:
else: tb += f"<td>{c.value}</td>" tb += "<td></td>"
else:
tb += f"<td>{c.value}</td>"
tb += "</tr>" tb += "</tr>"
tb += "</table>\n" tb += "</table>\n"
return tb return tb
@ -38,13 +41,15 @@ class HuExcelParser:
ti = list(rows[0]) ti = list(rows[0])
for r in list(rows[1:]): for r in list(rows[1:]):
l = [] l = []
for i,c in enumerate(r): for i, c in enumerate(r):
if not c.value:continue if not c.value:
continue
t = str(ti[i].value) if i < len(ti) else "" t = str(ti[i].value) if i < len(ti) else ""
t += ("" if t else "") + str(c.value) t += ("" if t else "") + str(c.value)
l.append(t) l.append(t)
l = "; ".join(l) l = "; ".join(l)
if sheetname.lower().find("sheet") <0: l += " ——"+sheetname if sheetname.lower().find("sheet") < 0:
l += " ——" + sheetname
res.append(l) res.append(l)
return res return res

View File

@ -43,9 +43,11 @@ class HuParser:
"rag/res/deepdoc"), "rag/res/deepdoc"),
local_files_only=True) local_files_only=True)
except Exception as e: except Exception as e:
model_dir = snapshot_download(repo_id="InfiniFlow/text_concat_xgb_v1.0") model_dir = snapshot_download(
repo_id="InfiniFlow/text_concat_xgb_v1.0")
self.updown_cnt_mdl.load_model(os.path.join(model_dir, "updown_concat_xgb.model")) self.updown_cnt_mdl.load_model(os.path.join(
model_dir, "updown_concat_xgb.model"))
self.page_from = 0 self.page_from = 0
""" """
If you have trouble downloading HuggingFace models, -_^ this might help!! If you have trouble downloading HuggingFace models, -_^ this might help!!
@ -235,7 +237,8 @@ class HuParser:
b["R_top"] = rows[ii]["top"] b["R_top"] = rows[ii]["top"]
b["R_bott"] = rows[ii]["bottom"] b["R_bott"] = rows[ii]["bottom"]
ii = Recognizer.find_overlapped_with_threashold(b, headers, thr=0.3) ii = Recognizer.find_overlapped_with_threashold(
b, headers, thr=0.3)
if ii is not None: if ii is not None:
b["H_top"] = headers[ii]["top"] b["H_top"] = headers[ii]["top"]
b["H_bott"] = headers[ii]["bottom"] b["H_bott"] = headers[ii]["bottom"]
@ -272,7 +275,8 @@ class HuParser:
) )
# merge chars in the same rect # merge chars in the same rect
for c in Recognizer.sort_X_firstly(chars, self.mean_width[pagenum - 1] // 4): for c in Recognizer.sort_X_firstly(
chars, self.mean_width[pagenum - 1] // 4):
ii = Recognizer.find_overlapped(c, bxs) ii = Recognizer.find_overlapped(c, bxs)
if ii is None: if ii is None:
self.lefted_chars.append(c) self.lefted_chars.append(c)
@ -283,13 +287,15 @@ class HuParser:
self.lefted_chars.append(c) self.lefted_chars.append(c)
continue continue
if c["text"] == " " and bxs[ii]["text"]: if c["text"] == " " and bxs[ii]["text"]:
if re.match(r"[0-9a-zA-Z,.?;:!%%]", bxs[ii]["text"][-1]): bxs[ii]["text"] += " " if re.match(r"[0-9a-zA-Z,.?;:!%%]", bxs[ii]["text"][-1]):
bxs[ii]["text"] += " "
else: else:
bxs[ii]["text"] += c["text"] bxs[ii]["text"] += c["text"]
for b in bxs: for b in bxs:
if not b["text"]: if not b["text"]:
left, right, top, bott = b["x0"] * ZM, b["x1"] * ZM, b["top"] * ZM, b["bottom"] * ZM left, right, top, bott = b["x0"] * ZM, b["x1"] * \
ZM, b["top"] * ZM, b["bottom"] * ZM
b["text"] = self.ocr.recognize(np.array(img), b["text"] = self.ocr.recognize(np.array(img),
np.array([[left, top], [right, top], [right, bott], [left, bott]], np.array([[left, top], [right, top], [right, bott], [left, bott]],
dtype=np.float32)) dtype=np.float32))
@ -302,7 +308,8 @@ class HuParser:
def _layouts_rec(self, ZM, drop=True): def _layouts_rec(self, ZM, drop=True):
assert len(self.page_images) == len(self.boxes) assert len(self.page_images) == len(self.boxes)
self.boxes, self.page_layout = self.layouter(self.page_images, self.boxes, ZM, drop=drop) self.boxes, self.page_layout = self.layouter(
self.page_images, self.boxes, ZM, drop=drop)
# cumlative Y # cumlative Y
for i in range(len(self.boxes)): for i in range(len(self.boxes)):
self.boxes[i]["top"] += \ self.boxes[i]["top"] += \
@ -332,7 +339,8 @@ class HuParser:
"equation"]: "equation"]:
i += 1 i += 1
continue continue
if abs(self._y_dis(b, b_)) < self.mean_height[bxs[i]["page_number"] - 1] / 3: if abs(self._y_dis(b, b_)
) < self.mean_height[bxs[i]["page_number"] - 1] / 3:
# merge # merge
bxs[i]["x1"] = b_["x1"] bxs[i]["x1"] = b_["x1"]
bxs[i]["top"] = (b["top"] + b_["top"]) / 2 bxs[i]["top"] = (b["top"] + b_["top"]) / 2
@ -366,12 +374,15 @@ class HuParser:
self.boxes = bxs self.boxes = bxs
def _naive_vertical_merge(self): def _naive_vertical_merge(self):
bxs = Recognizer.sort_Y_firstly(self.boxes, np.median(self.mean_height) / 3) bxs = Recognizer.sort_Y_firstly(
self.boxes, np.median(
self.mean_height) / 3)
i = 0 i = 0
while i + 1 < len(bxs): while i + 1 < len(bxs):
b = bxs[i] b = bxs[i]
b_ = bxs[i + 1] b_ = bxs[i + 1]
if b["page_number"] < b_["page_number"] and re.match(r"[0-9 •一—-]+$", b["text"]): if b["page_number"] < b_["page_number"] and re.match(
r"[0-9 •一—-]+$", b["text"]):
bxs.pop(i) bxs.pop(i)
continue continue
if not b["text"].strip(): if not b["text"].strip():
@ -379,7 +390,8 @@ class HuParser:
continue continue
concatting_feats = [ concatting_feats = [
b["text"].strip()[-1] in ",;:'\",、‘“;:-", b["text"].strip()[-1] in ",;:'\",、‘“;:-",
len(b["text"].strip()) > 1 and b["text"].strip()[-2] in ",;:'\",‘“、;:", len(b["text"].strip()) > 1 and b["text"].strip(
)[-2] in ",;:'\",‘“、;:",
b["text"].strip()[0] in "。;?!?”)),,、:", b["text"].strip()[0] in "。;?!?”)),,、:",
] ]
# features for not concating # features for not concating
@ -387,7 +399,7 @@ class HuParser:
b.get("layoutno", 0) != b.get("layoutno", 0), b.get("layoutno", 0) != b.get("layoutno", 0),
b["text"].strip()[-1] in "。?!?", b["text"].strip()[-1] in "。?!?",
self.is_english and b["text"].strip()[-1] in ".!?", self.is_english and b["text"].strip()[-1] in ".!?",
b["page_number"] == b_["page_number"] and b_["top"] - \ b["page_number"] == b_["page_number"] and b_["top"] -
b["bottom"] > self.mean_height[b["page_number"] - 1] * 1.5, b["bottom"] > self.mean_height[b["page_number"] - 1] * 1.5,
b["page_number"] < b_["page_number"] and abs( b["page_number"] < b_["page_number"] and abs(
b["x0"] - b_["x0"]) > self.mean_width[b["page_number"] - 1] * 4, b["x0"] - b_["x0"]) > self.mean_width[b["page_number"] - 1] * 4,
@ -396,7 +408,12 @@ class HuParser:
detach_feats = [b["x1"] < b_["x0"], detach_feats = [b["x1"] < b_["x0"],
b["x0"] > b_["x1"]] b["x0"] > b_["x1"]]
if (any(feats) and not any(concatting_feats)) or any(detach_feats): if (any(feats) and not any(concatting_feats)) or any(detach_feats):
print(b["text"], b_["text"], any(feats), any(concatting_feats), any(detach_feats)) print(
b["text"],
b_["text"],
any(feats),
any(concatting_feats),
any(detach_feats))
i += 1 i += 1
continue continue
# merge up and down # merge up and down
@ -526,31 +543,39 @@ class HuParser:
i += 1 i += 1
continue continue
findit = True findit = True
eng = re.match(r"[0-9a-zA-Z :'.-]{5,}", self.boxes[i]["text"].strip()) eng = re.match(
r"[0-9a-zA-Z :'.-]{5,}",
self.boxes[i]["text"].strip())
self.boxes.pop(i) self.boxes.pop(i)
if i >= len(self.boxes): break if i >= len(self.boxes):
break
prefix = self.boxes[i]["text"].strip()[:3] if not eng else " ".join( prefix = self.boxes[i]["text"].strip()[:3] if not eng else " ".join(
self.boxes[i]["text"].strip().split(" ")[:2]) self.boxes[i]["text"].strip().split(" ")[:2])
while not prefix: while not prefix:
self.boxes.pop(i) self.boxes.pop(i)
if i >= len(self.boxes): break if i >= len(self.boxes):
break
prefix = self.boxes[i]["text"].strip()[:3] if not eng else " ".join( prefix = self.boxes[i]["text"].strip()[:3] if not eng else " ".join(
self.boxes[i]["text"].strip().split(" ")[:2]) self.boxes[i]["text"].strip().split(" ")[:2])
self.boxes.pop(i) self.boxes.pop(i)
if i >= len(self.boxes) or not prefix: break if i >= len(self.boxes) or not prefix:
break
for j in range(i, min(i + 128, len(self.boxes))): for j in range(i, min(i + 128, len(self.boxes))):
if not re.match(prefix, self.boxes[j]["text"]): if not re.match(prefix, self.boxes[j]["text"]):
continue continue
for k in range(i, j): self.boxes.pop(i) for k in range(i, j):
self.boxes.pop(i)
break break
if findit: return if findit:
return
page_dirty = [0] * len(self.page_images) page_dirty = [0] * len(self.page_images)
for b in self.boxes: for b in self.boxes:
if re.search(r"(··|··|··)", b["text"]): if re.search(r"(··|··|··)", b["text"]):
page_dirty[b["page_number"] - 1] += 1 page_dirty[b["page_number"] - 1] += 1
page_dirty = set([i + 1 for i, t in enumerate(page_dirty) if t > 3]) page_dirty = set([i + 1 for i, t in enumerate(page_dirty) if t > 3])
if not page_dirty: return if not page_dirty:
return
i = 0 i = 0
while i < len(self.boxes): while i < len(self.boxes):
if self.boxes[i]["page_number"] in page_dirty: if self.boxes[i]["page_number"] in page_dirty:
@ -582,7 +607,8 @@ class HuParser:
b_["top"] = b["top"] b_["top"] = b["top"]
self.boxes.pop(i) self.boxes.pop(i)
def _extract_table_figure(self, need_image, ZM, return_html, need_position): def _extract_table_figure(self, need_image, ZM,
return_html, need_position):
tables = {} tables = {}
figures = {} figures = {}
# extract figure and table boxes # extract figure and table boxes
@ -761,7 +787,8 @@ class HuParser:
for k, bxs in tables.items(): for k, bxs in tables.items():
if not bxs: if not bxs:
continue continue
bxs = Recognizer.sort_Y_firstly(bxs, np.mean([(b["bottom"] - b["top"]) / 2 for b in bxs])) bxs = Recognizer.sort_Y_firstly(bxs, np.mean(
[(b["bottom"] - b["top"]) / 2 for b in bxs]))
poss = [] poss = []
res.append((cropout(bxs, "table", poss), res.append((cropout(bxs, "table", poss),
self.tbl_det.construct_table(bxs, html=return_html, is_english=self.is_english))) self.tbl_det.construct_table(bxs, html=return_html, is_english=self.is_english)))
@ -769,7 +796,8 @@ class HuParser:
assert len(positions) == len(res) assert len(positions) == len(res)
if need_position: return list(zip(res, positions)) if need_position:
return list(zip(res, positions))
return res return res
def proj_match(self, line): def proj_match(self, line):
@ -873,7 +901,8 @@ class HuParser:
boxes.pop(0) boxes.pop(0)
mw = np.mean(widths) mw = np.mean(widths)
if mj or mw / pw >= 0.35 or mw > 200: if mj or mw / pw >= 0.35 or mw > 200:
res.append("\n".join([c["text"] + self._line_tag(c, ZM) for c in lines])) res.append(
"\n".join([c["text"] + self._line_tag(c, ZM) for c in lines]))
else: else:
logging.debug("REMOVED: " + logging.debug("REMOVED: " +
"<<".join([c["text"] for c in lines])) "<<".join([c["text"] for c in lines]))
@ -883,13 +912,16 @@ class HuParser:
@staticmethod @staticmethod
def total_page_number(fnm, binary=None): def total_page_number(fnm, binary=None):
try: try:
pdf = pdfplumber.open(fnm) if not binary else pdfplumber.open(BytesIO(binary)) pdf = pdfplumber.open(
fnm) if not binary else pdfplumber.open(BytesIO(binary))
return len(pdf.pages) return len(pdf.pages)
except Exception as e: except Exception as e:
pdf = fitz.open(fnm) if not binary else fitz.open(stream=fnm, filetype="pdf") pdf = fitz.open(fnm) if not binary else fitz.open(
stream=fnm, filetype="pdf")
return len(pdf) return len(pdf)
def __images__(self, fnm, zoomin=3, page_from=0, page_to=299, callback=None): def __images__(self, fnm, zoomin=3, page_from=0,
page_to=299, callback=None):
self.lefted_chars = [] self.lefted_chars = []
self.mean_height = [] self.mean_height = []
self.mean_width = [] self.mean_width = []
@ -899,21 +931,26 @@ class HuParser:
self.page_layout = [] self.page_layout = []
self.page_from = page_from self.page_from = page_from
try: try:
self.pdf = pdfplumber.open(fnm) if isinstance(fnm, str) else pdfplumber.open(BytesIO(fnm)) self.pdf = pdfplumber.open(fnm) if isinstance(
fnm, str) else pdfplumber.open(BytesIO(fnm))
self.page_images = [p.to_image(resolution=72 * zoomin).annotated for i, p in self.page_images = [p.to_image(resolution=72 * zoomin).annotated for i, p in
enumerate(self.pdf.pages[page_from:page_to])] enumerate(self.pdf.pages[page_from:page_to])]
self.page_chars = [[c for c in page.chars if self._has_color(c)] for page in self.page_chars = [[c for c in page.chars if self._has_color(c)] for page in
self.pdf.pages[page_from:page_to]] self.pdf.pages[page_from:page_to]]
self.total_page = len(self.pdf.pages) self.total_page = len(self.pdf.pages)
except Exception as e: except Exception as e:
self.pdf = fitz.open(fnm) if isinstance(fnm, str) else fitz.open(stream=fnm, filetype="pdf") self.pdf = fitz.open(fnm) if isinstance(
fnm, str) else fitz.open(
stream=fnm, filetype="pdf")
self.page_images = [] self.page_images = []
self.page_chars = [] self.page_chars = []
mat = fitz.Matrix(zoomin, zoomin) mat = fitz.Matrix(zoomin, zoomin)
self.total_page = len(self.pdf) self.total_page = len(self.pdf)
for i, page in enumerate(self.pdf): for i, page in enumerate(self.pdf):
if i < page_from: continue if i < page_from:
if i >= page_to: break continue
if i >= page_to:
break
pix = page.get_pixmap(matrix=mat) pix = page.get_pixmap(matrix=mat)
img = Image.frombytes("RGB", [pix.width, pix.height], img = Image.frombytes("RGB", [pix.width, pix.height],
pix.samples) pix.samples)
@ -930,7 +967,7 @@ class HuParser:
if isinstance(a, dict): if isinstance(a, dict):
self.outlines.append((a["/Title"], depth)) self.outlines.append((a["/Title"], depth))
continue continue
dfs(a, depth+1) dfs(a, depth + 1)
dfs(outlines, 0) dfs(outlines, 0)
except Exception as e: except Exception as e:
logging.warning(f"Outlines exception: {e}") logging.warning(f"Outlines exception: {e}")
@ -941,7 +978,8 @@ class HuParser:
self.is_english = [re.search(r"[a-zA-Z0-9,/¸;:'\[\]\(\)!@#$%^&*\"?<>._-]{30,}", "".join( self.is_english = [re.search(r"[a-zA-Z0-9,/¸;:'\[\]\(\)!@#$%^&*\"?<>._-]{30,}", "".join(
random.choices([c["text"] for c in self.page_chars[i]], k=min(100, len(self.page_chars[i]))))) for i in random.choices([c["text"] for c in self.page_chars[i]], k=min(100, len(self.page_chars[i]))))) for i in
range(len(self.page_chars))] range(len(self.page_chars))]
if sum([1 if e else 0 for e in self.is_english]) > len(self.page_images) / 2: if sum([1 if e else 0 for e in self.is_english]) > len(
self.page_images) / 2:
self.is_english = True self.is_english = True
else: else:
self.is_english = False self.is_english = False
@ -970,9 +1008,11 @@ class HuParser:
# self.page_cum_height.append( # self.page_cum_height.append(
# np.max([c["bottom"] for c in chars])) # np.max([c["bottom"] for c in chars]))
self.__ocr(i + 1, img, chars, zoomin) self.__ocr(i + 1, img, chars, zoomin)
if callback: callback(prog=(i + 1) * 0.6 / len(self.page_images), msg="") if callback:
callback(prog=(i + 1) * 0.6 / len(self.page_images), msg="")
if not self.is_english and not any([c for c in self.page_chars]) and self.boxes: if not self.is_english and not any(
[c for c in self.page_chars]) and self.boxes:
bxes = [b for bxs in self.boxes for b in bxs] bxes = [b for bxs in self.boxes for b in bxs]
self.is_english = re.search(r"[\na-zA-Z0-9,/¸;:'\[\]\(\)!@#$%^&*\"?<>._-]{30,}", self.is_english = re.search(r"[\na-zA-Z0-9,/¸;:'\[\]\(\)!@#$%^&*\"?<>._-]{30,}",
"".join([b["text"] for b in random.choices(bxes, k=min(30, len(bxes)))])) "".join([b["text"] for b in random.choices(bxes, k=min(30, len(bxes)))]))
@ -989,7 +1029,8 @@ class HuParser:
self._text_merge() self._text_merge()
self._concat_downward() self._concat_downward()
self._filter_forpages() self._filter_forpages()
tbls = self._extract_table_figure(need_image, zoomin, return_html, False) tbls = self._extract_table_figure(
need_image, zoomin, return_html, False)
return self.__filterout_scraps(deepcopy(self.boxes), zoomin), tbls return self.__filterout_scraps(deepcopy(self.boxes), zoomin), tbls
def remove_tag(self, txt): def remove_tag(self, txt):
@ -1003,15 +1044,19 @@ class HuParser:
"#").strip("@").split("\t") "#").strip("@").split("\t")
left, right, top, bottom = float(left), float( left, right, top, bottom = float(left), float(
right), float(top), float(bottom) right), float(top), float(bottom)
poss.append(([int(p) - 1 for p in pn.split("-")], left, right, top, bottom)) poss.append(([int(p) - 1 for p in pn.split("-")],
left, right, top, bottom))
if not poss: if not poss:
if need_position: return None, None if need_position:
return None, None
return return
max_width = max(np.max([right - left for (_, left, right, _, _) in poss]), 6) max_width = max(
np.max([right - left for (_, left, right, _, _) in poss]), 6)
GAP = 6 GAP = 6
pos = poss[0] pos = poss[0]
poss.insert(0, ([pos[0][0]], pos[1], pos[2], max(0, pos[3] - 120), max(pos[3] - GAP, 0))) poss.insert(0, ([pos[0][0]], pos[1], pos[2], max(
0, pos[3] - 120), max(pos[3] - GAP, 0)))
pos = poss[-1] pos = poss[-1]
poss.append(([pos[0][-1]], pos[1], pos[2], min(self.page_images[pos[0][-1]].size[1] / ZM, pos[4] + GAP), poss.append(([pos[0][-1]], pos[1], pos[2], min(self.page_images[pos[0][-1]].size[1] / ZM, pos[4] + GAP),
min(self.page_images[pos[0][-1]].size[1] / ZM, pos[4] + 120))) min(self.page_images[pos[0][-1]].size[1] / ZM, pos[4] + 120)))
@ -1047,7 +1092,8 @@ class HuParser:
bottom -= self.page_images[pn].size[1] bottom -= self.page_images[pn].size[1]
if not imgs: if not imgs:
if need_position: return None, None if need_position:
return None, None
return return
height = 0 height = 0
for img in imgs: for img in imgs:
@ -1076,12 +1122,14 @@ class HuParser:
pn = bx["page_number"] pn = bx["page_number"]
top = bx["top"] - self.page_cum_height[pn - 1] top = bx["top"] - self.page_cum_height[pn - 1]
bott = bx["bottom"] - self.page_cum_height[pn - 1] bott = bx["bottom"] - self.page_cum_height[pn - 1]
poss.append((pn, bx["x0"], bx["x1"], top, min(bott, self.page_images[pn - 1].size[1] / ZM))) poss.append((pn, bx["x0"], bx["x1"], top, min(
bott, self.page_images[pn - 1].size[1] / ZM)))
while bott * ZM > self.page_images[pn - 1].size[1]: while bott * ZM > self.page_images[pn - 1].size[1]:
bott -= self.page_images[pn - 1].size[1] / ZM bott -= self.page_images[pn - 1].size[1] / ZM
top = 0 top = 0
pn += 1 pn += 1
poss.append((pn, bx["x0"], bx["x1"], top, min(bott, self.page_images[pn - 1].size[1] / ZM))) poss.append((pn, bx["x0"], bx["x1"], top, min(
bott, self.page_images[pn - 1].size[1] / ZM)))
return poss return poss
@ -1090,11 +1138,14 @@ class PlainParser(object):
self.outlines = [] self.outlines = []
lines = [] lines = []
try: try:
self.pdf = pdf2_read(filename if isinstance(filename, str) else BytesIO(filename)) self.pdf = pdf2_read(
filename if isinstance(
filename, str) else BytesIO(filename))
for page in self.pdf.pages[from_page:to_page]: for page in self.pdf.pages[from_page:to_page]:
lines.extend([t for t in page.extract_text().split("\n")]) lines.extend([t for t in page.extract_text().split("\n")])
outlines = self.pdf.outline outlines = self.pdf.outline
def dfs(arr, depth): def dfs(arr, depth):
for a in arr: for a in arr:
if isinstance(a, dict): if isinstance(a, dict):
@ -1117,5 +1168,6 @@ class PlainParser(object):
def remove_tag(txt): def remove_tag(txt):
raise NotImplementedError raise NotImplementedError
if __name__ == "__main__": if __name__ == "__main__":
pass pass

View File

@ -23,7 +23,8 @@ class HuPptParser(object):
tb = shape.table tb = shape.table
rows = [] rows = []
for i in range(1, len(tb.rows)): for i in range(1, len(tb.rows)):
rows.append("; ".join([tb.cell(0, j).text + ": " + tb.cell(i, j).text for j in range(len(tb.columns)) if tb.cell(i, j)])) rows.append("; ".join([tb.cell(
0, j).text + ": " + tb.cell(i, j).text for j in range(len(tb.columns)) if tb.cell(i, j)]))
return "\n".join(rows) return "\n".join(rows)
if shape.has_text_frame: if shape.has_text_frame:
@ -31,9 +32,10 @@ class HuPptParser(object):
if shape.shape_type == 6: if shape.shape_type == 6:
texts = [] texts = []
for p in sorted(shape.shapes, key=lambda x: (x.top//10, x.left)): for p in sorted(shape.shapes, key=lambda x: (x.top // 10, x.left)):
t = self.__extract(p) t = self.__extract(p)
if t: texts.append(t) if t:
texts.append(t)
return "\n".join(texts) return "\n".join(texts)
def __call__(self, fnm, from_page, to_page, callback=None): def __call__(self, fnm, from_page, to_page, callback=None):
@ -43,12 +45,16 @@ class HuPptParser(object):
txts = [] txts = []
self.total_page = len(ppt.slides) self.total_page = len(ppt.slides)
for i, slide in enumerate(ppt.slides): for i, slide in enumerate(ppt.slides):
if i < from_page: continue if i < from_page:
if i >= to_page:break continue
if i >= to_page:
break
texts = [] texts = []
for shape in sorted(slide.shapes, key=lambda x: (x.top//10, x.left)): for shape in sorted(
slide.shapes, key=lambda x: (x.top // 10, x.left)):
txt = self.__extract(shape) txt = self.__extract(shape)
if txt: texts.append(txt) if txt:
texts.append(txt)
txts.append("\n".join(texts)) txts.append("\n".join(texts))
return txts return txts

View File

@ -36,6 +36,7 @@ class LayoutRecognizer(Recognizer):
"Reference", "Reference",
"Equation", "Equation",
] ]
def __init__(self, domain): def __init__(self, domain):
try: try:
model_dir = snapshot_download( model_dir = snapshot_download(
@ -47,10 +48,12 @@ class LayoutRecognizer(Recognizer):
except Exception as e: except Exception as e:
model_dir = snapshot_download(repo_id="InfiniFlow/deepdoc") model_dir = snapshot_download(repo_id="InfiniFlow/deepdoc")
super().__init__(self.labels, domain, model_dir)#os.path.join(get_project_base_directory(), "rag/res/deepdoc/")) # os.path.join(get_project_base_directory(), "rag/res/deepdoc/"))
super().__init__(self.labels, domain, model_dir)
self.garbage_layouts = ["footer", "header", "reference"] self.garbage_layouts = ["footer", "header", "reference"]
def __call__(self, image_list, ocr_res, scale_factor=3, thr=0.2, batch_size=16, drop=True): def __call__(self, image_list, ocr_res, scale_factor=3,
thr=0.2, batch_size=16, drop=True):
def __is_garbage(b): def __is_garbage(b):
patt = [r"^•+$", r"(版权归©|免责条款|地址[:])", r"\.{3,}", "^[0-9]{1,2} / ?[0-9]{1,2}$", patt = [r"^•+$", r"(版权归©|免责条款|地址[:])", r"\.{3,}", "^[0-9]{1,2} / ?[0-9]{1,2}$",
r"^[0-9]{1,2} of [0-9]{1,2}$", "^http://[^ ]{12,}", r"^[0-9]{1,2} of [0-9]{1,2}$", "^http://[^ ]{12,}",
@ -75,7 +78,8 @@ class LayoutRecognizer(Recognizer):
"top": b["bbox"][1] / scale_factor, "bottom": b["bbox"][-1] / scale_factor, "top": b["bbox"][1] / scale_factor, "bottom": b["bbox"][-1] / scale_factor,
"page_number": pn, "page_number": pn,
} for b in lts] } for b in lts]
lts = self.sort_Y_firstly(lts, np.mean([l["bottom"]-l["top"] for l in lts]) / 2) lts = self.sort_Y_firstly(lts, np.mean(
[l["bottom"] - l["top"] for l in lts]) / 2)
lts = self.layouts_cleanup(bxs, lts) lts = self.layouts_cleanup(bxs, lts)
page_layout.append(lts) page_layout.append(lts)
@ -100,10 +104,13 @@ class LayoutRecognizer(Recognizer):
continue continue
lts_[ii]["visited"] = True lts_[ii]["visited"] = True
keep_feats = [ keep_feats = [
lts_[ii]["type"] == "footer" and bxs[i]["bottom"] < image_list[pn].size[1]*0.9/scale_factor, lts_[
lts_[ii]["type"] == "header" and bxs[i]["top"] > image_list[pn].size[1]*0.1/scale_factor, ii]["type"] == "footer" and bxs[i]["bottom"] < image_list[pn].size[1] * 0.9 / scale_factor,
lts_[
ii]["type"] == "header" and bxs[i]["top"] > image_list[pn].size[1] * 0.1 / scale_factor,
] ]
if drop and lts_[ii]["type"] in self.garbage_layouts and not any(keep_feats): if drop and lts_[
ii]["type"] in self.garbage_layouts and not any(keep_feats):
if lts_[ii]["type"] not in garbages: if lts_[ii]["type"] not in garbages:
garbages[lts_[ii]["type"]] = [] garbages[lts_[ii]["type"]] = []
garbages[lts_[ii]["type"]].append(bxs[i]["text"]) garbages[lts_[ii]["type"]].append(bxs[i]["text"])
@ -111,7 +118,8 @@ class LayoutRecognizer(Recognizer):
continue continue
bxs[i]["layoutno"] = f"{ty}-{ii}" bxs[i]["layoutno"] = f"{ty}-{ii}"
bxs[i]["layout_type"] = lts_[ii]["type"] if lts_[ii]["type"]!="equation" else "figure" bxs[i]["layout_type"] = lts_[ii]["type"] if lts_[
ii]["type"] != "equation" else "figure"
i += 1 i += 1
for lt in ["footer", "header", "reference", "figure caption", for lt in ["footer", "header", "reference", "figure caption",
@ -120,7 +128,7 @@ class LayoutRecognizer(Recognizer):
# add box to figure layouts which has not text box # add box to figure layouts which has not text box
for i, lt in enumerate( for i, lt in enumerate(
[lt for lt in lts if lt["type"] in ["figure","equation"]]): [lt for lt in lts if lt["type"] in ["figure", "equation"]]):
if lt.get("visited"): if lt.get("visited"):
continue continue
lt = deepcopy(lt) lt = deepcopy(lt)
@ -143,6 +151,3 @@ class LayoutRecognizer(Recognizer):
ocr_res = [b for b in ocr_res if b["text"].strip() not in garbag_set] ocr_res = [b for b in ocr_res if b["text"].strip() not in garbag_set]
return ocr_res, page_layout return ocr_res, page_layout

View File

@ -63,6 +63,7 @@ class DecodeImage(object):
data['image'] = img data['image'] = img
return data return data
class StandardizeImage(object): class StandardizeImage(object):
"""normalize image """normalize image
Args: Args:

View File

@ -11,12 +11,20 @@
# limitations under the License. # limitations under the License.
# #
import os, sys
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(os.path.abspath(__file__)), '../../')))
import numpy as np
import argparse
from deepdoc.vision import OCR, init_in_out
from deepdoc.vision.seeit import draw_box from deepdoc.vision.seeit import draw_box
from deepdoc.vision import OCR, init_in_out
import argparse
import numpy as np
import os
import sys
sys.path.insert(
0,
os.path.abspath(
os.path.join(
os.path.dirname(
os.path.abspath(__file__)),
'../../')))
def main(args): def main(args):
ocr = OCR() ocr = OCR()
@ -32,8 +40,8 @@ def main(args):
"score": 1} for b, t in bxs if b[0][0] <= b[1][0] and b[0][1] <= b[-1][1]] "score": 1} for b, t in bxs if b[0][0] <= b[1][0] and b[0][1] <= b[-1][1]]
img = draw_box(images[i], bxs, ["ocr"], 1.) img = draw_box(images[i], bxs, ["ocr"], 1.)
img.save(outputs[i], quality=95) img.save(outputs[i], quality=95)
with open(outputs[i] + ".txt", "w+") as f: f.write("\n".join([o["text"] for o in bxs])) with open(outputs[i] + ".txt", "w+") as f:
f.write("\n".join([o["text"] for o in bxs]))
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -11,24 +11,35 @@
# limitations under the License. # limitations under the License.
# #
import os, sys from deepdoc.vision.seeit import draw_box
from deepdoc.vision import Recognizer, LayoutRecognizer, TableStructureRecognizer, OCR, init_in_out
from api.utils.file_utils import get_project_base_directory
import argparse
import os
import sys
import re import re
import numpy as np import numpy as np
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(os.path.abspath(__file__)), '../../'))) sys.path.insert(
0,
import argparse os.path.abspath(
from api.utils.file_utils import get_project_base_directory os.path.join(
from deepdoc.vision import Recognizer, LayoutRecognizer, TableStructureRecognizer, OCR, init_in_out os.path.dirname(
from deepdoc.vision.seeit import draw_box os.path.abspath(__file__)),
'../../')))
def main(args): def main(args):
images, outputs = init_in_out(args) images, outputs = init_in_out(args)
if args.mode.lower() == "layout": if args.mode.lower() == "layout":
labels = LayoutRecognizer.labels labels = LayoutRecognizer.labels
detr = Recognizer(labels, "layout", os.path.join(get_project_base_directory(), "rag/res/deepdoc/")) detr = Recognizer(
labels,
"layout",
os.path.join(
get_project_base_directory(),
"rag/res/deepdoc/"))
if args.mode.lower() == "tsr": if args.mode.lower() == "tsr":
labels = TableStructureRecognizer.labels labels = TableStructureRecognizer.labels
detr = TableStructureRecognizer() detr = TableStructureRecognizer()
@ -39,7 +50,8 @@ def main(args):
if args.mode.lower() == "tsr": if args.mode.lower() == "tsr":
#lyt = [t for t in lyt if t["type"] == "table column"] #lyt = [t for t in lyt if t["type"] == "table column"]
html = get_table_html(images[i], lyt, ocr) html = get_table_html(images[i], lyt, ocr)
with open(outputs[i]+".html", "w+") as f: f.write(html) with open(outputs[i] + ".html", "w+") as f:
f.write(html)
lyt = [{ lyt = [{
"type": t["label"], "type": t["label"],
"bbox": [t["x0"], t["top"], t["x1"], t["bottom"]], "bbox": [t["x0"], t["top"], t["x1"], t["bottom"]],
@ -58,7 +70,7 @@ def get_table_html(img, tb_cpns, ocr):
"bottom": b[-1][1], "bottom": b[-1][1],
"layout_type": "table", "layout_type": "table",
"page_number": 0} for b, t in boxes if b[0][0] <= b[1][0] and b[0][1] <= b[-1][1]], "page_number": 0} for b, t in boxes if b[0][0] <= b[1][0] and b[0][1] <= b[-1][1]],
np.mean([b[-1][1]-b[0][1] for b,_ in boxes]) / 3 np.mean([b[-1][1] - b[0][1] for b, _ in boxes]) / 3
) )
def gather(kwd, fzy=10, ption=0.6): def gather(kwd, fzy=10, ption=0.6):
@ -157,7 +169,7 @@ def get_table_html(img, tb_cpns, ocr):
%s %s
</body> </body>
</html> </html>
"""% TableStructureRecognizer.construct_table(boxes, html=True) """ % TableStructureRecognizer.construct_table(boxes, html=True)
return html return html
@ -168,7 +180,10 @@ if __name__ == "__main__":
required=True) required=True)
parser.add_argument('--output_dir', help="Directory where to store the output images. Default: './layouts_outputs'", parser.add_argument('--output_dir', help="Directory where to store the output images. Default: './layouts_outputs'",
default="./layouts_outputs") default="./layouts_outputs")
parser.add_argument('--threshold', help="A threshold to filter out detections. Default: 0.5", default=0.5) parser.add_argument(
'--threshold',
help="A threshold to filter out detections. Default: 0.5",
default=0.5)
parser.add_argument('--mode', help="Task mode: layout recognition or table structure recognition", choices=["layout", "tsr"], parser.add_argument('--mode', help="Task mode: layout recognition or table structure recognition", choices=["layout", "tsr"],
default="layout") default="layout")
args = parser.parse_args() args = parser.parse_args()

View File

@ -44,7 +44,8 @@ class TableStructureRecognizer(Recognizer):
except Exception as e: except Exception as e:
model_dir = snapshot_download(repo_id="InfiniFlow/deepdoc") model_dir = snapshot_download(repo_id="InfiniFlow/deepdoc")
super().__init__(self.labels, "tsr", model_dir)#os.path.join(get_project_base_directory(), "rag/res/deepdoc/")) # os.path.join(get_project_base_directory(), "rag/res/deepdoc/"))
super().__init__(self.labels, "tsr", model_dir)
def __call__(self, images, thr=0.2): def __call__(self, images, thr=0.2):
tbls = super().__call__(images, thr) tbls = super().__call__(images, thr)
@ -138,7 +139,8 @@ class TableStructureRecognizer(Recognizer):
i = 0 i = 0
while i < len(boxes): while i < len(boxes):
if TableStructureRecognizer.is_caption(boxes[i]): if TableStructureRecognizer.is_caption(boxes[i]):
if is_english: cap + " " if is_english:
cap + " "
cap += boxes[i]["text"] cap += boxes[i]["text"]
boxes.pop(i) boxes.pop(i)
i -= 1 i -= 1
@ -366,7 +368,8 @@ class TableStructureRecognizer(Recognizer):
continue continue
txt = "" txt = ""
if arr: if arr:
h = min(np.min([c["bottom"] - c["top"] for c in arr]) / 2, 10) h = min(np.min([c["bottom"] - c["top"]
for c in arr]) / 2, 10)
txt = " ".join([c["text"] txt = " ".join([c["text"]
for c in Recognizer.sort_Y_firstly(arr, h)]) for c in Recognizer.sort_Y_firstly(arr, h)])
txts.append(txt) txts.append(txt)

View File

@ -48,10 +48,12 @@ class Pdf(PdfParser):
callback(0.8, "Text extraction finished") callback(0.8, "Text extraction finished")
return [(b["text"] + self._line_tag(b, zoomin), b.get("layoutno","")) for b in self.boxes], tbls return [(b["text"] + self._line_tag(b, zoomin), b.get("layoutno", ""))
for b in self.boxes], tbls
def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", callback=None, **kwargs): def chunk(filename, binary=None, from_page=0, to_page=100000,
lang="Chinese", callback=None, **kwargs):
""" """
Supported file formats are docx, pdf, txt. Supported file formats are docx, pdf, txt.
Since a book is long and not all the parts are useful, if it's a PDF, Since a book is long and not all the parts are useful, if it's a PDF,
@ -63,48 +65,63 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", ca
} }
doc["title_sm_tks"] = huqie.qieqie(doc["title_tks"]) doc["title_sm_tks"] = huqie.qieqie(doc["title_tks"])
pdf_parser = None pdf_parser = None
sections,tbls = [], [] sections, tbls = [], []
if re.search(r"\.docx?$", filename, re.IGNORECASE): if re.search(r"\.docx?$", filename, re.IGNORECASE):
callback(0.1, "Start to parse.") callback(0.1, "Start to parse.")
doc_parser = DocxParser() doc_parser = DocxParser()
# TODO: table of contents need to be removed # TODO: table of contents need to be removed
sections, tbls = doc_parser(binary if binary else filename, from_page=from_page, to_page=to_page) sections, tbls = doc_parser(
remove_contents_table(sections, eng=is_english(random_choices([t for t,_ in sections], k=200))) binary if binary else filename, from_page=from_page, to_page=to_page)
remove_contents_table(sections, eng=is_english(
random_choices([t for t, _ in sections], k=200)))
callback(0.8, "Finish parsing.") callback(0.8, "Finish parsing.")
elif re.search(r"\.pdf$", filename, re.IGNORECASE): elif re.search(r"\.pdf$", filename, re.IGNORECASE):
pdf_parser = Pdf() if kwargs.get("parser_config",{}).get("layout_recognize", True) else PlainParser() pdf_parser = Pdf() if kwargs.get(
"parser_config", {}).get(
"layout_recognize", True) else PlainParser()
sections, tbls = pdf_parser(filename if not binary else binary, sections, tbls = pdf_parser(filename if not binary else binary,
from_page=from_page, to_page=to_page, callback=callback) from_page=from_page, to_page=to_page, callback=callback)
elif re.search(r"\.txt$", filename, re.IGNORECASE): elif re.search(r"\.txt$", filename, re.IGNORECASE):
callback(0.1, "Start to parse.") callback(0.1, "Start to parse.")
txt = "" txt = ""
if binary:txt = binary.decode("utf-8") if binary:
txt = binary.decode("utf-8")
else: else:
with open(filename, "r") as f: with open(filename, "r") as f:
while True: while True:
l = f.readline() l = f.readline()
if not l:break if not l:
break
txt += l txt += l
sections = txt.split("\n") sections = txt.split("\n")
sections = [(l,"") for l in sections if l] sections = [(l, "") for l in sections if l]
remove_contents_table(sections, eng = is_english(random_choices([t for t,_ in sections], k=200))) remove_contents_table(sections, eng=is_english(
random_choices([t for t, _ in sections], k=200)))
callback(0.8, "Finish parsing.") callback(0.8, "Finish parsing.")
else: raise NotImplementedError("file type not supported yet(docx, pdf, txt supported)") else:
raise NotImplementedError(
"file type not supported yet(docx, pdf, txt supported)")
make_colon_as_title(sections) make_colon_as_title(sections)
bull = bullets_category([t for t in random_choices([t for t,_ in sections], k=100)]) bull = bullets_category(
[t for t in random_choices([t for t, _ in sections], k=100)])
if bull >= 0: if bull >= 0:
chunks = ["\n".join(ck) for ck in hierarchical_merge(bull, sections, 3)] chunks = ["\n".join(ck)
for ck in hierarchical_merge(bull, sections, 3)]
else: else:
sections = [s.split("@") for s,_ in sections] sections = [s.split("@") for s, _ in sections]
sections = [(pr[0], "@"+pr[1]) for pr in sections if len(pr)==2] sections = [(pr[0], "@" + pr[1]) for pr in sections if len(pr) == 2]
chunks = naive_merge(sections, kwargs.get("chunk_token_num", 256), kwargs.get("delimer", "\n。;!?")) chunks = naive_merge(
sections, kwargs.get(
"chunk_token_num", 256), kwargs.get(
"delimer", "\n。;!?"))
# is it English # is it English
eng = lang.lower() == "english"#is_english(random_choices([t for t, _ in sections], k=218)) # is_english(random_choices([t for t, _ in sections], k=218))
eng = lang.lower() == "english"
res = tokenize_table(tbls, doc, eng) res = tokenize_table(tbls, doc, eng)
res.extend(tokenize_chunks(chunks, doc, eng, pdf_parser)) res.extend(tokenize_chunks(chunks, doc, eng, pdf_parser))
@ -114,6 +131,7 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", ca
if __name__ == "__main__": if __name__ == "__main__":
import sys import sys
def dummy(prog=None, msg=""): def dummy(prog=None, msg=""):
pass pass
chunk(sys.argv[1], from_page=1, to_page=10, callback=dummy) chunk(sys.argv[1], from_page=1, to_page=10, callback=dummy)

View File

@ -35,8 +35,10 @@ class Docx(DocxParser):
pn = 0 pn = 0
lines = [] lines = []
for p in self.doc.paragraphs: for p in self.doc.paragraphs:
if pn > to_page:break if pn > to_page:
if from_page <= pn < to_page and p.text.strip(): lines.append(self.__clean(p.text)) break
if from_page <= pn < to_page and p.text.strip():
lines.append(self.__clean(p.text))
for run in p.runs: for run in p.runs:
if 'lastRenderedPageBreak' in run._element.xml: if 'lastRenderedPageBreak' in run._element.xml:
pn += 1 pn += 1
@ -63,15 +65,18 @@ class Pdf(PdfParser):
start = timer() start = timer()
self._layouts_rec(zoomin) self._layouts_rec(zoomin)
callback(0.67, "Layout analysis finished") callback(0.67, "Layout analysis finished")
cron_logger.info("paddle layouts:".format((timer()-start)/(self.total_page+0.1))) cron_logger.info("paddle layouts:".format(
(timer() - start) / (self.total_page + 0.1)))
self._naive_vertical_merge() self._naive_vertical_merge()
callback(0.8, "Text extraction finished") callback(0.8, "Text extraction finished")
return [(b["text"], self._line_tag(b, zoomin)) for b in self.boxes], None return [(b["text"], self._line_tag(b, zoomin))
for b in self.boxes], None
def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", callback=None, **kwargs): def chunk(filename, binary=None, from_page=0, to_page=100000,
lang="Chinese", callback=None, **kwargs):
""" """
Supported file formats are docx, pdf, txt. Supported file formats are docx, pdf, txt.
""" """
@ -89,7 +94,9 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", ca
callback(0.8, "Finish parsing.") callback(0.8, "Finish parsing.")
elif re.search(r"\.pdf$", filename, re.IGNORECASE): elif re.search(r"\.pdf$", filename, re.IGNORECASE):
pdf_parser = Pdf() if kwargs.get("parser_config",{}).get("layout_recognize", True) else PlainParser() pdf_parser = Pdf() if kwargs.get(
"parser_config", {}).get(
"layout_recognize", True) else PlainParser()
for txt, poss in pdf_parser(filename if not binary else binary, for txt, poss in pdf_parser(filename if not binary else binary,
from_page=from_page, to_page=to_page, callback=callback)[0]: from_page=from_page, to_page=to_page, callback=callback)[0]:
sections.append(txt + poss) sections.append(txt + poss)
@ -97,33 +104,40 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", ca
elif re.search(r"\.txt$", filename, re.IGNORECASE): elif re.search(r"\.txt$", filename, re.IGNORECASE):
callback(0.1, "Start to parse.") callback(0.1, "Start to parse.")
txt = "" txt = ""
if binary:txt = binary.decode("utf-8") if binary:
txt = binary.decode("utf-8")
else: else:
with open(filename, "r") as f: with open(filename, "r") as f:
while True: while True:
l = f.readline() l = f.readline()
if not l:break if not l:
break
txt += l txt += l
sections = txt.split("\n") sections = txt.split("\n")
sections = [l for l in sections if l] sections = [l for l in sections if l]
callback(0.8, "Finish parsing.") callback(0.8, "Finish parsing.")
else: raise NotImplementedError("file type not supported yet(docx, pdf, txt supported)") else:
raise NotImplementedError(
"file type not supported yet(docx, pdf, txt supported)")
# is it English # is it English
eng = lang.lower() == "english"#is_english(sections) eng = lang.lower() == "english" # is_english(sections)
# Remove 'Contents' part # Remove 'Contents' part
remove_contents_table(sections, eng) remove_contents_table(sections, eng)
make_colon_as_title(sections) make_colon_as_title(sections)
bull = bullets_category(sections) bull = bullets_category(sections)
chunks = hierarchical_merge(bull, sections, 3) chunks = hierarchical_merge(bull, sections, 3)
if not chunks: callback(0.99, "No chunk parsed out.") if not chunks:
callback(0.99, "No chunk parsed out.")
return tokenize_chunks(["\n".join(ck) for ck in chunks], doc, eng, pdf_parser) return tokenize_chunks(["\n".join(ck)
for ck in chunks], doc, eng, pdf_parser)
if __name__ == "__main__": if __name__ == "__main__":
import sys import sys
def dummy(prog=None, msg=""): def dummy(prog=None, msg=""):
pass pass
chunk(sys.argv[1], callback=dummy) chunk(sys.argv[1], callback=dummy)

View File

@ -25,10 +25,10 @@ class Pdf(PdfParser):
callback callback
) )
callback(msg="OCR finished.") callback(msg="OCR finished.")
#for bb in self.boxes: # for bb in self.boxes:
# for b in bb: # for b in bb:
# print(b) # print(b)
print("OCR:", timer()-start) print("OCR:", timer() - start)
self._layouts_rec(zoomin) self._layouts_rec(zoomin)
callback(0.65, "Layout analysis finished.") callback(0.65, "Layout analysis finished.")
@ -45,30 +45,35 @@ class Pdf(PdfParser):
for b in self.boxes: for b in self.boxes:
b["text"] = re.sub(r"([\t  ]|\u3000){2,}", " ", b["text"].strip()) b["text"] = re.sub(r"([\t  ]|\u3000){2,}", " ", b["text"].strip())
return [(b["text"], b.get("layout_no", ""), self.get_position(b, zoomin)) for i, b in enumerate(self.boxes)], tbls return [(b["text"], b.get("layout_no", ""), self.get_position(b, zoomin))
for i, b in enumerate(self.boxes)], tbls
def chunk(filename, binary=None, from_page=0, to_page=100000,
def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", callback=None, **kwargs): lang="Chinese", callback=None, **kwargs):
""" """
Only pdf is supported. Only pdf is supported.
""" """
pdf_parser = None pdf_parser = None
if re.search(r"\.pdf$", filename, re.IGNORECASE): if re.search(r"\.pdf$", filename, re.IGNORECASE):
pdf_parser = Pdf() if kwargs.get("parser_config",{}).get("layout_recognize", True) else PlainParser() pdf_parser = Pdf() if kwargs.get(
"parser_config", {}).get(
"layout_recognize", True) else PlainParser()
sections, tbls = pdf_parser(filename if not binary else binary, sections, tbls = pdf_parser(filename if not binary else binary,
from_page=from_page, to_page=to_page, callback=callback) from_page=from_page, to_page=to_page, callback=callback)
if sections and len(sections[0])<3: sections = [(t, l, [[0]*5]) for t, l in sections] if sections and len(sections[0]) < 3:
sections = [(t, l, [[0] * 5]) for t, l in sections]
else: raise NotImplementedError("file type not supported yet(pdf supported)") else:
raise NotImplementedError("file type not supported yet(pdf supported)")
doc = { doc = {
"docnm_kwd": filename "docnm_kwd": filename
} }
doc["title_tks"] = huqie.qie(re.sub(r"\.[a-zA-Z]+$", "", doc["docnm_kwd"])) doc["title_tks"] = huqie.qie(re.sub(r"\.[a-zA-Z]+$", "", doc["docnm_kwd"]))
doc["title_sm_tks"] = huqie.qieqie(doc["title_tks"]) doc["title_sm_tks"] = huqie.qieqie(doc["title_tks"])
# is it English # is it English
eng = lang.lower() == "english"#pdf_parser.is_english eng = lang.lower() == "english" # pdf_parser.is_english
# set pivot using the most frequent type of title, # set pivot using the most frequent type of title,
# then merge between 2 pivot # then merge between 2 pivot
@ -79,7 +84,8 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", ca
for txt, _, _ in sections: for txt, _, _ in sections:
for t, lvl in pdf_parser.outlines: for t, lvl in pdf_parser.outlines:
tks = set([t[i] + t[i + 1] for i in range(len(t) - 1)]) tks = set([t[i] + t[i + 1] for i in range(len(t) - 1)])
tks_ = set([txt[i] + txt[i + 1] for i in range(min(len(t), len(txt) - 1))]) tks_ = set([txt[i] + txt[i + 1]
for i in range(min(len(t), len(txt) - 1))])
if len(set(tks & tks_)) / max([len(tks), len(tks_), 1]) > 0.8: if len(set(tks & tks_)) / max([len(tks), len(tks_), 1]) > 0.8:
levels.append(lvl) levels.append(lvl)
break break
@ -87,24 +93,27 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", ca
levels.append(max_lvl + 1) levels.append(max_lvl + 1)
else: else:
bull = bullets_category([txt for txt,_,_ in sections]) bull = bullets_category([txt for txt, _, _ in sections])
most_level, levels = title_frequency(bull, [(txt, l) for txt, l, poss in sections]) most_level, levels = title_frequency(
bull, [(txt, l) for txt, l, poss in sections])
assert len(sections) == len(levels) assert len(sections) == len(levels)
sec_ids = [] sec_ids = []
sid = 0 sid = 0
for i, lvl in enumerate(levels): for i, lvl in enumerate(levels):
if lvl <= most_level and i > 0 and lvl != levels[i - 1]: sid += 1 if lvl <= most_level and i > 0 and lvl != levels[i - 1]:
sid += 1
sec_ids.append(sid) sec_ids.append(sid)
# print(lvl, self.boxes[i]["text"], most_level, sid) # print(lvl, self.boxes[i]["text"], most_level, sid)
sections = [(txt, sec_ids[i], poss) for i, (txt, _, poss) in enumerate(sections)] sections = [(txt, sec_ids[i], poss)
for i, (txt, _, poss) in enumerate(sections)]
for (img, rows), poss in tbls: for (img, rows), poss in tbls:
sections.append((rows if isinstance(rows, str) else rows[0], -1, sections.append((rows if isinstance(rows, str) else rows[0], -1,
[(p[0] + 1 - from_page, p[1], p[2], p[3], p[4]) for p in poss])) [(p[0] + 1 - from_page, p[1], p[2], p[3], p[4]) for p in poss]))
def tag(pn, left, right, top, bottom): def tag(pn, left, right, top, bottom):
if pn+left+right+top+bottom == 0: if pn + left + right + top + bottom == 0:
return "" return ""
return "@@{}\t{:.1f}\t{:.1f}\t{:.1f}\t{:.1f}##" \ return "@@{}\t{:.1f}\t{:.1f}\t{:.1f}\t{:.1f}##" \
.format(pn, left, right, top, bottom) .format(pn, left, right, top, bottom)
@ -112,7 +121,8 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", ca
chunks = [] chunks = []
last_sid = -2 last_sid = -2
tk_cnt = 0 tk_cnt = 0
for txt, sec_id, poss in sorted(sections, key=lambda x: (x[-1][0][0], x[-1][0][3], x[-1][0][1])): for txt, sec_id, poss in sorted(sections, key=lambda x: (
x[-1][0][0], x[-1][0][3], x[-1][0][1])):
poss = "\t".join([tag(*pos) for pos in poss]) poss = "\t".join([tag(*pos) for pos in poss])
if tk_cnt < 2048 and (sec_id == last_sid or sec_id == -1): if tk_cnt < 2048 and (sec_id == last_sid or sec_id == -1):
if chunks: if chunks:
@ -121,16 +131,17 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", ca
continue continue
chunks.append(txt + poss) chunks.append(txt + poss)
tk_cnt = num_tokens_from_string(txt) tk_cnt = num_tokens_from_string(txt)
if sec_id > -1: last_sid = sec_id if sec_id > -1:
last_sid = sec_id
res = tokenize_table(tbls, doc, eng) res = tokenize_table(tbls, doc, eng)
res.extend(tokenize_chunks(chunks, doc, eng, pdf_parser)) res.extend(tokenize_chunks(chunks, doc, eng, pdf_parser))
return res return res
if __name__ == "__main__": if __name__ == "__main__":
import sys import sys
def dummy(prog=None, msg=""): def dummy(prog=None, msg=""):
pass pass
chunk(sys.argv[1], callback=dummy) chunk(sys.argv[1], callback=dummy)

View File

@ -44,11 +44,14 @@ class Pdf(PdfParser):
tbls = self._extract_table_figure(True, zoomin, True, True) tbls = self._extract_table_figure(True, zoomin, True, True)
self._naive_vertical_merge() self._naive_vertical_merge()
cron_logger.info("paddle layouts:".format((timer() - start) / (self.total_page + 0.1))) cron_logger.info("paddle layouts:".format(
return [(b["text"], self._line_tag(b, zoomin)) for b in self.boxes], tbls (timer() - start) / (self.total_page + 0.1)))
return [(b["text"], self._line_tag(b, zoomin))
for b in self.boxes], tbls
def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", callback=None, **kwargs): def chunk(filename, binary=None, from_page=0, to_page=100000,
lang="Chinese", callback=None, **kwargs):
""" """
Supported file formats are docx, pdf, excel, txt. Supported file formats are docx, pdf, excel, txt.
This method apply the naive ways to chunk files. This method apply the naive ways to chunk files.
@ -56,8 +59,10 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", ca
Next, these successive pieces are merge into chunks whose token number is no more than 'Max token number'. Next, these successive pieces are merge into chunks whose token number is no more than 'Max token number'.
""" """
eng = lang.lower() == "english"#is_english(cks) eng = lang.lower() == "english" # is_english(cks)
parser_config = kwargs.get("parser_config", {"chunk_token_num": 128, "delimiter": "\n!?。;!?", "layout_recognize": True}) parser_config = kwargs.get(
"parser_config", {
"chunk_token_num": 128, "delimiter": "\n!?。;!?", "layout_recognize": True})
doc = { doc = {
"docnm_kwd": filename, "docnm_kwd": filename,
"title_tks": huqie.qie(re.sub(r"\.[a-zA-Z]+$", "", filename)) "title_tks": huqie.qie(re.sub(r"\.[a-zA-Z]+$", "", filename))
@ -73,7 +78,8 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", ca
callback(0.8, "Finish parsing.") callback(0.8, "Finish parsing.")
elif re.search(r"\.pdf$", filename, re.IGNORECASE): elif re.search(r"\.pdf$", filename, re.IGNORECASE):
pdf_parser = Pdf() if parser_config["layout_recognize"] else PlainParser() pdf_parser = Pdf(
) if parser_config["layout_recognize"] else PlainParser()
sections, tbls = pdf_parser(filename if not binary else binary, sections, tbls = pdf_parser(filename if not binary else binary,
from_page=from_page, to_page=to_page, callback=callback) from_page=from_page, to_page=to_page, callback=callback)
res = tokenize_table(tbls, doc, eng) res = tokenize_table(tbls, doc, eng)
@ -92,16 +98,21 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", ca
with open(filename, "r") as f: with open(filename, "r") as f:
while True: while True:
l = f.readline() l = f.readline()
if not l: break if not l:
break
txt += l txt += l
sections = txt.split("\n") sections = txt.split("\n")
sections = [(l, "") for l in sections if l] sections = [(l, "") for l in sections if l]
callback(0.8, "Finish parsing.") callback(0.8, "Finish parsing.")
else: else:
raise NotImplementedError("file type not supported yet(docx, pdf, txt supported)") raise NotImplementedError(
"file type not supported yet(docx, pdf, txt supported)")
chunks = naive_merge(sections, parser_config.get("chunk_token_num", 128), parser_config.get("delimiter", "\n!?。;!?")) chunks = naive_merge(
sections, parser_config.get(
"chunk_token_num", 128), parser_config.get(
"delimiter", "\n!?。;!?"))
res.extend(tokenize_chunks(chunks, doc, eng, pdf_parser)) res.extend(tokenize_chunks(chunks, doc, eng, pdf_parser))
return res return res
@ -110,9 +121,7 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", ca
if __name__ == "__main__": if __name__ == "__main__":
import sys import sys
def dummy(prog=None, msg=""): def dummy(prog=None, msg=""):
pass pass
chunk(sys.argv[1], from_page=0, to_page=10, callback=dummy) chunk(sys.argv[1], from_page=0, to_page=10, callback=dummy)

View File

@ -41,20 +41,23 @@ class Pdf(PdfParser):
tbls = self._extract_table_figure(True, zoomin, True, True) tbls = self._extract_table_figure(True, zoomin, True, True)
self._concat_downward() self._concat_downward()
sections = [(b["text"], self.get_position(b, zoomin)) for i, b in enumerate(self.boxes)] sections = [(b["text"], self.get_position(b, zoomin))
for i, b in enumerate(self.boxes)]
for (img, rows), poss in tbls: for (img, rows), poss in tbls:
sections.append((rows if isinstance(rows, str) else rows[0], sections.append((rows if isinstance(rows, str) else rows[0],
[(p[0] + 1 - from_page, p[1], p[2], p[3], p[4]) for p in poss])) [(p[0] + 1 - from_page, p[1], p[2], p[3], p[4]) for p in poss]))
return [(txt, "") for txt, _ in sorted(sections, key=lambda x: (x[-1][0][0], x[-1][0][3], x[-1][0][1]))], None return [(txt, "") for txt, _ in sorted(sections, key=lambda x: (
x[-1][0][0], x[-1][0][3], x[-1][0][1]))], None
def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", callback=None, **kwargs): def chunk(filename, binary=None, from_page=0, to_page=100000,
lang="Chinese", callback=None, **kwargs):
""" """
Supported file formats are docx, pdf, excel, txt. Supported file formats are docx, pdf, excel, txt.
One file forms a chunk which maintains original text order. One file forms a chunk which maintains original text order.
""" """
eng = lang.lower() == "english"#is_english(cks) eng = lang.lower() == "english" # is_english(cks)
if re.search(r"\.docx?$", filename, re.IGNORECASE): if re.search(r"\.docx?$", filename, re.IGNORECASE):
callback(0.1, "Start to parse.") callback(0.1, "Start to parse.")
@ -62,8 +65,11 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", ca
callback(0.8, "Finish parsing.") callback(0.8, "Finish parsing.")
elif re.search(r"\.pdf$", filename, re.IGNORECASE): elif re.search(r"\.pdf$", filename, re.IGNORECASE):
pdf_parser = Pdf() if kwargs.get("parser_config",{}).get("layout_recognize", True) else PlainParser() pdf_parser = Pdf() if kwargs.get(
sections, _ = pdf_parser(filename if not binary else binary, to_page=to_page, callback=callback) "parser_config", {}).get(
"layout_recognize", True) else PlainParser()
sections, _ = pdf_parser(
filename if not binary else binary, to_page=to_page, callback=callback)
sections = [s for s, _ in sections if s] sections = [s for s, _ in sections if s]
elif re.search(r"\.xlsx?$", filename, re.IGNORECASE): elif re.search(r"\.xlsx?$", filename, re.IGNORECASE):
@ -80,14 +86,16 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", ca
with open(filename, "r") as f: with open(filename, "r") as f:
while True: while True:
l = f.readline() l = f.readline()
if not l: break if not l:
break
txt += l txt += l
sections = txt.split("\n") sections = txt.split("\n")
sections = [s for s in sections if s] sections = [s for s in sections if s]
callback(0.8, "Finish parsing.") callback(0.8, "Finish parsing.")
else: else:
raise NotImplementedError("file type not supported yet(docx, pdf, txt supported)") raise NotImplementedError(
"file type not supported yet(docx, pdf, txt supported)")
doc = { doc = {
"docnm_kwd": filename, "docnm_kwd": filename,
@ -101,9 +109,7 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", ca
if __name__ == "__main__": if __name__ == "__main__":
import sys import sys
def dummy(prog=None, msg=""): def dummy(prog=None, msg=""):
pass pass
chunk(sys.argv[1], from_page=0, to_page=10, callback=dummy) chunk(sys.argv[1], from_page=0, to_page=10, callback=dummy)

View File

@ -67,7 +67,7 @@ class Pdf(PdfParser):
if from_page > 0: if from_page > 0:
return { return {
"title":"", "title": "",
"authors": "", "authors": "",
"abstract": "", "abstract": "",
"sections": [(b["text"] + self._line_tag(b, zoomin), b.get("layoutno", "")) for b in self.boxes if "sections": [(b["text"] + self._line_tag(b, zoomin), b.get("layoutno", "")) for b in self.boxes if
@ -87,7 +87,8 @@ class Pdf(PdfParser):
title = "" title = ""
break break
for j in range(3): for j in range(3):
if _begin(self.boxes[i + j]["text"]): break if _begin(self.boxes[i + j]["text"]):
break
authors.append(self.boxes[i + j]["text"]) authors.append(self.boxes[i + j]["text"])
break break
break break
@ -107,10 +108,15 @@ class Pdf(PdfParser):
abstr = txt + self._line_tag(self.boxes[i], zoomin) abstr = txt + self._line_tag(self.boxes[i], zoomin)
i += 1 i += 1
break break
if not abstr: i = 0 if not abstr:
i = 0
callback(0.8, "Page {}~{}: Text merging finished".format(from_page, min(to_page, self.total_page))) callback(
for b in self.boxes: print(b["text"], b.get("layoutno")) 0.8, "Page {}~{}: Text merging finished".format(
from_page, min(
to_page, self.total_page)))
for b in self.boxes:
print(b["text"], b.get("layoutno"))
print(tbls) print(tbls)
return { return {
@ -123,14 +129,15 @@ class Pdf(PdfParser):
} }
def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", callback=None, **kwargs): def chunk(filename, binary=None, from_page=0, to_page=100000,
lang="Chinese", callback=None, **kwargs):
""" """
Only pdf is supported. Only pdf is supported.
The abstract of the paper will be sliced as an entire chunk, and will not be sliced partly. The abstract of the paper will be sliced as an entire chunk, and will not be sliced partly.
""" """
pdf_parser = None pdf_parser = None
if re.search(r"\.pdf$", filename, re.IGNORECASE): if re.search(r"\.pdf$", filename, re.IGNORECASE):
if not kwargs.get("parser_config",{}).get("layout_recognize", True): if not kwargs.get("parser_config", {}).get("layout_recognize", True):
pdf_parser = PlainParser() pdf_parser = PlainParser()
paper = { paper = {
"title": filename, "title": filename,
@ -143,14 +150,15 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", ca
pdf_parser = Pdf() pdf_parser = Pdf()
paper = pdf_parser(filename if not binary else binary, paper = pdf_parser(filename if not binary else binary,
from_page=from_page, to_page=to_page, callback=callback) from_page=from_page, to_page=to_page, callback=callback)
else: raise NotImplementedError("file type not supported yet(pdf supported)") else:
raise NotImplementedError("file type not supported yet(pdf supported)")
doc = {"docnm_kwd": filename, "authors_tks": huqie.qie(paper["authors"]), doc = {"docnm_kwd": filename, "authors_tks": huqie.qie(paper["authors"]),
"title_tks": huqie.qie(paper["title"] if paper["title"] else filename)} "title_tks": huqie.qie(paper["title"] if paper["title"] else filename)}
doc["title_sm_tks"] = huqie.qieqie(doc["title_tks"]) doc["title_sm_tks"] = huqie.qieqie(doc["title_tks"])
doc["authors_sm_tks"] = huqie.qieqie(doc["authors_tks"]) doc["authors_sm_tks"] = huqie.qieqie(doc["authors_tks"])
# is it English # is it English
eng = lang.lower() == "english"#pdf_parser.is_english eng = lang.lower() == "english" # pdf_parser.is_english
print("It's English.....", eng) print("It's English.....", eng)
res = tokenize_table(paper["tables"], doc, eng) res = tokenize_table(paper["tables"], doc, eng)
@ -160,7 +168,8 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", ca
txt = pdf_parser.remove_tag(paper["abstract"]) txt = pdf_parser.remove_tag(paper["abstract"])
d["important_kwd"] = ["abstract", "总结", "概括", "summary", "summarize"] d["important_kwd"] = ["abstract", "总结", "概括", "summary", "summarize"]
d["important_tks"] = " ".join(d["important_kwd"]) d["important_tks"] = " ".join(d["important_kwd"])
d["image"], poss = pdf_parser.crop(paper["abstract"], need_position=True) d["image"], poss = pdf_parser.crop(
paper["abstract"], need_position=True)
add_positions(d, poss) add_positions(d, poss)
tokenize(d, txt, eng) tokenize(d, txt, eng)
res.append(d) res.append(d)
@ -174,7 +183,8 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", ca
sec_ids = [] sec_ids = []
sid = 0 sid = 0
for i, lvl in enumerate(levels): for i, lvl in enumerate(levels):
if lvl <= most_level and i > 0 and lvl != levels[i-1]: sid += 1 if lvl <= most_level and i > 0 and lvl != levels[i - 1]:
sid += 1
sec_ids.append(sid) sec_ids.append(sid)
print(lvl, sorted_sections[i][0], most_level, sid) print(lvl, sorted_sections[i][0], most_level, sid)
@ -190,6 +200,7 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", ca
res.extend(tokenize_chunks(chunks, doc, eng, pdf_parser)) res.extend(tokenize_chunks(chunks, doc, eng, pdf_parser))
return res return res
""" """
readed = [0] * len(paper["lines"]) readed = [0] * len(paper["lines"])
# find colon firstly # find colon firstly
@ -212,7 +223,7 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", ca
for k in range(j, i): readed[k] = True for k in range(j, i): readed[k] = True
txt = txt[::-1] txt = txt[::-1]
if eng: if eng:
r = re.search(r"(.*?) ([\.;?!]|$)", txt) r = re.search(r"(.*?) ([\\.;?!]|$)", txt)
txt = r.group(1)[::-1] if r else txt[::-1] txt = r.group(1)[::-1] if r else txt[::-1]
else: else:
r = re.search(r"(.*?) ([。?;!]|$)", txt) r = re.search(r"(.*?) ([。?;!]|$)", txt)
@ -270,6 +281,7 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", ca
if __name__ == "__main__": if __name__ == "__main__":
import sys import sys
def dummy(prog=None, msg=""): def dummy(prog=None, msg=""):
pass pass
chunk(sys.argv[1], callback=dummy) chunk(sys.argv[1], callback=dummy)

View File

@ -33,9 +33,12 @@ class Ppt(PptParser):
with slides.Presentation(BytesIO(fnm)) as presentation: with slides.Presentation(BytesIO(fnm)) as presentation:
for i, slide in enumerate(presentation.slides[from_page: to_page]): for i, slide in enumerate(presentation.slides[from_page: to_page]):
buffered = BytesIO() buffered = BytesIO()
slide.get_thumbnail(0.5, 0.5).save(buffered, drawing.imaging.ImageFormat.jpeg) slide.get_thumbnail(
0.5, 0.5).save(
buffered, drawing.imaging.ImageFormat.jpeg)
imgs.append(Image.open(buffered)) imgs.append(Image.open(buffered))
assert len(imgs) == len(txts), "Slides text and image do not match: {} vs. {}".format(len(imgs), len(txts)) assert len(imgs) == len(
txts), "Slides text and image do not match: {} vs. {}".format(len(imgs), len(txts))
callback(0.9, "Image extraction finished") callback(0.9, "Image extraction finished")
self.is_english = is_english(txts) self.is_english = is_english(txts)
return [(txts[i], imgs[i]) for i in range(len(txts))] return [(txts[i], imgs[i]) for i in range(len(txts))]
@ -47,25 +50,34 @@ class Pdf(PdfParser):
def __garbage(self, txt): def __garbage(self, txt):
txt = txt.lower().strip() txt = txt.lower().strip()
if re.match(r"[0-9\.,%/-]+$", txt): return True if re.match(r"[0-9\.,%/-]+$", txt):
if len(txt) < 3:return True return True
if len(txt) < 3:
return True
return False return False
def __call__(self, filename, binary=None, from_page=0, to_page=100000, zoomin=3, callback=None): def __call__(self, filename, binary=None, from_page=0,
to_page=100000, zoomin=3, callback=None):
callback(msg="OCR is running...") callback(msg="OCR is running...")
self.__images__(filename if not binary else binary, zoomin, from_page, to_page, callback) self.__images__(filename if not binary else binary,
callback(0.8, "Page {}~{}: OCR finished".format(from_page, min(to_page, self.total_page))) zoomin, from_page, to_page, callback)
assert len(self.boxes) == len(self.page_images), "{} vs. {}".format(len(self.boxes), len(self.page_images)) callback(0.8, "Page {}~{}: OCR finished".format(
from_page, min(to_page, self.total_page)))
assert len(self.boxes) == len(self.page_images), "{} vs. {}".format(
len(self.boxes), len(self.page_images))
res = [] res = []
for i in range(len(self.boxes)): for i in range(len(self.boxes)):
lines = "\n".join([b["text"] for b in self.boxes[i] if not self.__garbage(b["text"])]) lines = "\n".join([b["text"] for b in self.boxes[i]
if not self.__garbage(b["text"])])
res.append((lines, self.page_images[i])) res.append((lines, self.page_images[i]))
callback(0.9, "Page {}~{}: Parsing finished".format(from_page, min(to_page, self.total_page))) callback(0.9, "Page {}~{}: Parsing finished".format(
from_page, min(to_page, self.total_page)))
return res return res
class PlainPdf(PlainParser): class PlainPdf(PlainParser):
def __call__(self, filename, binary=None, from_page=0, to_page=100000, callback=None, **kwargs): def __call__(self, filename, binary=None, from_page=0,
to_page=100000, callback=None, **kwargs):
self.pdf = pdf2_read(filename if not binary else BytesIO(binary)) self.pdf = pdf2_read(filename if not binary else BytesIO(binary))
page_txt = [] page_txt = []
for page in self.pdf.pages[from_page: to_page]: for page in self.pdf.pages[from_page: to_page]:
@ -74,7 +86,8 @@ class PlainPdf(PlainParser):
return [(txt, None) for txt in page_txt] return [(txt, None) for txt in page_txt]
def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", callback=None, **kwargs): def chunk(filename, binary=None, from_page=0, to_page=100000,
lang="Chinese", callback=None, **kwargs):
""" """
The supported file formats are pdf, pptx. The supported file formats are pdf, pptx.
Every page will be treated as a chunk. And the thumbnail of every page will be stored. Every page will be treated as a chunk. And the thumbnail of every page will be stored.
@ -89,35 +102,42 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", ca
res = [] res = []
if re.search(r"\.pptx?$", filename, re.IGNORECASE): if re.search(r"\.pptx?$", filename, re.IGNORECASE):
ppt_parser = Ppt() ppt_parser = Ppt()
for pn, (txt,img) in enumerate(ppt_parser(filename if not binary else binary, from_page, 1000000, callback)): for pn, (txt, img) in enumerate(ppt_parser(
filename if not binary else binary, from_page, 1000000, callback)):
d = copy.deepcopy(doc) d = copy.deepcopy(doc)
pn += from_page pn += from_page
d["image"] = img d["image"] = img
d["page_num_int"] = [pn+1] d["page_num_int"] = [pn + 1]
d["top_int"] = [0] d["top_int"] = [0]
d["position_int"] = [(pn + 1, 0, img.size[0], 0, img.size[1])] d["position_int"] = [(pn + 1, 0, img.size[0], 0, img.size[1])]
tokenize(d, txt, eng) tokenize(d, txt, eng)
res.append(d) res.append(d)
return res return res
elif re.search(r"\.pdf$", filename, re.IGNORECASE): elif re.search(r"\.pdf$", filename, re.IGNORECASE):
pdf_parser = Pdf() if kwargs.get("parser_config",{}).get("layout_recognize", True) else PlainPdf() pdf_parser = Pdf() if kwargs.get(
for pn, (txt,img) in enumerate(pdf_parser(filename, binary, from_page=from_page, to_page=to_page, callback=callback)): "parser_config", {}).get(
"layout_recognize", True) else PlainPdf()
for pn, (txt, img) in enumerate(pdf_parser(filename, binary,
from_page=from_page, to_page=to_page, callback=callback)):
d = copy.deepcopy(doc) d = copy.deepcopy(doc)
pn += from_page pn += from_page
if img: d["image"] = img if img:
d["page_num_int"] = [pn+1] d["image"] = img
d["page_num_int"] = [pn + 1]
d["top_int"] = [0] d["top_int"] = [0]
d["position_int"] = [(pn + 1, 0, img.size[0] if img else 0, 0, img.size[1] if img else 0)] d["position_int"] = [
(pn + 1, 0, img.size[0] if img else 0, 0, img.size[1] if img else 0)]
tokenize(d, txt, eng) tokenize(d, txt, eng)
res.append(d) res.append(d)
return res return res
raise NotImplementedError("file type not supported yet(pptx, pdf supported)") raise NotImplementedError(
"file type not supported yet(pptx, pdf supported)")
if __name__== "__main__": if __name__ == "__main__":
import sys import sys
def dummy(a, b): def dummy(a, b):
pass pass
chunk(sys.argv[1], callback=dummy) chunk(sys.argv[1], callback=dummy)

View File

@ -27,6 +27,8 @@ from rag.utils import rmSpace
forbidden_select_fields4resume = [ forbidden_select_fields4resume = [
"name_pinyin_kwd", "edu_first_fea_kwd", "degree_kwd", "sch_rank_kwd", "edu_fea_kwd" "name_pinyin_kwd", "edu_first_fea_kwd", "degree_kwd", "sch_rank_kwd", "edu_fea_kwd"
] ]
def remote_call(filename, binary): def remote_call(filename, binary):
q = { q = {
"header": { "header": {
@ -48,18 +50,22 @@ def remote_call(filename, binary):
} }
for _ in range(3): for _ in range(3):
try: try:
resume = requests.post("http://127.0.0.1:61670/tog", data=json.dumps(q)) resume = requests.post(
"http://127.0.0.1:61670/tog",
data=json.dumps(q))
resume = resume.json()["response"]["results"] resume = resume.json()["response"]["results"]
resume = refactor(resume) resume = refactor(resume)
for k in ["education", "work", "project", "training", "skill", "certificate", "language"]: for k in ["education", "work", "project",
if not resume.get(k) and k in resume: del resume[k] "training", "skill", "certificate", "language"]:
if not resume.get(k) and k in resume:
del resume[k]
resume = step_one.refactor(pd.DataFrame([{"resume_content": json.dumps(resume), "tob_resume_id": "x", resume = step_one.refactor(pd.DataFrame([{"resume_content": json.dumps(resume), "tob_resume_id": "x",
"updated_at": datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")}])) "updated_at": datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")}]))
resume = step_two.parse(resume) resume = step_two.parse(resume)
return resume return resume
except Exception as e: except Exception as e:
cron_logger.error("Resume parser error: "+str(e)) cron_logger.error("Resume parser error: " + str(e))
return {} return {}
@ -144,10 +150,13 @@ def chunk(filename, binary=None, callback=None, **kwargs):
doc["content_ltks"] = huqie.qie(doc["content_with_weight"]) doc["content_ltks"] = huqie.qie(doc["content_with_weight"])
doc["content_sm_ltks"] = huqie.qieqie(doc["content_ltks"]) doc["content_sm_ltks"] = huqie.qieqie(doc["content_ltks"])
for n, _ in field_map.items(): for n, _ in field_map.items():
if n not in resume:continue if n not in resume:
if isinstance(resume[n], list) and (len(resume[n]) == 1 or n not in forbidden_select_fields4resume): continue
if isinstance(resume[n], list) and (
len(resume[n]) == 1 or n not in forbidden_select_fields4resume):
resume[n] = resume[n][0] resume[n] = resume[n][0]
if n.find("_tks")>0: resume[n] = huqie.qieqie(resume[n]) if n.find("_tks") > 0:
resume[n] = huqie.qieqie(resume[n])
doc[n] = resume[n] doc[n] = resume[n]
print(doc) print(doc)

View File

@ -25,7 +25,8 @@ from deepdoc.parser import ExcelParser
class Excel(ExcelParser): class Excel(ExcelParser):
def __call__(self, fnm, binary=None, from_page=0, to_page=10000000000, callback=None): def __call__(self, fnm, binary=None, from_page=0,
to_page=10000000000, callback=None):
if not binary: if not binary:
wb = load_workbook(fnm) wb = load_workbook(fnm)
else: else:
@ -48,8 +49,10 @@ class Excel(ExcelParser):
data = [] data = []
for i, r in enumerate(rows[1:]): for i, r in enumerate(rows[1:]):
rn += 1 rn += 1
if rn-1 < from_page:continue if rn - 1 < from_page:
if rn -1>=to_page: break continue
if rn - 1 >= to_page:
break
row = [ row = [
cell.value for ii, cell.value for ii,
cell in enumerate(r) if ii not in missed] cell in enumerate(r) if ii not in missed]
@ -60,7 +63,7 @@ class Excel(ExcelParser):
done += 1 done += 1
res.append(pd.DataFrame(np.array(data), columns=headers)) res.append(pd.DataFrame(np.array(data), columns=headers))
callback(0.3, ("Extract records: {}~{}".format(from_page+1, min(to_page, from_page+rn)) + ( callback(0.3, ("Extract records: {}~{}".format(from_page + 1, min(to_page, from_page + rn)) + (
f"{len(fails)} failure, line: %s..." % (",".join(fails[:3])) if fails else ""))) f"{len(fails)} failure, line: %s..." % (",".join(fails[:3])) if fails else "")))
return res return res
@ -73,7 +76,8 @@ def trans_datatime(s):
def trans_bool(s): def trans_bool(s):
if re.match(r"(true|yes|是|\*|✓|✔|☑|✅|√)$", str(s).strip(), flags=re.IGNORECASE): if re.match(r"(true|yes|是|\*|✓|✔|☑|✅|√)$",
str(s).strip(), flags=re.IGNORECASE):
return "yes" return "yes"
if re.match(r"(false|no|否|⍻|×)$", str(s).strip(), flags=re.IGNORECASE): if re.match(r"(false|no|否|⍻|×)$", str(s).strip(), flags=re.IGNORECASE):
return "no" return "no"
@ -107,13 +111,14 @@ def column_data_type(arr):
arr[i] = trans[ty](str(arr[i])) arr[i] = trans[ty](str(arr[i]))
except Exception as e: except Exception as e:
arr[i] = None arr[i] = None
#if ty == "text": # if ty == "text":
# if len(arr) > 128 and uni / len(arr) < 0.1: # if len(arr) > 128 and uni / len(arr) < 0.1:
# ty = "keyword" # ty = "keyword"
return arr, ty return arr, ty
def chunk(filename, binary=None, from_page=0, to_page=10000000000, lang="Chinese", callback=None, **kwargs): def chunk(filename, binary=None, from_page=0, to_page=10000000000,
lang="Chinese", callback=None, **kwargs):
""" """
Excel and csv(txt) format files are supported. Excel and csv(txt) format files are supported.
For csv or txt file, the delimiter between columns is TAB. For csv or txt file, the delimiter between columns is TAB.
@ -131,7 +136,12 @@ def chunk(filename, binary=None, from_page=0, to_page=10000000000, lang="Chinese
if re.search(r"\.xlsx?$", filename, re.IGNORECASE): if re.search(r"\.xlsx?$", filename, re.IGNORECASE):
callback(0.1, "Start to parse.") callback(0.1, "Start to parse.")
excel_parser = Excel() excel_parser = Excel()
dfs = excel_parser(filename, binary, from_page=from_page, to_page=to_page, callback=callback) dfs = excel_parser(
filename,
binary,
from_page=from_page,
to_page=to_page,
callback=callback)
elif re.search(r"\.(txt|csv)$", filename, re.IGNORECASE): elif re.search(r"\.(txt|csv)$", filename, re.IGNORECASE):
callback(0.1, "Start to parse.") callback(0.1, "Start to parse.")
txt = "" txt = ""
@ -149,8 +159,10 @@ def chunk(filename, binary=None, from_page=0, to_page=10000000000, lang="Chinese
headers = lines[0].split(kwargs.get("delimiter", "\t")) headers = lines[0].split(kwargs.get("delimiter", "\t"))
rows = [] rows = []
for i, line in enumerate(lines[1:]): for i, line in enumerate(lines[1:]):
if i < from_page:continue if i < from_page:
if i >= to_page: break continue
if i >= to_page:
break
row = [l for l in line.split(kwargs.get("delimiter", "\t"))] row = [l for l in line.split(kwargs.get("delimiter", "\t"))]
if len(row) != len(headers): if len(row) != len(headers):
fails.append(str(i)) fails.append(str(i))
@ -181,7 +193,13 @@ def chunk(filename, binary=None, from_page=0, to_page=10000000000, lang="Chinese
del df[n] del df[n]
clmns = df.columns.values clmns = df.columns.values
txts = list(copy.deepcopy(clmns)) txts = list(copy.deepcopy(clmns))
py_clmns = [PY.get_pinyins(re.sub(r"(/.*|[^]+?|\([^()]+?\))", "", n), '_')[0] for n in clmns] py_clmns = [
PY.get_pinyins(
re.sub(
r"(/.*|[^]+?|\([^()]+?\))",
"",
n),
'_')[0] for n in clmns]
clmn_tys = [] clmn_tys = []
for j in range(len(clmns)): for j in range(len(clmns)):
cln, ty = column_data_type(df[clmns[j]]) cln, ty = column_data_type(df[clmns[j]])
@ -192,7 +210,7 @@ def chunk(filename, binary=None, from_page=0, to_page=10000000000, lang="Chinese
clmns_map = [(py_clmns[i].lower() + fieds_map[clmn_tys[i]], clmns[i].replace("_", " ")) clmns_map = [(py_clmns[i].lower() + fieds_map[clmn_tys[i]], clmns[i].replace("_", " "))
for i in range(len(clmns))] for i in range(len(clmns))]
eng = lang.lower() == "english"#is_english(txts) eng = lang.lower() == "english" # is_english(txts)
for ii, row in df.iterrows(): for ii, row in df.iterrows():
d = { d = {
"docnm_kwd": filename, "docnm_kwd": filename,

View File

@ -13,6 +13,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# #
from zhipuai import ZhipuAI
from dashscope import Generation
from abc import ABC from abc import ABC
from openai import OpenAI from openai import OpenAI
import openai import openai
@ -34,7 +36,8 @@ class GptTurbo(Base):
self.model_name = model_name self.model_name = model_name
def chat(self, system, history, gen_conf): def chat(self, system, history, gen_conf):
if system: history.insert(0, {"role": "system", "content": system}) if system:
history.insert(0, {"role": "system", "content": system})
try: try:
response = self.client.chat.completions.create( response = self.client.chat.completions.create(
model=self.model_name, model=self.model_name,
@ -46,16 +49,18 @@ class GptTurbo(Base):
[ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?" [ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
return ans, response.usage.completion_tokens return ans, response.usage.completion_tokens
except openai.APIError as e: except openai.APIError as e:
return "**ERROR**: "+str(e), 0 return "**ERROR**: " + str(e), 0
class MoonshotChat(GptTurbo): class MoonshotChat(GptTurbo):
def __init__(self, key, model_name="moonshot-v1-8k"): def __init__(self, key, model_name="moonshot-v1-8k"):
self.client = OpenAI(api_key=key, base_url="https://api.moonshot.cn/v1",) self.client = OpenAI(
api_key=key, base_url="https://api.moonshot.cn/v1",)
self.model_name = model_name self.model_name = model_name
def chat(self, system, history, gen_conf): def chat(self, system, history, gen_conf):
if system: history.insert(0, {"role": "system", "content": system}) if system:
history.insert(0, {"role": "system", "content": system})
try: try:
response = self.client.chat.completions.create( response = self.client.chat.completions.create(
model=self.model_name, model=self.model_name,
@ -67,10 +72,9 @@ class MoonshotChat(GptTurbo):
[ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?" [ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
return ans, response.usage.completion_tokens return ans, response.usage.completion_tokens
except openai.APIError as e: except openai.APIError as e:
return "**ERROR**: "+str(e), 0 return "**ERROR**: " + str(e), 0
from dashscope import Generation
class QWenChat(Base): class QWenChat(Base):
def __init__(self, key, model_name=Generation.Models.qwen_turbo): def __init__(self, key, model_name=Generation.Models.qwen_turbo):
import dashscope import dashscope
@ -79,7 +83,8 @@ class QWenChat(Base):
def chat(self, system, history, gen_conf): def chat(self, system, history, gen_conf):
from http import HTTPStatus from http import HTTPStatus
if system: history.insert(0, {"role": "system", "content": system}) if system:
history.insert(0, {"role": "system", "content": system})
response = Generation.call( response = Generation.call(
self.model_name, self.model_name,
messages=history, messages=history,
@ -92,20 +97,21 @@ class QWenChat(Base):
ans += response.output.choices[0]['message']['content'] ans += response.output.choices[0]['message']['content']
tk_count += response.usage.output_tokens tk_count += response.usage.output_tokens
if response.output.choices[0].get("finish_reason", "") == "length": if response.output.choices[0].get("finish_reason", "") == "length":
ans += "...\nFor the content length reason, it stopped, continue?" if is_english([ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?" ans += "...\nFor the content length reason, it stopped, continue?" if is_english(
[ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
return ans, tk_count return ans, tk_count
return "**ERROR**: " + response.message, tk_count return "**ERROR**: " + response.message, tk_count
from zhipuai import ZhipuAI
class ZhipuChat(Base): class ZhipuChat(Base):
def __init__(self, key, model_name="glm-3-turbo"): def __init__(self, key, model_name="glm-3-turbo"):
self.client = ZhipuAI(api_key=key) self.client = ZhipuAI(api_key=key)
self.model_name = model_name self.model_name = model_name
def chat(self, system, history, gen_conf): def chat(self, system, history, gen_conf):
if system: history.insert(0, {"role": "system", "content": system}) if system:
history.insert(0, {"role": "system", "content": system})
try: try:
response = self.client.chat.completions.create( response = self.client.chat.completions.create(
self.model_name, self.model_name,
@ -120,6 +126,7 @@ class ZhipuChat(Base):
except Exception as e: except Exception as e:
return "**ERROR**: " + str(e), 0 return "**ERROR**: " + str(e), 0
class LocalLLM(Base): class LocalLLM(Base):
class RPCProxy: class RPCProxy:
def __init__(self, host, port): def __init__(self, host, port):
@ -129,14 +136,17 @@ class LocalLLM(Base):
def __conn(self): def __conn(self):
from multiprocessing.connection import Client from multiprocessing.connection import Client
self._connection = Client((self.host, self.port), authkey=b'infiniflow-token4kevinhu') self._connection = Client(
(self.host, self.port), authkey=b'infiniflow-token4kevinhu')
def __getattr__(self, name): def __getattr__(self, name):
import pickle import pickle
def do_rpc(*args, **kwargs): def do_rpc(*args, **kwargs):
for _ in range(3): for _ in range(3):
try: try:
self._connection.send(pickle.dumps((name, args, kwargs))) self._connection.send(
pickle.dumps((name, args, kwargs)))
return pickle.loads(self._connection.recv()) return pickle.loads(self._connection.recv())
except Exception as e: except Exception as e:
self.__conn() self.__conn()
@ -148,7 +158,8 @@ class LocalLLM(Base):
self.client = LocalLLM.RPCProxy("127.0.0.1", 7860) self.client = LocalLLM.RPCProxy("127.0.0.1", 7860)
def chat(self, system, history, gen_conf): def chat(self, system, history, gen_conf):
if system: history.insert(0, {"role": "system", "content": system}) if system:
history.insert(0, {"role": "system", "content": system})
try: try:
ans = self.client.chat( ans = self.client.chat(
history, history,

View File

@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# #
from zhipuai import ZhipuAI
import io import io
from abc import ABC from abc import ABC
@ -57,7 +58,7 @@ class Base(ABC):
}, },
}, },
{ {
"text": "请用中文详细描述一下图中的内容,比如时间,地点,人物,事情,人物心情等,如果有数据请提取出数据。" if self.lang.lower() == "chinese" else \ "text": "请用中文详细描述一下图中的内容,比如时间,地点,人物,事情,人物心情等,如果有数据请提取出数据。" if self.lang.lower() == "chinese" else
"Please describe the content of this picture, like where, when, who, what happen. If it has number data, please extract them out.", "Please describe the content of this picture, like where, when, who, what happen. If it has number data, please extract them out.",
}, },
], ],
@ -92,8 +93,9 @@ class QWenCV(Base):
def prompt(self, binary): def prompt(self, binary):
# stupid as hell # stupid as hell
tmp_dir = get_project_base_directory("tmp") tmp_dir = get_project_base_directory("tmp")
if not os.path.exists(tmp_dir): os.mkdir(tmp_dir) if not os.path.exists(tmp_dir):
path = os.path.join(tmp_dir, "%s.jpg"%get_uuid()) os.mkdir(tmp_dir)
path = os.path.join(tmp_dir, "%s.jpg" % get_uuid())
Image.open(io.BytesIO(binary)).save(path) Image.open(io.BytesIO(binary)).save(path)
return [ return [
{ {
@ -103,7 +105,7 @@ class QWenCV(Base):
"image": f"file://{path}" "image": f"file://{path}"
}, },
{ {
"text": "请用中文详细描述一下图中的内容,比如时间,地点,人物,事情,人物心情等,如果有数据请提取出数据。" if self.lang.lower() == "chinese" else \ "text": "请用中文详细描述一下图中的内容,比如时间,地点,人物,事情,人物心情等,如果有数据请提取出数据。" if self.lang.lower() == "chinese" else
"Please describe the content of this picture, like where, when, who, what happen. If it has number data, please extract them out.", "Please describe the content of this picture, like where, when, who, what happen. If it has number data, please extract them out.",
}, },
], ],
@ -120,9 +122,6 @@ class QWenCV(Base):
return response.message, 0 return response.message, 0
from zhipuai import ZhipuAI
class Zhipu4V(Base): class Zhipu4V(Base):
def __init__(self, key, model_name="glm-4v", lang="Chinese"): def __init__(self, key, model_name="glm-4v", lang="Chinese"):
self.client = ZhipuAI(api_key=key) self.client = ZhipuAI(api_key=key)

View File

@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# #
from zhipuai import ZhipuAI
import os import os
from abc import ABC from abc import ABC
@ -40,11 +41,11 @@ flag_model = FlagModel(model_dir,
query_instruction_for_retrieval="为这个句子生成表示以用于检索相关文章:", query_instruction_for_retrieval="为这个句子生成表示以用于检索相关文章:",
use_fp16=torch.cuda.is_available()) use_fp16=torch.cuda.is_available())
class Base(ABC): class Base(ABC):
def __init__(self, key, model_name): def __init__(self, key, model_name):
pass pass
def encode(self, texts: list, batch_size=32): def encode(self, texts: list, batch_size=32):
raise NotImplementedError("Please implement encode method!") raise NotImplementedError("Please implement encode method!")
@ -67,11 +68,11 @@ class HuEmbedding(Base):
""" """
self.model = flag_model self.model = flag_model
def encode(self, texts: list, batch_size=32): def encode(self, texts: list, batch_size=32):
texts = [t[:2000] for t in texts] texts = [t[:2000] for t in texts]
token_count = 0 token_count = 0
for t in texts: token_count += num_tokens_from_string(t) for t in texts:
token_count += num_tokens_from_string(t)
res = [] res = []
for i in range(0, len(texts), batch_size): for i in range(0, len(texts), batch_size):
res.extend(self.model.encode(texts[i:i + batch_size]).tolist()) res.extend(self.model.encode(texts[i:i + batch_size]).tolist())
@ -90,7 +91,8 @@ class OpenAIEmbed(Base):
def encode(self, texts: list, batch_size=32): def encode(self, texts: list, batch_size=32):
res = self.client.embeddings.create(input=texts, res = self.client.embeddings.create(input=texts,
model=self.model_name) model=self.model_name)
return np.array([d.embedding for d in res.data]), res.usage.total_tokens return np.array([d.embedding for d in res.data]
), res.usage.total_tokens
def encode_queries(self, text): def encode_queries(self, text):
res = self.client.embeddings.create(input=[text], res = self.client.embeddings.create(input=[text],
@ -111,7 +113,7 @@ class QWenEmbed(Base):
for i in range(0, len(texts), batch_size): for i in range(0, len(texts), batch_size):
resp = dashscope.TextEmbedding.call( resp = dashscope.TextEmbedding.call(
model=self.model_name, model=self.model_name,
input=texts[i:i+batch_size], input=texts[i:i + batch_size],
text_type="document" text_type="document"
) )
embds = [[] for _ in range(len(resp["output"]["embeddings"]))] embds = [[] for _ in range(len(resp["output"]["embeddings"]))]
@ -127,10 +129,10 @@ class QWenEmbed(Base):
input=text[:2048], input=text[:2048],
text_type="query" text_type="query"
) )
return np.array(resp["output"]["embeddings"][0]["embedding"]), resp["usage"]["total_tokens"] return np.array(resp["output"]["embeddings"][0]
["embedding"]), resp["usage"]["total_tokens"]
from zhipuai import ZhipuAI
class ZhipuEmbed(Base): class ZhipuEmbed(Base):
def __init__(self, key, model_name="embedding-2"): def __init__(self, key, model_name="embedding-2"):
self.client = ZhipuAI(api_key=key) self.client = ZhipuAI(api_key=key)
@ -139,7 +141,8 @@ class ZhipuEmbed(Base):
def encode(self, texts: list, batch_size=32): def encode(self, texts: list, batch_size=32):
res = self.client.embeddings.create(input=texts, res = self.client.embeddings.create(input=texts,
model=self.model_name) model=self.model_name)
return np.array([d.embedding for d in res.data]), res.usage.total_tokens return np.array([d.embedding for d in res.data]
), res.usage.total_tokens
def encode_queries(self, text): def encode_queries(self, text):
res = self.client.embeddings.create(input=text, res = self.client.embeddings.create(input=text,

View File

@ -9,7 +9,7 @@ from transformers import AutoModelForCausalLM, AutoTokenizer
class RPCHandler: class RPCHandler:
def __init__(self): def __init__(self):
self._functions = { } self._functions = {}
def register_function(self, func): def register_function(self, func):
self._functions[func.__name__] = func self._functions[func.__name__] = func
@ -21,7 +21,7 @@ class RPCHandler:
func_name, args, kwargs = pickle.loads(connection.recv()) func_name, args, kwargs = pickle.loads(connection.recv())
# Run the RPC and send a response # Run the RPC and send a response
try: try:
r = self._functions[func_name](*args,**kwargs) r = self._functions[func_name](*args, **kwargs)
connection.send(pickle.dumps(r)) connection.send(pickle.dumps(r))
except Exception as e: except Exception as e:
connection.send(pickle.dumps(e)) connection.send(pickle.dumps(e))
@ -44,11 +44,17 @@ def rpc_server(hdlr, address, authkey):
models = [] models = []
tokenizer = None tokenizer = None
def chat(messages, gen_conf): def chat(messages, gen_conf):
global tokenizer global tokenizer
model = Model() model = Model()
try: try:
conf = {"max_new_tokens": int(gen_conf.get("max_tokens", 256)), "temperature": float(gen_conf.get("temperature", 0.1))} conf = {
"max_new_tokens": int(
gen_conf.get(
"max_tokens", 256)), "temperature": float(
gen_conf.get(
"temperature", 0.1))}
print(messages, conf) print(messages, conf)
text = tokenizer.apply_chat_template( text = tokenizer.apply_chat_template(
messages, messages,
@ -65,7 +71,8 @@ def chat(messages, gen_conf):
output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids) output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
] ]
return tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0] return tokenizer.batch_decode(
generated_ids, skip_special_tokens=True)[0]
except Exception as e: except Exception as e:
return str(e) return str(e)
@ -75,10 +82,15 @@ def Model():
random.seed(time.time()) random.seed(time.time())
return random.choice(models) return random.choice(models)
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--model_name", type=str, help="Model name") parser.add_argument("--model_name", type=str, help="Model name")
parser.add_argument("--port", default=7860, type=int, help="RPC serving port") parser.add_argument(
"--port",
default=7860,
type=int,
help="RPC serving port")
args = parser.parse_args() args = parser.parse_args()
handler = RPCHandler() handler = RPCHandler()
@ -93,4 +105,5 @@ if __name__ == "__main__":
tokenizer = AutoTokenizer.from_pretrained(args.model_name) tokenizer = AutoTokenizer.from_pretrained(args.model_name)
# Run the server # Run the server
rpc_server(handler, ('0.0.0.0', args.port), authkey=b'infiniflow-token4kevinhu') rpc_server(handler, ('0.0.0.0', args.port),
authkey=b'infiniflow-token4kevinhu')

View File

@ -372,7 +372,8 @@ class PptChunker(HuChunker):
tb = shape.table tb = shape.table
rows = [] rows = []
for i in range(1, len(tb.rows)): for i in range(1, len(tb.rows)):
rows.append("; ".join([tb.cell(0, j).text + ": " + tb.cell(i, j).text for j in range(len(tb.columns)) if tb.cell(i, j)])) rows.append("; ".join([tb.cell(
0, j).text + ": " + tb.cell(i, j).text for j in range(len(tb.columns)) if tb.cell(i, j)]))
return "\n".join(rows) return "\n".join(rows)
if shape.has_text_frame: if shape.has_text_frame:
@ -382,7 +383,8 @@ class PptChunker(HuChunker):
texts = [] texts = []
for p in shape.shapes: for p in shape.shapes:
t = self.__extract(p) t = self.__extract(p)
if t: texts.append(t) if t:
texts.append(t)
return "\n".join(texts) return "\n".join(texts)
def __call__(self, fnm): def __call__(self, fnm):
@ -395,7 +397,8 @@ class PptChunker(HuChunker):
texts = [] texts = []
for shape in slide.shapes: for shape in slide.shapes:
txt = self.__extract(shape) txt = self.__extract(shape)
if txt: texts.append(txt) if txt:
texts.append(txt)
txts.append("\n".join(texts)) txts.append("\n".join(texts))
import aspose.slides as slides import aspose.slides as slides
@ -404,9 +407,12 @@ class PptChunker(HuChunker):
with slides.Presentation(BytesIO(fnm)) as presentation: with slides.Presentation(BytesIO(fnm)) as presentation:
for slide in presentation.slides: for slide in presentation.slides:
buffered = BytesIO() buffered = BytesIO()
slide.get_thumbnail(0.5, 0.5).save(buffered, drawing.imaging.ImageFormat.jpeg) slide.get_thumbnail(
0.5, 0.5).save(
buffered, drawing.imaging.ImageFormat.jpeg)
imgs.append(buffered.getvalue()) imgs.append(buffered.getvalue())
assert len(imgs) == len(txts), "Slides text and image do not match: {} vs. {}".format(len(imgs), len(txts)) assert len(imgs) == len(
txts), "Slides text and image do not match: {} vs. {}".format(len(imgs), len(txts))
flds = self.Fields() flds = self.Fields()
flds.text_chunks = [(txts[i], imgs[i]) for i in range(len(txts))] flds.text_chunks = [(txts[i], imgs[i]) for i in range(len(txts))]
@ -445,7 +451,8 @@ class TextChunker(HuChunker):
if isinstance(fnm, str): if isinstance(fnm, str):
with open(fnm, "r") as f: with open(fnm, "r") as f:
txt = f.read() txt = f.read()
else: txt = fnm.decode("utf-8") else:
txt = fnm.decode("utf-8")
flds.text_chunks = [(c, None) for c in self.naive_text_chunk(txt)] flds.text_chunks = [(c, None) for c in self.naive_text_chunk(txt)]
flds.table_chunks = [] flds.table_chunks = []
return flds return flds

View File

@ -149,7 +149,8 @@ class EsQueryer:
atks = toDict(atks) atks = toDict(atks)
btkss = [toDict(tks) for tks in btkss] btkss = [toDict(tks) for tks in btkss]
tksim = [self.similarity(atks, btks) for btks in btkss] tksim = [self.similarity(atks, btks) for btks in btkss]
return np.array(sims[0]) * vtweight + np.array(tksim) * tkweight, tksim, sims[0] return np.array(sims[0]) * vtweight + \
np.array(tksim) * tkweight, tksim, sims[0]
def similarity(self, qtwt, dtwt): def similarity(self, qtwt, dtwt):
if isinstance(dtwt, type("")): if isinstance(dtwt, type("")):
@ -159,11 +160,11 @@ class EsQueryer:
s = 1e-9 s = 1e-9
for k, v in qtwt.items(): for k, v in qtwt.items():
if k in dtwt: if k in dtwt:
s += v# * dtwt[k] s += v # * dtwt[k]
q = 1e-9 q = 1e-9
for k, v in qtwt.items(): for k, v in qtwt.items():
q += v #* v q += v # * v
#d = 1e-9 #d = 1e-9
#for k, v in dtwt.items(): # for k, v in dtwt.items():
# d += v * v # d += v * v
return s / q #math.sqrt(q) / math.sqrt(d) return s / q # math.sqrt(q) / math.sqrt(d)

View File

@ -80,14 +80,18 @@ class Dealer:
if not req.get("sort"): if not req.get("sort"):
s = s.sort( s = s.sort(
{"create_time": {"order": "desc", "unmapped_type": "date"}}, {"create_time": {"order": "desc", "unmapped_type": "date"}},
{"create_timestamp_flt": {"order": "desc", "unmapped_type": "float"}} {"create_timestamp_flt": {
"order": "desc", "unmapped_type": "float"}}
) )
else: else:
s = s.sort( s = s.sort(
{"page_num_int": {"order": "asc", "unmapped_type": "float", "mode": "avg", "numeric_type": "double"}}, {"page_num_int": {"order": "asc", "unmapped_type": "float",
{"top_int": {"order": "asc", "unmapped_type": "float", "mode": "avg", "numeric_type": "double"}}, "mode": "avg", "numeric_type": "double"}},
{"top_int": {"order": "asc", "unmapped_type": "float",
"mode": "avg", "numeric_type": "double"}},
{"create_time": {"order": "desc", "unmapped_type": "date"}}, {"create_time": {"order": "desc", "unmapped_type": "date"}},
{"create_timestamp_flt": {"order": "desc", "unmapped_type": "float"}} {"create_timestamp_flt": {
"order": "desc", "unmapped_type": "float"}}
) )
if qst: if qst:
@ -180,11 +184,13 @@ class Dealer:
m = {n: d.get(n) for n in flds if d.get(n) is not None} m = {n: d.get(n) for n in flds if d.get(n) is not None}
for n, v in m.items(): for n, v in m.items():
if isinstance(v, type([])): if isinstance(v, type([])):
m[n] = "\t".join([str(vv) if not isinstance(vv, list) else "\t".join([str(vvv) for vvv in vv]) for vv in v]) m[n] = "\t".join([str(vv) if not isinstance(
vv, list) else "\t".join([str(vvv) for vvv in vv]) for vv in v])
continue continue
if not isinstance(v, type("")): if not isinstance(v, type("")):
m[n] = str(m[n]) m[n] = str(m[n])
if n.find("tks")>0: m[n] = rmSpace(m[n]) if n.find("tks") > 0:
m[n] = rmSpace(m[n])
if m: if m:
res[d["id"]] = m res[d["id"]] = m
@ -205,12 +211,16 @@ class Dealer:
if pieces[i] == "```": if pieces[i] == "```":
st = i st = i
i += 1 i += 1
while i<len(pieces) and pieces[i] != "```": while i < len(pieces) and pieces[i] != "```":
i += 1 i += 1
if i < len(pieces): i += 1 if i < len(pieces):
pieces_.append("".join(pieces[st: i])+"\n") i += 1
pieces_.append("".join(pieces[st: i]) + "\n")
else: else:
pieces_.extend(re.split(r"([^\|][;。?!\n]|[a-z][.?;!][ \n])", pieces[i])) pieces_.extend(
re.split(
r"([^\|][;。?!\n]|[a-z][.?;!][ \n])",
pieces[i]))
i += 1 i += 1
pieces = pieces_ pieces = pieces_
else: else:
@ -234,7 +244,8 @@ class Dealer:
assert len(ans_v[0]) == len(chunk_v[0]), "The dimension of query and chunk do not match: {} vs. {}".format( assert len(ans_v[0]) == len(chunk_v[0]), "The dimension of query and chunk do not match: {} vs. {}".format(
len(ans_v[0]), len(chunk_v[0])) len(ans_v[0]), len(chunk_v[0]))
chunks_tks = [huqie.qie(self.qryr.rmWWW(ck)).split(" ") for ck in chunks] chunks_tks = [huqie.qie(self.qryr.rmWWW(ck)).split(" ")
for ck in chunks]
cites = {} cites = {}
for i, a in enumerate(pieces_): for i, a in enumerate(pieces_):
sim, tksim, vtsim = self.qryr.hybrid_similarity(ans_v[i], sim, tksim, vtsim = self.qryr.hybrid_similarity(ans_v[i],
@ -258,9 +269,11 @@ class Dealer:
continue continue
if i not in cites: if i not in cites:
continue continue
for c in cites[i]: assert int(c) < len(chunk_v)
for c in cites[i]: for c in cites[i]:
if c in seted:continue assert int(c) < len(chunk_v)
for c in cites[i]:
if c in seted:
continue
res += f" ##{c}$$" res += f" ##{c}$$"
seted.add(c) seted.add(c)
@ -343,7 +356,11 @@ class Dealer:
if dnm not in ranks["doc_aggs"]: if dnm not in ranks["doc_aggs"]:
ranks["doc_aggs"][dnm] = {"doc_id": did, "count": 0} ranks["doc_aggs"][dnm] = {"doc_id": did, "count": 0}
ranks["doc_aggs"][dnm]["count"] += 1 ranks["doc_aggs"][dnm]["count"] += 1
ranks["doc_aggs"] = [{"doc_name": k, "doc_id": v["doc_id"], "count": v["count"]} for k,v in sorted(ranks["doc_aggs"].items(), key=lambda x:x[1]["count"]*-1)] ranks["doc_aggs"] = [{"doc_name": k,
"doc_id": v["doc_id"],
"count": v["count"]} for k,
v in sorted(ranks["doc_aggs"].items(),
key=lambda x:x[1]["count"] * -1)]
return ranks return ranks
@ -354,10 +371,17 @@ class Dealer:
replaces = [] replaces = []
for r in re.finditer(r" ([a-z_]+_l?tks)( like | ?= ?)'([^']+)'", sql): for r in re.finditer(r" ([a-z_]+_l?tks)( like | ?= ?)'([^']+)'", sql):
fld, v = r.group(1), r.group(3) fld, v = r.group(1), r.group(3)
match = " MATCH({}, '{}', 'operator=OR;minimum_should_match=30%') ".format(fld, huqie.qieqie(huqie.qie(v))) match = " MATCH({}, '{}', 'operator=OR;minimum_should_match=30%') ".format(
replaces.append(("{}{}'{}'".format(r.group(1), r.group(2), r.group(3)), match)) fld, huqie.qieqie(huqie.qie(v)))
replaces.append(
("{}{}'{}'".format(
r.group(1),
r.group(2),
r.group(3)),
match))
for p, r in replaces: sql = sql.replace(p, r, 1) for p, r in replaces:
sql = sql.replace(p, r, 1)
chat_logger.info(f"To es: {sql}") chat_logger.info(f"To es: {sql}")
try: try:
@ -366,4 +390,3 @@ class Dealer:
except Exception as e: except Exception as e:
chat_logger.error(f"SQL failure: {sql} =>" + str(e)) chat_logger.error(f"SQL failure: {sql} =>" + str(e))
return {"error": str(e)} return {"error": str(e)}

View File

@ -150,8 +150,10 @@ class Dealer:
return 6 return 6
def ner(t): def ner(t):
if re.match(r"[0-9,.]{2,}$", t): return 2 if re.match(r"[0-9,.]{2,}$", t):
if re.match(r"[a-z]{1,2}$", t): return 0.01 return 2
if re.match(r"[a-z]{1,2}$", t):
return 0.01
if not self.ne or t not in self.ne: if not self.ne or t not in self.ne:
return 1 return 1
m = {"toxic": 2, "func": 1, "corp": 3, "loca": 3, "sch": 3, "stock": 3, m = {"toxic": 2, "func": 1, "corp": 3, "loca": 3, "sch": 3, "stock": 3,

View File

@ -14,7 +14,7 @@
# limitations under the License. # limitations under the License.
# #
import os import os
from api.utils import get_base_config,decrypt_database_config from api.utils import get_base_config, decrypt_database_config
from api.utils.file_utils import get_project_base_directory from api.utils.file_utils import get_project_base_directory
from api.utils.log_utils import LoggerFactory, getLogger from api.utils.log_utils import LoggerFactory, getLogger
@ -28,7 +28,11 @@ MINIO = decrypt_database_config(name="minio")
DOC_MAXIMUM_SIZE = 128 * 1024 * 1024 DOC_MAXIMUM_SIZE = 128 * 1024 * 1024
# Logger # Logger
LoggerFactory.set_directory(os.path.join(get_project_base_directory(), "logs", "rag")) LoggerFactory.set_directory(
os.path.join(
get_project_base_directory(),
"logs",
"rag"))
# {CRITICAL: 50, FATAL:50, ERROR:40, WARNING:30, WARN:30, INFO:20, DEBUG:10, NOTSET:0} # {CRITICAL: 50, FATAL:50, ERROR:40, WARNING:30, WARN:30, INFO:20, DEBUG:10, NOTSET:0}
LoggerFactory.LEVEL = 10 LoggerFactory.LEVEL = 10
@ -37,4 +41,3 @@ minio_logger = getLogger("minio")
cron_logger = getLogger("cron_logger") cron_logger = getLogger("cron_logger")
chunk_logger = getLogger("chunk_logger") chunk_logger = getLogger("chunk_logger")
database_logger = getLogger("database") database_logger = getLogger("database")

View File

@ -47,7 +47,7 @@ def collect(tm):
def set_dispatching(docid): def set_dispatching(docid):
try: try:
DocumentService.update_by_id( DocumentService.update_by_id(
docid, {"progress": random.random()*1 / 100., docid, {"progress": random.random() * 1 / 100.,
"progress_msg": "Task dispatched...", "progress_msg": "Task dispatched...",
"process_begin_at": get_format_time() "process_begin_at": get_format_time()
}) })
@ -56,7 +56,10 @@ def set_dispatching(docid):
def dispatch(): def dispatch():
tm_fnm = os.path.join(get_project_base_directory(), "rag/res", f"broker.tm") tm_fnm = os.path.join(
get_project_base_directory(),
"rag/res",
f"broker.tm")
tm = findMaxTm(tm_fnm) tm = findMaxTm(tm_fnm)
rows = collect(tm) rows = collect(tm)
if len(rows) == 0: if len(rows) == 0:
@ -82,17 +85,22 @@ def dispatch():
tsks = [] tsks = []
if r["type"] == FileType.PDF.value: if r["type"] == FileType.PDF.value:
do_layout = r["parser_config"].get("layout_recognize", True) do_layout = r["parser_config"].get("layout_recognize", True)
pages = PdfParser.total_page_number(r["name"], MINIO.get(r["kb_id"], r["location"])) pages = PdfParser.total_page_number(
r["name"], MINIO.get(r["kb_id"], r["location"]))
page_size = r["parser_config"].get("task_page_size", 12) page_size = r["parser_config"].get("task_page_size", 12)
if r["parser_id"] == "paper": page_size = r["parser_config"].get("task_page_size", 22) if r["parser_id"] == "paper":
if r["parser_id"] == "one": page_size = 1000000000 page_size = r["parser_config"].get("task_page_size", 22)
if not do_layout: page_size = 1000000000 if r["parser_id"] == "one":
page_size = 1000000000
if not do_layout:
page_size = 1000000000
page_ranges = r["parser_config"].get("pages") page_ranges = r["parser_config"].get("pages")
if not page_ranges: page_ranges = [(1, 100000)] if not page_ranges:
for s,e in page_ranges: page_ranges = [(1, 100000)]
for s, e in page_ranges:
s -= 1 s -= 1
s = max(0, s) s = max(0, s)
e = min(e-1, pages) e = min(e - 1, pages)
for p in range(s, e, page_size): for p in range(s, e, page_size):
task = new_task() task = new_task()
task["from_page"] = p task["from_page"] = p
@ -100,7 +108,9 @@ def dispatch():
tsks.append(task) tsks.append(task)
elif r["parser_id"] == "table": elif r["parser_id"] == "table":
rn = HuExcelParser.row_number(r["name"], MINIO.get(r["kb_id"], r["location"])) rn = HuExcelParser.row_number(
r["name"], MINIO.get(
r["kb_id"], r["location"]))
for i in range(0, rn, 3000): for i in range(0, rn, 3000):
task = new_task() task = new_task()
task["from_page"] = i task["from_page"] = i
@ -120,27 +130,37 @@ def update_progress():
for d in docs: for d in docs:
try: try:
tsks = TaskService.query(doc_id=d["id"], order_by=Task.create_time) tsks = TaskService.query(doc_id=d["id"], order_by=Task.create_time)
if not tsks:continue if not tsks:
continue
msg = [] msg = []
prg = 0 prg = 0
finished = True finished = True
bad = 0 bad = 0
status = TaskStatus.RUNNING.value status = TaskStatus.RUNNING.value
for t in tsks: for t in tsks:
if 0 <= t.progress < 1: finished = False if 0 <= t.progress < 1:
finished = False
prg += t.progress if t.progress >= 0 else 0 prg += t.progress if t.progress >= 0 else 0
msg.append(t.progress_msg) msg.append(t.progress_msg)
if t.progress == -1: bad += 1 if t.progress == -1:
bad += 1
prg /= len(tsks) prg /= len(tsks)
if finished and bad: if finished and bad:
prg = -1 prg = -1
status = TaskStatus.FAIL.value status = TaskStatus.FAIL.value
elif finished: status = TaskStatus.DONE.value elif finished:
status = TaskStatus.DONE.value
msg = "\n".join(msg) msg = "\n".join(msg)
info = {"process_duation": datetime.timestamp(datetime.now())-d["process_begin_at"].timestamp(), "run": status} info = {
if prg !=0 : info["progress"] = prg "process_duation": datetime.timestamp(
if msg: info["progress_msg"] = msg datetime.now()) -
d["process_begin_at"].timestamp(),
"run": status}
if prg != 0:
info["progress"] = prg
if msg:
info["progress_msg"] = msg
DocumentService.update_by_id(d["id"], info) DocumentService.update_by_id(d["id"], info)
except Exception as e: except Exception as e:
cron_logger.error("fetch task exception:" + str(e)) cron_logger.error("fetch task exception:" + str(e))

View File

@ -67,7 +67,7 @@ FACTORY = {
def set_progress(task_id, from_page=0, to_page=-1, def set_progress(task_id, from_page=0, to_page=-1,
prog=None, msg="Processing..."): prog=None, msg="Processing..."):
if prog is not None and prog < 0: if prog is not None and prog < 0:
msg = "[ERROR]"+msg msg = "[ERROR]" + msg
cancel = TaskService.do_cancel(task_id) cancel = TaskService.do_cancel(task_id)
if cancel: if cancel:
msg += " [Canceled]" msg += " [Canceled]"
@ -188,11 +188,13 @@ def embedding(docs, mdl, parser_config={}, callback=None):
cnts_ = np.array([]) cnts_ = np.array([])
for i in range(0, len(cnts), batch_size): for i in range(0, len(cnts), batch_size):
vts, c = mdl.encode(cnts[i: i+batch_size]) vts, c = mdl.encode(cnts[i: i + batch_size])
if len(cnts_) == 0: cnts_ = vts if len(cnts_) == 0:
else: cnts_ = np.concatenate((cnts_, vts), axis=0) cnts_ = vts
else:
cnts_ = np.concatenate((cnts_, vts), axis=0)
tk_count += c tk_count += c
callback(prog=0.7+0.2*(i+1)/len(cnts), msg="") callback(prog=0.7 + 0.2 * (i + 1) / len(cnts), msg="")
cnts = cnts_ cnts = cnts_
title_w = float(parser_config.get("filename_embd_weight", 0.1)) title_w = float(parser_config.get("filename_embd_weight", 0.1))
@ -234,7 +236,9 @@ def main(comm, mod):
continue continue
# TODO: exception handler # TODO: exception handler
## set_progress(r["did"], -1, "ERROR: ") ## set_progress(r["did"], -1, "ERROR: ")
callback(msg="Finished slicing files(%d). Start to embedding the content."%len(cks)) callback(
msg="Finished slicing files(%d). Start to embedding the content." %
len(cks))
try: try:
tk_count = embedding(cks, embd_mdl, r["parser_config"], callback) tk_count = embedding(cks, embd_mdl, r["parser_config"], callback)
except Exception as e: except Exception as e: