Add resume parser and fix bugs (#59)

* Update .gitignore

* Update .gitignore

* Add resume parser and fix bugs
This commit is contained in:
KevinHuSh 2024-02-07 19:27:23 +08:00 committed by GitHub
parent eb8254e688
commit c5ea37cd30
16 changed files with 451 additions and 57 deletions

4
.gitignore vendored
View File

@ -3,6 +3,10 @@
debug/ debug/
target/ target/
__pycache__/ __pycache__/
hudet/
cv/
layout_app.py
resume/
# Remove Cargo.lock from gitignore if creating an executable, leave it for libraries # Remove Cargo.lock from gitignore if creating an executable, leave it for libraries
# More information here https://doc.rust-lang.org/cargo/guide/cargo-toml-vs-cargo-lock.html # More information here https://doc.rust-lang.org/cargo/guide/cargo-toml-vs-cargo-lock.html

View File

@ -47,17 +47,20 @@ def list():
tenant_id = DocumentService.get_tenant_id(req["doc_id"]) tenant_id = DocumentService.get_tenant_id(req["doc_id"])
if not tenant_id: if not tenant_id:
return get_data_error_result(retmsg="Tenant not found!") return get_data_error_result(retmsg="Tenant not found!")
e, doc = DocumentService.get_by_id(doc_id)
if not e:
return get_data_error_result(retmsg="Document not found!")
query = { query = {
"doc_ids": [doc_id], "page": page, "size": size, "question": question "doc_ids": [doc_id], "page": page, "size": size, "question": question
} }
if "available_int" in req: if "available_int" in req:
query["available_int"] = int(req["available_int"]) query["available_int"] = int(req["available_int"])
sres = retrievaler.search(query, search.index_name(tenant_id)) sres = retrievaler.search(query, search.index_name(tenant_id))
res = {"total": sres.total, "chunks": []} res = {"total": sres.total, "chunks": [], "doc": doc.to_dict()}
for id in sres.ids: for id in sres.ids:
d = { d = {
"chunk_id": id, "chunk_id": id,
"content_with_weight": rmSpace(sres.highlight[id]) if question else sres.field[id]["content_with_weight"], "content_with_weight": rmSpace(sres.highlight[id]) if question else sres.field[id].get("content_with_weight", ""),
"doc_id": sres.field[id]["doc_id"], "doc_id": sres.field[id]["doc_id"],
"docnm_kwd": sres.field[id]["docnm_kwd"], "docnm_kwd": sres.field[id]["docnm_kwd"],
"important_kwd": sres.field[id].get("important_kwd", []), "important_kwd": sres.field[id].get("important_kwd", []),
@ -110,7 +113,7 @@ def get():
"important_kwd") "important_kwd")
def set(): def set():
req = request.json req = request.json
d = {"id": req["chunk_id"]} d = {"id": req["chunk_id"], "content_with_weight": req["content_with_weight"]}
d["content_ltks"] = huqie.qie(req["content_with_weight"]) d["content_ltks"] = huqie.qie(req["content_with_weight"])
d["content_sm_ltks"] = huqie.qieqie(d["content_ltks"]) d["content_sm_ltks"] = huqie.qieqie(d["content_ltks"])
d["important_kwd"] = req["important_kwd"] d["important_kwd"] = req["important_kwd"]
@ -181,11 +184,12 @@ def create():
md5 = hashlib.md5() md5 = hashlib.md5()
md5.update((req["content_with_weight"] + req["doc_id"]).encode("utf-8")) md5.update((req["content_with_weight"] + req["doc_id"]).encode("utf-8"))
chunck_id = md5.hexdigest() chunck_id = md5.hexdigest()
d = {"id": chunck_id, "content_ltks": huqie.qie(req["content_with_weight"])} d = {"id": chunck_id, "content_ltks": huqie.qie(req["content_with_weight"]), "content_with_weight": req["content_with_weight"]}
d["content_sm_ltks"] = huqie.qieqie(d["content_ltks"]) d["content_sm_ltks"] = huqie.qieqie(d["content_ltks"])
d["important_kwd"] = req.get("important_kwd", []) d["important_kwd"] = req.get("important_kwd", [])
d["important_tks"] = huqie.qie(" ".join(req.get("important_kwd", []))) d["important_tks"] = huqie.qie(" ".join(req.get("important_kwd", [])))
d["create_time"] = str(datetime.datetime.now()).replace("T", " ")[:19] d["create_time"] = str(datetime.datetime.now()).replace("T", " ")[:19]
d["create_timestamp_flt"] = datetime.datetime.now().timestamp()
try: try:
e, doc = DocumentService.get_by_id(req["doc_id"]) e, doc = DocumentService.get_by_id(req["doc_id"])

View File

@ -13,16 +13,21 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# #
import re
from flask import request from flask import request
from flask_login import login_required from flask_login import login_required
from api.db.services.dialog_service import DialogService, ConversationService from api.db.services.dialog_service import DialogService, ConversationService
from api.db import LLMType from api.db import LLMType
from api.db.services.llm_service import LLMService, TenantLLMService, LLMBundle from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.llm_service import LLMService, LLMBundle
from api.settings import access_logger
from api.utils.api_utils import server_error_response, get_data_error_result, validate_request from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
from api.utils import get_uuid from api.utils import get_uuid
from api.utils.api_utils import get_json_result from api.utils.api_utils import get_json_result
from rag.llm import ChatModel from rag.llm import ChatModel
from rag.nlp import retrievaler from rag.nlp import retrievaler
from rag.nlp.search import index_name
from rag.utils import num_tokens_from_string, encoder from rag.utils import num_tokens_from_string, encoder
@ -163,6 +168,17 @@ def chat(dialog, messages, **kwargs):
if not llm: if not llm:
raise LookupError("LLM(%s) not found"%dialog.llm_id) raise LookupError("LLM(%s) not found"%dialog.llm_id)
llm = llm[0] llm = llm[0]
question = messages[-1]["content"]
embd_mdl = LLMBundle(dialog.tenant_id, LLMType.EMBEDDING)
chat_mdl = LLMBundle(dialog.tenant_id, LLMType.CHAT, dialog.llm_id)
field_map = KnowledgebaseService.get_field_map(dialog.kb_ids)
## try to use sql if field mapping is good to go
if field_map:
markdown_tbl,chunks = use_sql(question, field_map, dialog.tenant_id, chat_mdl)
if markdown_tbl:
return {"answer": markdown_tbl, "retrieval": {"chunks": chunks}}
prompt_config = dialog.prompt_config prompt_config = dialog.prompt_config
for p in prompt_config["parameters"]: for p in prompt_config["parameters"]:
if p["key"] == "knowledge":continue if p["key"] == "knowledge":continue
@ -170,9 +186,6 @@ def chat(dialog, messages, **kwargs):
if p["key"] not in kwargs: if p["key"] not in kwargs:
prompt_config["system"] = prompt_config["system"].replace("{%s}"%p["key"], " ") prompt_config["system"] = prompt_config["system"].replace("{%s}"%p["key"], " ")
question = messages[-1]["content"]
embd_mdl = LLMBundle(dialog.tenant_id, LLMType.EMBEDDING)
chat_mdl = LLMBundle(dialog.tenant_id, LLMType.CHAT, dialog.llm_id)
kbinfos = retrievaler.retrieval(question, embd_mdl, dialog.tenant_id, dialog.kb_ids, 1, dialog.top_n, dialog.similarity_threshold, kbinfos = retrievaler.retrieval(question, embd_mdl, dialog.tenant_id, dialog.kb_ids, 1, dialog.top_n, dialog.similarity_threshold,
dialog.vector_similarity_weight, top=1024, aggs=False) dialog.vector_similarity_weight, top=1024, aggs=False)
knowledges = [ck["content_with_weight"] for ck in kbinfos["chunks"]] knowledges = [ck["content_with_weight"] for ck in kbinfos["chunks"]]
@ -196,4 +209,46 @@ def chat(dialog, messages, **kwargs):
vtweight=dialog.vector_similarity_weight) vtweight=dialog.vector_similarity_weight)
for c in kbinfos["chunks"]: for c in kbinfos["chunks"]:
if c.get("vector"):del c["vector"] if c.get("vector"):del c["vector"]
return {"answer": answer, "retrieval": kbinfos} return {"answer": answer, "retrieval": kbinfos}
def use_sql(question,field_map, tenant_id, chat_mdl):
sys_prompt = "你是一个DBA。你需要这对以下表的字段结构根据我的问题写出sql。"
user_promt = """
表名{}
数据库表字段说明如下
{}
问题{}
请写出SQL
""".format(
index_name(tenant_id),
"\n".join([f"{k}: {v}" for k,v in field_map.items()]),
question
)
sql = chat_mdl.chat(sys_prompt, [{"role": "user", "content": user_promt}], {"temperature": 0.1})
sql = re.sub(r".*?select ", "select ", sql, flags=re.IGNORECASE)
sql = re.sub(r" +", " ", sql)
if sql[:len("select ")].lower() != "select ":
return None, None
if sql[:len("select *")].lower() != "select *":
sql = "select doc_id,docnm_kwd," + sql[6:]
tbl = retrievaler.sql_retrieval(sql)
if not tbl: return None, None
docid_idx = set([ii for ii, c in enumerate(tbl["columns"]) if c["name"] == "doc_id"])
docnm_idx = set([ii for ii, c in enumerate(tbl["columns"]) if c["name"] == "docnm_kwd"])
clmn_idx = [ii for ii in range(len(tbl["columns"])) if ii not in (docid_idx|docnm_idx)]
clmns = "|".join([re.sub(r"/.*", "", field_map.get(tbl["columns"][i]["name"], f"C{i}")) for i in clmn_idx]) + "|原文"
line = "|".join(["------" for _ in range(len(clmn_idx))]) + "|------"
rows = ["|".join([str(r[i]) for i in clmn_idx])+"|" for r in tbl["rows"]]
if not docid_idx or not docnm_idx:
access_logger.error("SQL missing field: " + sql)
return "\n".join([clmns, line, "\n".join(rows)]), []
rows = "\n".join([r+f"##{ii}$$" for ii,r in enumerate(rows)])
docid_idx = list(docid_idx)[0]
docnm_idx = list(docnm_idx)[0]
return "\n".join([clmns, line, rows]), [{"doc_id": r[docid_idx], "docnm_kwd": r[docnm_idx]} for r in tbl["rows"]]

View File

@ -21,9 +21,6 @@ import flask
from elasticsearch_dsl import Q from elasticsearch_dsl import Q
from flask import request from flask import request
from flask_login import login_required, current_user from flask_login import login_required, current_user
from api.db.db_models import Task
from api.db.services.task_service import TaskService
from rag.nlp import search from rag.nlp import search
from rag.utils import ELASTICSEARCH from rag.utils import ELASTICSEARCH
from api.db.services import duplicate_name from api.db.services import duplicate_name
@ -35,7 +32,7 @@ from api.db.services.document_service import DocumentService
from api.settings import RetCode from api.settings import RetCode
from api.utils.api_utils import get_json_result from api.utils.api_utils import get_json_result
from rag.utils.minio_conn import MINIO from rag.utils.minio_conn import MINIO
from api.utils.file_utils import filename_type from api.utils.file_utils import filename_type, thumbnail
@manager.route('/upload', methods=['POST']) @manager.route('/upload', methods=['POST'])
@ -78,7 +75,8 @@ def upload():
"type": filename_type(filename), "type": filename_type(filename),
"name": filename, "name": filename,
"location": location, "location": location,
"size": len(blob) "size": len(blob),
"thumbnail": thumbnail(filename, blob)
}) })
return get_json_result(data=doc.to_json()) return get_json_result(data=doc.to_json())
except Exception as e: except Exception as e:

View File

@ -474,7 +474,7 @@ class Knowledgebase(DataBaseModel):
vector_similarity_weight = FloatField(default=0.3) vector_similarity_weight = FloatField(default=0.3)
parser_id = CharField(max_length=32, null=False, help_text="default parser ID", default=ParserType.GENERAL.value) parser_id = CharField(max_length=32, null=False, help_text="default parser ID", default=ParserType.GENERAL.value)
parser_config = JSONField(null=False, default={"from_page":0, "to_page": 100000}) parser_config = JSONField(null=False, default={"pages":[[0,1000000]]})
status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted1: validate)", default="1") status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted1: validate)", default="1")
def __str__(self): def __str__(self):
@ -489,7 +489,7 @@ class Document(DataBaseModel):
thumbnail = TextField(null=True, help_text="thumbnail base64 string") thumbnail = TextField(null=True, help_text="thumbnail base64 string")
kb_id = CharField(max_length=256, null=False, index=True) kb_id = CharField(max_length=256, null=False, index=True)
parser_id = CharField(max_length=32, null=False, help_text="default parser ID") parser_id = CharField(max_length=32, null=False, help_text="default parser ID")
parser_config = JSONField(null=False, default={"from_page":0, "to_page": 100000}) parser_config = JSONField(null=False, default={"pages":[[0,1000000]]})
source_type = CharField(max_length=128, null=False, default="local", help_text="where dose this document from") source_type = CharField(max_length=128, null=False, default="local", help_text="where dose this document from")
type = CharField(max_length=32, null=False, help_text="file extension") type = CharField(max_length=32, null=False, help_text="file extension")
created_by = CharField(max_length=32, null=False, help_text="who created it") created_by = CharField(max_length=32, null=False, help_text="who created it")

View File

@ -21,5 +21,6 @@ class DialogService(CommonService):
model = Dialog model = Dialog
class ConversationService(CommonService): class ConversationService(CommonService):
model = Conversation model = Conversation

View File

@ -63,3 +63,31 @@ class KnowledgebaseService(CommonService):
d = kbs[0].to_dict() d = kbs[0].to_dict()
d["embd_id"] = kbs[0].tenant.embd_id d["embd_id"] = kbs[0].tenant.embd_id
return d return d
@classmethod
@DB.connection_context()
def update_parser_config(cls, id, config):
e, m = cls.get_by_id(id)
if not e:raise LookupError(f"knowledgebase({id}) not found.")
def dfs_update(old, new):
for k,v in new.items():
if k not in old:
old[k] = v
continue
if isinstance(v, dict):
assert isinstance(old[k], dict)
dfs_update(old[k], v)
else: old[k] = v
dfs_update(m.parser_config, config)
cls.update_by_id(id, m.parser_config)
@classmethod
@DB.connection_context()
def get_field_map(cls, ids):
conf = {}
for k in cls.get_by_ids(ids):
if k.parser_config and "field_map" in k.parser_config:
conf.update(k.parser_config)
return conf

View File

@ -13,11 +13,14 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# #
import base64
import json import json
import os import os
import re import re
from io import BytesIO
import fitz
from PIL import Image
from cachetools import LRUCache, cached from cachetools import LRUCache, cached
from ruamel.yaml import YAML from ruamel.yaml import YAML
@ -150,4 +153,33 @@ def filename_type(filename):
return FileType.AURAL.value return FileType.AURAL.value
if re.match(r".*\.(jpg|jpeg|png|tif|gif|pcx|tga|exif|fpx|svg|psd|cdr|pcd|dxf|ufo|eps|ai|raw|WMF|webp|avif|apng|icon|ico|mpg|mpeg|avi|rm|rmvb|mov|wmv|asf|dat|asx|wvx|mpe|mpa|mp4)$", filename): if re.match(r".*\.(jpg|jpeg|png|tif|gif|pcx|tga|exif|fpx|svg|psd|cdr|pcd|dxf|ufo|eps|ai|raw|WMF|webp|avif|apng|icon|ico|mpg|mpeg|avi|rm|rmvb|mov|wmv|asf|dat|asx|wvx|mpe|mpa|mp4)$", filename):
return FileType.VISUAL return FileType.VISUAL
def thumbnail(filename, blob):
filename = filename.lower()
if re.match(r".*\.pdf$", filename):
pdf = fitz.open(stream=blob, filetype="pdf")
pix = pdf[0].get_pixmap(matrix=fitz.Matrix(0.03, 0.03))
buffered = BytesIO()
Image.frombytes("RGB", [pix.width, pix.height],
pix.samples).save(buffered, format="png")
return "data:image/png;base64," + base64.b64encode(buffered.getvalue())
if re.match(r".*\.(jpg|jpeg|png|tif|gif|icon|ico|webp)$", filename):
return ("data:image/%s;base64,"%filename.split(".")[-1]) + base64.b64encode(Image.open(BytesIO(blob)).thumbnail((30, 30)).tobytes())
if re.match(r".*\.(ppt|pptx)$", filename):
import aspose.slides as slides
import aspose.pydrawing as drawing
try:
with slides.Presentation(BytesIO(blob)) as presentation:
buffered = BytesIO()
presentation.slides[0].get_thumbnail(0.03, 0.03).save(buffered, drawing.imaging.ImageFormat.png)
return "data:image/png;base64," + base64.b64encode(buffered.getvalue())
except Exception as e:
pass

View File

@ -3,7 +3,6 @@ import re
from collections import Counter from collections import Counter
from api.db import ParserType from api.db import ParserType
from rag.cv.ppdetection import PPDet
from rag.parser import tokenize from rag.parser import tokenize
from rag.nlp import huqie from rag.nlp import huqie
from rag.parser.pdf_parser import HuParser from rag.parser.pdf_parser import HuParser

102
rag/app/resume.py Normal file
View File

@ -0,0 +1,102 @@
import copy
import json
import os
import re
import requests
from api.db.services.knowledgebase_service import KnowledgebaseService
from rag.nlp import huqie
from rag.settings import cron_logger
from rag.utils import rmSpace
def chunk(filename, binary=None, callback=None, **kwargs):
if not re.search(r"\.(pdf|doc|docx|txt)$", filename, flags=re.IGNORECASE): raise NotImplementedError("file type not supported yet(pdf supported)")
url = os.environ.get("INFINIFLOW_SERVER")
if not url:raise EnvironmentError("Please set environment variable: 'INFINIFLOW_SERVER'")
token = os.environ.get("INFINIFLOW_TOKEN")
if not token:raise EnvironmentError("Please set environment variable: 'INFINIFLOW_TOKEN'")
if not binary:
with open(filename, "rb") as f: binary = f.read()
def remote_call():
nonlocal filename, binary
for _ in range(3):
try:
res = requests.post(url + "/v1/layout/resume/", files=[(filename, binary)],
headers={"Authorization": token}, timeout=180)
res = res.json()
if res["retcode"] != 0: raise RuntimeError(res["retmsg"])
return res["data"]
except RuntimeError as e:
raise e
except Exception as e:
cron_logger.error("resume parsing:" + str(e))
resume = remote_call()
print(json.dumps(resume, ensure_ascii=False, indent=2))
field_map = {
"name_kwd": "姓名/名字",
"gender_kwd": "性别(男,女)",
"age_int": "年龄/岁/年纪",
"phone_kwd": "电话/手机/微信",
"email_tks": "email/e-mail/邮箱",
"position_name_tks": "职位/职能/岗位/职责",
"expect_position_name_tks": "期望职位/期望职能/期望岗位",
"hightest_degree_kwd": "最高学历高中职高硕士本科博士初中中技中专专科专升本MPAMBAEMBA",
"first_degree_kwd": "第一学历高中职高硕士本科博士初中中技中专专科专升本MPAMBAEMBA",
"first_major_tks": "第一学历专业",
"first_school_name_tks": "第一学历毕业学校",
"edu_first_fea_kwd": "第一学历标签211留学双一流985海外知名重点大学中专专升本专科本科大专",
"degree_kwd": "过往学历高中职高硕士本科博士初中中技中专专科专升本MPAMBAEMBA",
"major_tks": "学过的专业/过往专业",
"school_name_tks": "学校/毕业院校",
"sch_rank_kwd": "学校标签(顶尖学校,精英学校,优质学校,一般学校)",
"edu_fea_kwd": "教育标签211留学双一流985海外知名重点大学中专专升本专科本科大专",
"work_exp_flt": "工作年限/工作年份/N年经验/毕业了多少年",
"birth_dt": "生日/出生年份",
"corp_nm_tks": "就职过的公司/之前的公司/上过班的公司",
"corporation_name_tks": "最近就职(上班)的公司/上一家公司",
"edu_end_int": "毕业年份",
"expect_city_names_tks": "期望城市",
"industry_name_tks": "所在行业"
}
titles = []
for n in ["name_kwd", "gender_kwd", "position_name_tks", "age_int"]:
v = resume.get(n, "")
if isinstance(v, list):v = v[0]
if n.find("tks") > 0: v = rmSpace(v)
titles.append(str(v))
doc = {
"docnm_kwd": filename,
"title_tks": huqie.qie("-".join(titles)+"-简历")
}
doc["title_sm_tks"] = huqie.qieqie(doc["title_tks"])
pairs = []
for n,m in field_map.items():
if not resume.get(n):continue
v = resume[n]
if isinstance(v, list):v = " ".join(v)
if n.find("tks") > 0: v = rmSpace(v)
pairs.append((m, str(v)))
doc["content_with_weight"] = "\n".join(["{}: {}".format(re.sub(r"[^]+", "", k), v) for k,v in pairs])
doc["content_ltks"] = huqie.qie(doc["content_with_weight"])
doc["content_sm_ltks"] = huqie.qieqie(doc["content_ltks"])
for n, _ in field_map.items(): doc[n] = resume[n]
print(doc)
KnowledgebaseService.update_parser_config(kwargs["kb_id"], {"field_map": field_map})
return [doc]
if __name__ == "__main__":
import sys
def dummy(a, b):
pass
chunk(sys.argv[1], callback=dummy)

View File

@ -1,13 +1,13 @@
import copy import copy
import random
import re import re
from io import BytesIO from io import BytesIO
from xpinyin import Pinyin from xpinyin import Pinyin
import numpy as np import numpy as np
import pandas as pd import pandas as pd
from nltk import word_tokenize
from openpyxl import load_workbook from openpyxl import load_workbook
from dateutil.parser import parse as datetime_parse from dateutil.parser import parse as datetime_parse
from api.db.services.knowledgebase_service import KnowledgebaseService
from rag.parser import is_english, tokenize from rag.parser import is_english, tokenize
from rag.nlp import huqie, stemmer from rag.nlp import huqie, stemmer
@ -27,18 +27,19 @@ class Excel(object):
ws = wb[sheetname] ws = wb[sheetname]
rows = list(ws.rows) rows = list(ws.rows)
headers = [cell.value for cell in rows[0]] headers = [cell.value for cell in rows[0]]
missed = set([i for i,h in enumerate(headers) if h is None]) missed = set([i for i, h in enumerate(headers) if h is None])
headers = [cell.value for i,cell in enumerate(rows[0]) if i not in missed] headers = [cell.value for i, cell in enumerate(rows[0]) if i not in missed]
data = [] data = []
for i, r in enumerate(rows[1:]): for i, r in enumerate(rows[1:]):
row = [cell.value for ii,cell in enumerate(r) if ii not in missed] row = [cell.value for ii, cell in enumerate(r) if ii not in missed]
if len(row) != len(headers): if len(row) != len(headers):
fails.append(str(i)) fails.append(str(i))
continue continue
data.append(row) data.append(row)
done += 1 done += 1
if done % 999 == 0: if done % 999 == 0:
callback(done * 0.6/total, ("Extract records: {}".format(len(res)) + (f"{len(fails)} failure({sheetname}), line: %s..."%(",".join(fails[:3])) if fails else ""))) callback(done * 0.6 / total, ("Extract records: {}".format(len(res)) + (
f"{len(fails)} failure({sheetname}), line: %s..." % (",".join(fails[:3])) if fails else "")))
res.append(pd.DataFrame(np.array(data), columns=headers)) res.append(pd.DataFrame(np.array(data), columns=headers))
callback(0.6, ("Extract records: {}. ".format(done) + ( callback(0.6, ("Extract records: {}. ".format(done) + (
@ -61,9 +62,10 @@ def trans_bool(s):
def column_data_type(arr): def column_data_type(arr):
uni = len(set([a for a in arr if a is not None])) uni = len(set([a for a in arr if a is not None]))
counts = {"int": 0, "float": 0, "text": 0, "datetime": 0, "bool": 0} counts = {"int": 0, "float": 0, "text": 0, "datetime": 0, "bool": 0}
trans = {t:f for f,t in [(int, "int"), (float, "float"), (trans_datatime, "datetime"), (trans_bool, "bool"), (str, "text")]} trans = {t: f for f, t in
[(int, "int"), (float, "float"), (trans_datatime, "datetime"), (trans_bool, "bool"), (str, "text")]}
for a in arr: for a in arr:
if a is None:continue if a is None: continue
if re.match(r"[+-]?[0-9]+(\.0+)?$", str(a).replace("%%", "")): if re.match(r"[+-]?[0-9]+(\.0+)?$", str(a).replace("%%", "")):
counts["int"] += 1 counts["int"] += 1
elif re.match(r"[+-]?[0-9.]+$", str(a).replace("%%", "")): elif re.match(r"[+-]?[0-9.]+$", str(a).replace("%%", "")):
@ -72,17 +74,18 @@ def column_data_type(arr):
counts["bool"] += 1 counts["bool"] += 1
elif trans_datatime(str(a)): elif trans_datatime(str(a)):
counts["datetime"] += 1 counts["datetime"] += 1
else: counts["text"] += 1 else:
counts = sorted(counts.items(), key=lambda x: x[1]*-1) counts["text"] += 1
counts = sorted(counts.items(), key=lambda x: x[1] * -1)
ty = counts[0][0] ty = counts[0][0]
for i in range(len(arr)): for i in range(len(arr)):
if arr[i] is None:continue if arr[i] is None: continue
try: try:
arr[i] = trans[ty](str(arr[i])) arr[i] = trans[ty](str(arr[i]))
except Exception as e: except Exception as e:
arr[i] = None arr[i] = None
if ty == "text": if ty == "text":
if len(arr) > 128 and uni/len(arr) < 0.1: if len(arr) > 128 and uni / len(arr) < 0.1:
ty = "keyword" ty = "keyword"
return arr, ty return arr, ty
@ -123,48 +126,51 @@ def chunk(filename, binary=None, callback=None, **kwargs):
dfs = [pd.DataFrame(np.array(rows), columns=headers)] dfs = [pd.DataFrame(np.array(rows), columns=headers)]
else: raise NotImplementedError("file type not supported yet(excel, text, csv supported)") else:
raise NotImplementedError("file type not supported yet(excel, text, csv supported)")
res = [] res = []
PY = Pinyin() PY = Pinyin()
fieds_map = {"text": "_tks", "int": "_int", "keyword": "_kwd", "float": "_flt", "datetime": "_dt", "bool": "_kwd"} fieds_map = {"text": "_tks", "int": "_int", "keyword": "_kwd", "float": "_flt", "datetime": "_dt", "bool": "_kwd"}
for df in dfs: for df in dfs:
for n in ["id", "_id", "index", "idx"]: for n in ["id", "_id", "index", "idx"]:
if n in df.columns:del df[n] if n in df.columns: del df[n]
clmns = df.columns.values clmns = df.columns.values
txts = list(copy.deepcopy(clmns)) txts = list(copy.deepcopy(clmns))
py_clmns = [PY.get_pinyins(n)[0].replace("-", "_") for n in clmns] py_clmns = [PY.get_pinyins(n)[0].replace("-", "_") for n in clmns]
clmn_tys = [] clmn_tys = []
for j in range(len(clmns)): for j in range(len(clmns)):
cln,ty = column_data_type(df[clmns[j]]) cln, ty = column_data_type(df[clmns[j]])
clmn_tys.append(ty) clmn_tys.append(ty)
df[clmns[j]] = cln df[clmns[j]] = cln
if ty == "text": txts.extend([str(c) for c in cln if c]) if ty == "text": txts.extend([str(c) for c in cln if c])
clmns_map = [(py_clmns[j] + fieds_map[clmn_tys[j]], clmns[j]) for i in range(len(clmns))] clmns_map = [(py_clmns[j] + fieds_map[clmn_tys[j]], clmns[j]) for i in range(len(clmns))]
# TODO: set this column map to KB parser configuration
eng = is_english(txts) eng = is_english(txts)
for ii,row in df.iterrows(): for ii, row in df.iterrows():
d = {} d = {}
row_txt = [] row_txt = []
for j in range(len(clmns)): for j in range(len(clmns)):
if row[clmns[j]] is None:continue if row[clmns[j]] is None: continue
fld = clmns_map[j][0] fld = clmns_map[j][0]
d[fld] = row[clmns[j]] if clmn_tys[j] != "text" else huqie.qie(row[clmns[j]]) d[fld] = row[clmns[j]] if clmn_tys[j] != "text" else huqie.qie(row[clmns[j]])
row_txt.append("{}:{}".format(clmns[j], row[clmns[j]])) row_txt.append("{}:{}".format(clmns[j], row[clmns[j]]))
if not row_txt:continue if not row_txt: continue
tokenize(d, "; ".join(row_txt), eng) tokenize(d, "; ".join(row_txt), eng)
print(d)
res.append(d) res.append(d)
KnowledgebaseService.update_parser_config(kwargs["kb_id"], {"field_map": {k: v for k, v in clmns_map}})
callback(0.6, "") callback(0.6, "")
return res return res
if __name__ == "__main__":
if __name__== "__main__":
import sys import sys
def dummy(a, b): def dummy(a, b):
pass pass
chunk(sys.argv[1], callback=dummy)
chunk(sys.argv[1], callback=dummy)

View File

@ -74,7 +74,9 @@ class Dealer:
s = s.highlight("title_ltks") s = s.highlight("title_ltks")
if not qst: if not qst:
s = s.sort( s = s.sort(
{"create_time": {"order": "desc", "unmapped_type": "date"}}) {"create_time": {"order": "desc", "unmapped_type": "date"}},
{"create_timestamp_flt": {"order": "desc", "unmapped_type": "float"}}
)
if qst: if qst:
s = s.highlight_options( s = s.highlight_options(
@ -298,3 +300,22 @@ class Dealer:
ranks["doc_aggs"][dnm] += 1 ranks["doc_aggs"][dnm] += 1
return ranks return ranks
def sql_retrieval(self, sql, fetch_size=128):
sql = re.sub(r"[ ]+", " ", sql)
replaces = []
for r in re.finditer(r" ([a-z_]+_l?tks like |[a-z_]+_l?tks ?= ?)'([^']+)'", sql):
fld, v = r.group(1), r.group(2)
fld = re.sub(r" ?(like|=)$", "", fld).lower()
if v[0] == "%%": v = v[1:-1]
match = " MATCH({}, '{}', 'operator=OR;fuzziness=AUTO:1,3;minimum_should_match=30%') ".format(fld, huqie.qie(v))
replaces.append((r.group(1)+r.group(2), match))
for p, r in replaces: sql.replace(p, r)
try:
tbl = self.es.sql(sql, fetch_size)
return tbl
except Exception as e:
es_logger(f"SQL failure: {sql} =>" + str(e))

127
rag/nlp/surname.py Normal file
View File

@ -0,0 +1,127 @@
#-*- coding: utf-8 -*-
m = set(["","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","羿","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","宿","","怀",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","寿","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"广","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","西","","","","","","","","","","","","","","","","","","","","","","","","","","","","","","","","","","","","","","","","鹿","",
"万俟","司马","上官","欧阳",
"夏侯","诸葛","闻人","东方",
"赫连","皇甫","尉迟","公羊",
"澹台","公冶","宗政","濮阳",
"淳于","单于","太叔","申屠",
"公孙","仲孙","轩辕","令狐",
"钟离","宇文","长孙","慕容",
"鲜于","闾丘","司徒","司空",
"亓官","司寇","仉督","子车",
"颛孙","端木","巫马","公西",
"漆雕","乐正","壤驷","公良",
"拓跋","夹谷","宰父","榖梁",
"","","","","","","","",
"段干","百里","东郭","南门",
"呼延","","","羊舌","","",
"","","","","","","","",
"梁丘","左丘","东门","西门",
"","","","","","","南宫",
"","","","","","","","",
"第五","",""])
def isit(n):return n.strip() in m

View File

@ -81,11 +81,13 @@ def dispatch():
tsks = [] tsks = []
if r["type"] == FileType.PDF.value: if r["type"] == FileType.PDF.value:
pages = HuParser.total_page_number(r["name"], MINIO.get(r["kb_id"], r["location"])) pages = HuParser.total_page_number(r["name"], MINIO.get(r["kb_id"], r["location"]))
for p in range(0, pages, 10): for s,e in r["parser_config"].get("pages", [(0,100000)]):
task = new_task() e = min(e, pages)
task["from_page"] = p for p in range(s, e, 10):
task["to_page"] = min(p + 10, pages) task = new_task()
tsks.append(task) task["from_page"] = p
task["to_page"] = min(p + 10, e)
tsks.append(task)
else: else:
tsks.append(new_task()) tsks.append(new_task())
print(tsks) print(tsks)

View File

@ -58,7 +58,7 @@ FACTORY = {
} }
def set_progress(task_id, from_page, to_page, prog=None, msg="Processing..."): def set_progress(task_id, from_page=0, to_page=-1, prog=None, msg="Processing..."):
cancel = TaskService.do_cancel(task_id) cancel = TaskService.do_cancel(task_id)
if cancel: if cancel:
msg += " [Canceled]" msg += " [Canceled]"
@ -110,7 +110,7 @@ def collect(comm, mod, tm):
def build(row, cvmdl): def build(row, cvmdl):
if row["size"] > DOC_MAXIMUM_SIZE: if row["size"] > DOC_MAXIMUM_SIZE:
set_progress(row["id"], -1, "File size exceeds( <= %dMb )" % set_progress(row["id"], prog=-1, msg="File size exceeds( <= %dMb )" %
(int(DOC_MAXIMUM_SIZE / 1024 / 1024))) (int(DOC_MAXIMUM_SIZE / 1024 / 1024)))
return [] return []
@ -119,7 +119,7 @@ def build(row, cvmdl):
try: try:
cron_logger.info("Chunkking {}/{}".format(row["location"], row["name"])) cron_logger.info("Chunkking {}/{}".format(row["location"], row["name"]))
cks = chunker.chunk(row["name"], MINIO.get(row["kb_id"], row["location"]), row["from_page"], row["to_page"], cks = chunker.chunk(row["name"], MINIO.get(row["kb_id"], row["location"]), row["from_page"], row["to_page"],
callback) callback, kb_id=row["kb_id"])
except Exception as e: except Exception as e:
if re.search("(No such file|not found)", str(e)): if re.search("(No such file|not found)", str(e)):
callback(-1, "Can not find file <%s>" % row["doc_name"]) callback(-1, "Can not find file <%s>" % row["doc_name"])
@ -144,6 +144,7 @@ def build(row, cvmdl):
md5.update((ck["content_with_weight"] + str(d["doc_id"])).encode("utf-8")) md5.update((ck["content_with_weight"] + str(d["doc_id"])).encode("utf-8"))
d["_id"] = md5.hexdigest() d["_id"] = md5.hexdigest()
d["create_time"] = str(datetime.datetime.now()).replace("T", " ")[:19] d["create_time"] = str(datetime.datetime.now()).replace("T", " ")[:19]
d["create_timestamp_flt"] = datetime.datetime.now().timestamp()
if not d.get("image"): if not d.get("image"):
docs.append(d) docs.append(d)
continue continue
@ -197,15 +198,15 @@ def main(comm, mod):
tmf = open(tm_fnm, "a+") tmf = open(tm_fnm, "a+")
for _, r in rows.iterrows(): for _, r in rows.iterrows():
callback = partial(set_progress, r["id"], r["from_page"], r["to_page"])
try: try:
embd_mdl = LLMBundle(r["tenant_id"], LLMType.EMBEDDING) embd_mdl = LLMBundle(r["tenant_id"], LLMType.EMBEDDING)
cv_mdl = LLMBundle(r["tenant_id"], LLMType.IMAGE2TEXT) cv_mdl = LLMBundle(r["tenant_id"], LLMType.IMAGE2TEXT)
# TODO: sequence2text model # TODO: sequence2text model
except Exception as e: except Exception as e:
set_progress(r["id"], -1, str(e)) callback(prog=-1, msg=str(e))
continue continue
callback = partial(set_progress, r["id"], r["from_page"], r["to_page"])
st_tm = timer() st_tm = timer()
cks = build(r, cv_mdl) cks = build(r, cv_mdl)
if not cks: if not cks:

View File

@ -3,13 +3,14 @@ import json
import time import time
import copy import copy
import elasticsearch import elasticsearch
from elastic_transport import ConnectionTimeout
from elasticsearch import Elasticsearch from elasticsearch import Elasticsearch
from elasticsearch_dsl import UpdateByQuery, Search, Index from elasticsearch_dsl import UpdateByQuery, Search, Index
from rag.settings import es_logger from rag.settings import es_logger
from rag import settings from rag import settings
from rag.utils import singleton from rag.utils import singleton
es_logger.info("Elasticsearch version: "+ str(elasticsearch.__version__)) es_logger.info("Elasticsearch version: "+str(elasticsearch.__version__))
@singleton @singleton
@ -57,7 +58,7 @@ class HuEs:
body=d, body=d,
id=id, id=id,
doc_type="doc", doc_type="doc",
refresh=False, refresh=True,
retry_on_conflict=100) retry_on_conflict=100)
else: else:
r = self.es.update( r = self.es.update(
@ -65,7 +66,7 @@ class HuEs:
self.idxnm if not idxnm else idxnm), self.idxnm if not idxnm else idxnm),
body=d, body=d,
id=id, id=id,
refresh=False, refresh=True,
retry_on_conflict=100) retry_on_conflict=100)
es_logger.info("Successfully upsert: %s" % id) es_logger.info("Successfully upsert: %s" % id)
T = True T = True
@ -240,6 +241,18 @@ class HuEs:
es_logger.error("ES search timeout for 3 times!") es_logger.error("ES search timeout for 3 times!")
raise Exception("ES search timeout.") raise Exception("ES search timeout.")
def sql(self, sql, fetch_size=128, format="json", timeout=2):
for i in range(3):
try:
res = self.es.sql.query(body={"query": sql, "fetch_size": fetch_size}, format=format, request_timeout=timeout)
return res
except ConnectionTimeout as e:
es_logger.error("Timeout【Q】" + sql)
continue
es_logger.error("ES search timeout for 3 times!")
raise ConnectionTimeout()
def get(self, doc_id, idxnm=None): def get(self, doc_id, idxnm=None):
for i in range(3): for i in range(3):
try: try:
@ -308,7 +321,8 @@ class HuEs:
try: try:
r = self.es.delete_by_query( r = self.es.delete_by_query(
index=idxnm if idxnm else self.idxnm, index=idxnm if idxnm else self.idxnm,
body=Search().query(query).to_dict()) refresh = True,
body=Search().query(query).to_dict())
return True return True
except Exception as e: except Exception as e:
es_logger.error("ES updateByQuery deleteByQuery: " + es_logger.error("ES updateByQuery deleteByQuery: " +