refactor retieval_test, add SQl retrieval methods (#61)

This commit is contained in:
KevinHuSh 2024-02-08 17:01:01 +08:00 committed by GitHub
parent 0a903c7714
commit 5e0a689c43
16 changed files with 238 additions and 74 deletions

View File

@ -227,7 +227,7 @@ def retrieval_test():
doc_ids = req.get("doc_ids", [])
similarity_threshold = float(req.get("similarity_threshold", 0.2))
vector_similarity_weight = float(req.get("vector_similarity_weight", 0.3))
top = int(req.get("top", 1024))
top = int(req.get("top_k", 1024))
try:
e, kb = KnowledgebaseService.get_by_id(kb_id)
if not e:
@ -237,6 +237,9 @@ def retrieval_test():
kb.tenant_id, LLMType.EMBEDDING.value)
ranks = retrievaler.retrieval(question, embd_mdl, kb.tenant_id, [kb_id], page, size, similarity_threshold,
vector_similarity_weight, top, doc_ids)
for c in ranks["chunks"]:
if "vector" in c:
del c["vector"]
return get_json_result(data=ranks)
except Exception as e:

View File

@ -229,6 +229,7 @@ def use_sql(question,field_map, tenant_id, chat_mdl):
sql = chat_mdl.chat(sys_prompt, [{"role": "user", "content": user_promt}], {"temperature": 0.1})
sql = re.sub(r".*?select ", "select ", sql, flags=re.IGNORECASE)
sql = re.sub(r" +", " ", sql)
sql = re.sub(r"[;].*", "", sql)
if sql[:len("select ")].lower() != "select ":
return None, None
if sql[:len("select *")].lower() != "select *":
@ -241,6 +242,7 @@ def use_sql(question,field_map, tenant_id, chat_mdl):
docnm_idx = set([ii for ii, c in enumerate(tbl["columns"]) if c["name"] == "docnm_kwd"])
clmn_idx = [ii for ii in range(len(tbl["columns"])) if ii not in (docid_idx|docnm_idx)]
# compose markdown table
clmns = "|".join([re.sub(r"/.*", "", field_map.get(tbl["columns"][i]["name"], f"C{i}")) for i in clmn_idx]) + "|原文"
line = "|".join(["------" for _ in range(len(clmn_idx))]) + "|------"
rows = ["|".join([str(r[i]) for i in clmn_idx])+"|" for r in tbl["rows"]]

View File

@ -13,9 +13,10 @@
# See the License for the specific language governing permissions and
# limitations under the License
#
#
import base64
import pathlib
import re
import flask
from elasticsearch_dsl import Q
@ -27,7 +28,7 @@ from api.db.services import duplicate_name
from api.db.services.knowledgebase_service import KnowledgebaseService
from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
from api.utils import get_uuid
from api.db import FileType, TaskStatus
from api.db import FileType, TaskStatus, ParserType
from api.db.services.document_service import DocumentService
from api.settings import RetCode
from api.utils.api_utils import get_json_result
@ -66,7 +67,7 @@ def upload():
location += "_"
blob = request.files['file'].read()
MINIO.put(kb_id, location, blob)
doc = DocumentService.insert({
doc = {
"id": get_uuid(),
"kb_id": kb.id,
"parser_id": kb.parser_id,
@ -77,7 +78,12 @@ def upload():
"location": location,
"size": len(blob),
"thumbnail": thumbnail(filename, blob)
})
}
if doc["type"] == FileType.VISUAL:
doc["parser_id"] = ParserType.PICTURE.value
if re.search(r"\.(ppt|pptx|pages)$", filename):
doc["parser_id"] = ParserType.PRESENTATION.value
doc = DocumentService.insert(doc)
return get_json_result(data=doc.to_json())
except Exception as e:
return server_error_response(e)
@ -283,6 +289,9 @@ def change_parser():
if doc.parser_id.lower() == req["parser_id"].lower():
return get_json_result(data=True)
if doc.type == FileType.VISUAL or re.search(r"\.(ppt|pptx|pages)$", doc.name):
return get_data_error_result(retmsg="Not supported yet!")
e = DocumentService.update_by_id(doc.id, {"parser_id": req["parser_id"], "progress":0, "progress_msg": ""})
if not e:
return get_data_error_result(retmsg="Document not found!")

View File

@ -78,3 +78,5 @@ class ParserType(StrEnum):
BOOK = "book"
QA = "qa"
TABLE = "table"
NAIVE = "naive"
PICTURE = "picture"

View File

@ -381,7 +381,7 @@ class Tenant(DataBaseModel):
embd_id = CharField(max_length=128, null=False, help_text="default embedding model ID")
asr_id = CharField(max_length=128, null=False, help_text="default ASR model ID")
img2txt_id = CharField(max_length=128, null=False, help_text="default image to text model ID")
parser_ids = CharField(max_length=128, null=False, help_text="document processors")
parser_ids = CharField(max_length=256, null=False, help_text="document processors")
credit = IntegerField(default=512)
status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted1: validate)", default="1")

View File

@ -63,7 +63,9 @@ def init_llm_factory():
"status": "1",
},
]
llm_infos = [{
llm_infos = [
# ---------------------- OpenAI ------------------------
{
"fid": factory_infos[0]["name"],
"llm_name": "gpt-3.5-turbo",
"tags": "LLM,CHAT,4K",
@ -105,7 +107,9 @@ def init_llm_factory():
"tags": "LLM,CHAT,IMAGE2TEXT",
"max_tokens": 765,
"model_type": LLMType.IMAGE2TEXT.value
},{
},
# ----------------------- Qwen -----------------------
{
"fid": factory_infos[1]["name"],
"llm_name": "qwen-turbo",
"tags": "LLM,CHAT,8K",
@ -135,7 +139,9 @@ def init_llm_factory():
"tags": "LLM,CHAT,IMAGE2TEXT",
"max_tokens": 765,
"model_type": LLMType.IMAGE2TEXT.value
},{
},
# ----------------------- Infiniflow -----------------------
{
"fid": factory_infos[2]["name"],
"llm_name": "gpt-3.5-turbo",
"tags": "LLM,CHAT,4K",
@ -160,6 +166,33 @@ def init_llm_factory():
"max_tokens": 765,
"model_type": LLMType.IMAGE2TEXT.value
},
# ---------------------- ZhipuAI ----------------------
{
"fid": factory_infos[3]["name"],
"llm_name": "glm-3-turbo",
"tags": "LLM,CHAT,",
"max_tokens": 128 * 1000,
"model_type": LLMType.CHAT.value
}, {
"fid": factory_infos[3]["name"],
"llm_name": "glm-4",
"tags": "LLM,CHAT,",
"max_tokens": 128 * 1000,
"model_type": LLMType.CHAT.value
}, {
"fid": factory_infos[3]["name"],
"llm_name": "glm-4v",
"tags": "LLM,CHAT,IMAGE2TEXT",
"max_tokens": 2000,
"model_type": LLMType.IMAGE2TEXT.value
},
{
"fid": factory_infos[3]["name"],
"llm_name": "embedding-2",
"tags": "TEXT EMBEDDING",
"max_tokens": 512,
"model_type": LLMType.SPEECH2TEXT.value
},
]
for info in factory_infos:
LLMFactoriesService.save(**info)

View File

@ -47,7 +47,7 @@ LLM = get_base_config("llm", {})
CHAT_MDL = LLM.get("chat_model", "gpt-3.5-turbo")
EMBEDDING_MDL = LLM.get("embedding_model", "text-embedding-ada-002")
ASR_MDL = LLM.get("asr_model", "whisper-1")
PARSERS = LLM.get("parsers", "general:General,resume:esume,laws:Laws,manual:Manual,book:Book,paper:Paper,qa:Q&A,presentation:Presentation")
PARSERS = LLM.get("parsers", "general:General,qa:Q&A,resume:Resume,naive:Naive,table:Table,laws:Laws,manual:Manual,book:Book,paper:Paper,presentation:Presentation,picture:Picture")
IMAGE2TEXT_MDL = LLM.get("image2text_model", "gpt-4-vision-preview")
# distribution

View File

@ -57,7 +57,8 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, callback=None, **k
callback(0.8, "Finish parsing.")
else: raise NotImplementedError("file type not supported yet(docx, pdf, txt supported)")
cks = naive_merge(sections, kwargs.get("chunk_token_num", 128), kwargs.get("delimer", "\n。;!?"))
parser_config = kwargs.get("parser_config", {"chunk_token_num": 128, "delimer": "\n。;!?"})
cks = naive_merge(sections, parser_config["chunk_token_num"], parser_config["delimer"])
eng = is_english(cks)
res = []
# wrap up to es documents

View File

@ -24,31 +24,45 @@ class Excel(object):
for i, r in enumerate(rows):
q, a = "", ""
for cell in r:
if not cell.value: continue
if not q: q = str(cell.value)
elif not a: a = str(cell.value)
else: break
if q and a: res.append((q, a))
else: fails.append(str(i+1))
if not cell.value:
continue
if not q:
q = str(cell.value)
elif not a:
a = str(cell.value)
else:
break
if q and a:
res.append((q, a))
else:
fails.append(str(i + 1))
if len(res) % 999 == 0:
callback(len(res)*0.6/total, ("Extract Q&A: {}".format(len(res)) + (f"{len(fails)} failure, line: %s..."%(",".join(fails[:3])) if fails else "")))
callback(len(res) *
0.6 /
total, ("Extract Q&A: {}".format(len(res)) +
(f"{len(fails)} failure, line: %s..." %
(",".join(fails[:3])) if fails else "")))
callback(0.6, ("Extract Q&A: {}. ".format(len(res)) + (
f"{len(fails)} failure, line: %s..." % (",".join(fails[:3])) if fails else "")))
self.is_english = is_english([rmPrefix(q) for q, _ in random_choices(res, k=30) if len(q)>1])
self.is_english = is_english(
[rmPrefix(q) for q, _ in random_choices(res, k=30) if len(q) > 1])
return res
def rmPrefix(txt):
return re.sub(r"^(问题|答案|回答|user|assistant|Q|A|Question|Answer|问|答)[\t: ]+", "", txt.strip(), flags=re.IGNORECASE)
return re.sub(
r"^(问题|答案|回答|user|assistant|Q|A|Question|Answer|问|答)[\t: ]+", "", txt.strip(), flags=re.IGNORECASE)
def beAdoc(d, q, a, eng):
qprefix = "Question: " if eng else "问题:"
aprefix = "Answer: " if eng else "回答:"
d["content_with_weight"] = "\t".join([qprefix+rmPrefix(q), aprefix+rmPrefix(a)])
d["content_with_weight"] = "\t".join(
[qprefix + rmPrefix(q), aprefix + rmPrefix(a)])
if eng:
d["content_ltks"] = " ".join([stemmer.stem(w) for w in word_tokenize(q)])
d["content_ltks"] = " ".join([stemmer.stem(w)
for w in word_tokenize(q)])
else:
d["content_ltks"] = huqie.qie(q)
d["content_sm_ltks"] = huqie.qieqie(d["content_ltks"])
@ -61,7 +75,7 @@ def chunk(filename, binary=None, callback=None, **kwargs):
if re.search(r"\.xlsx?$", filename, re.IGNORECASE):
callback(0.1, "Start to parse.")
excel_parser = Excel()
for q,a in excel_parser(filename, binary, callback):
for q, a in excel_parser(filename, binary, callback):
res.append(beAdoc({}, q, a, excel_parser.is_english))
return res
elif re.search(r"\.(txt|csv)$", filename, re.IGNORECASE):
@ -73,7 +87,8 @@ def chunk(filename, binary=None, callback=None, **kwargs):
with open(filename, "r") as f:
while True:
l = f.readline()
if not l: break
if not l:
break
txt += l
lines = txt.split("\n")
eng = is_english([rmPrefix(l) for l in lines[:100]])
@ -93,12 +108,13 @@ def chunk(filename, binary=None, callback=None, **kwargs):
return res
raise NotImplementedError("file type not supported yet(pptx, pdf supported)")
raise NotImplementedError(
"file type not supported yet(pptx, pdf supported)")
if __name__== "__main__":
if __name__ == "__main__":
import sys
def dummy(a, b):
pass
chunk(sys.argv[1], callback=dummy)

View File

@ -11,15 +11,22 @@ from rag.utils import rmSpace
def chunk(filename, binary=None, callback=None, **kwargs):
if not re.search(r"\.(pdf|doc|docx|txt)$", filename, flags=re.IGNORECASE): raise NotImplementedError("file type not supported yet(pdf supported)")
if not re.search(r"\.(pdf|doc|docx|txt)$", filename, flags=re.IGNORECASE):
raise NotImplementedError("file type not supported yet(pdf supported)")
url = os.environ.get("INFINIFLOW_SERVER")
if not url:raise EnvironmentError("Please set environment variable: 'INFINIFLOW_SERVER'")
if not url:
raise EnvironmentError(
"Please set environment variable: 'INFINIFLOW_SERVER'")
token = os.environ.get("INFINIFLOW_TOKEN")
if not token:raise EnvironmentError("Please set environment variable: 'INFINIFLOW_TOKEN'")
if not token:
raise EnvironmentError(
"Please set environment variable: 'INFINIFLOW_TOKEN'")
if not binary:
with open(filename, "rb") as f: binary = f.read()
with open(filename, "rb") as f:
binary = f.read()
def remote_call():
nonlocal filename, binary
for _ in range(3):
@ -27,14 +34,17 @@ def chunk(filename, binary=None, callback=None, **kwargs):
res = requests.post(url + "/v1/layout/resume/", files=[(filename, binary)],
headers={"Authorization": token}, timeout=180)
res = res.json()
if res["retcode"] != 0: raise RuntimeError(res["retmsg"])
if res["retcode"] != 0:
raise RuntimeError(res["retmsg"])
return res["data"]
except RuntimeError as e:
raise e
except Exception as e:
cron_logger.error("resume parsing:" + str(e))
callback(0.2, "Resume parsing is going on...")
resume = remote_call()
callback(0.6, "Done parsing. Chunking...")
print(json.dumps(resume, ensure_ascii=False, indent=2))
field_map = {
@ -69,34 +79,43 @@ def chunk(filename, binary=None, callback=None, **kwargs):
titles = []
for n in ["name_kwd", "gender_kwd", "position_name_tks", "age_int"]:
v = resume.get(n, "")
if isinstance(v, list):v = v[0]
if n.find("tks") > 0: v = rmSpace(v)
if isinstance(v, list):
v = v[0]
if n.find("tks") > 0:
v = rmSpace(v)
titles.append(str(v))
doc = {
"docnm_kwd": filename,
"title_tks": huqie.qie("-".join(titles)+"-简历")
"title_tks": huqie.qie("-".join(titles) + "-简历")
}
doc["title_sm_tks"] = huqie.qieqie(doc["title_tks"])
pairs = []
for n,m in field_map.items():
if not resume.get(n):continue
for n, m in field_map.items():
if not resume.get(n):
continue
v = resume[n]
if isinstance(v, list):v = " ".join(v)
if n.find("tks") > 0: v = rmSpace(v)
if isinstance(v, list):
v = " ".join(v)
if n.find("tks") > 0:
v = rmSpace(v)
pairs.append((m, str(v)))
doc["content_with_weight"] = "\n".join(["{}: {}".format(re.sub(r"[^]+", "", k), v) for k,v in pairs])
doc["content_with_weight"] = "\n".join(
["{}: {}".format(re.sub(r"[^]+", "", k), v) for k, v in pairs])
doc["content_ltks"] = huqie.qie(doc["content_with_weight"])
doc["content_sm_ltks"] = huqie.qieqie(doc["content_ltks"])
for n, _ in field_map.items(): doc[n] = resume[n]
for n, _ in field_map.items():
doc[n] = resume[n]
print(doc)
KnowledgebaseService.update_parser_config(kwargs["kb_id"], {"field_map": field_map})
KnowledgebaseService.update_parser_config(
kwargs["kb_id"], {"field_map": field_map})
return [doc]
if __name__ == "__main__":
import sys
def dummy(a, b):
pass
chunk(sys.argv[1], callback=dummy)

View File

@ -28,10 +28,15 @@ class Excel(object):
rows = list(ws.rows)
headers = [cell.value for cell in rows[0]]
missed = set([i for i, h in enumerate(headers) if h is None])
headers = [cell.value for i, cell in enumerate(rows[0]) if i not in missed]
headers = [
cell.value for i,
cell in enumerate(
rows[0]) if i not in missed]
data = []
for i, r in enumerate(rows[1:]):
row = [cell.value for ii, cell in enumerate(r) if ii not in missed]
row = [
cell.value for ii,
cell in enumerate(r) if ii not in missed]
if len(row) != len(headers):
fails.append(str(i))
continue
@ -55,8 +60,10 @@ def trans_datatime(s):
def trans_bool(s):
if re.match(r"(true|yes|是)$", str(s).strip(), flags=re.IGNORECASE): return ["yes", ""]
if re.match(r"(false|no|否)$", str(s).strip(), flags=re.IGNORECASE): return ["no", ""]
if re.match(r"(true|yes|是)$", str(s).strip(), flags=re.IGNORECASE):
return ["yes", ""]
if re.match(r"(false|no|否)$", str(s).strip(), flags=re.IGNORECASE):
return ["no", ""]
def column_data_type(arr):
@ -65,7 +72,8 @@ def column_data_type(arr):
trans = {t: f for f, t in
[(int, "int"), (float, "float"), (trans_datatime, "datetime"), (trans_bool, "bool"), (str, "text")]}
for a in arr:
if a is None: continue
if a is None:
continue
if re.match(r"[+-]?[0-9]+(\.0+)?$", str(a).replace("%%", "")):
counts["int"] += 1
elif re.match(r"[+-]?[0-9.]+$", str(a).replace("%%", "")):
@ -79,7 +87,8 @@ def column_data_type(arr):
counts = sorted(counts.items(), key=lambda x: x[1] * -1)
ty = counts[0][0]
for i in range(len(arr)):
if arr[i] is None: continue
if arr[i] is None:
continue
try:
arr[i] = trans[ty](str(arr[i]))
except Exception as e:
@ -105,7 +114,8 @@ def chunk(filename, binary=None, callback=None, **kwargs):
with open(filename, "r") as f:
while True:
l = f.readline()
if not l: break
if not l:
break
txt += l
lines = txt.split("\n")
fails = []
@ -127,14 +137,22 @@ def chunk(filename, binary=None, callback=None, **kwargs):
dfs = [pd.DataFrame(np.array(rows), columns=headers)]
else:
raise NotImplementedError("file type not supported yet(excel, text, csv supported)")
raise NotImplementedError(
"file type not supported yet(excel, text, csv supported)")
res = []
PY = Pinyin()
fieds_map = {"text": "_tks", "int": "_int", "keyword": "_kwd", "float": "_flt", "datetime": "_dt", "bool": "_kwd"}
fieds_map = {
"text": "_tks",
"int": "_int",
"keyword": "_kwd",
"float": "_flt",
"datetime": "_dt",
"bool": "_kwd"}
for df in dfs:
for n in ["id", "_id", "index", "idx"]:
if n in df.columns: del df[n]
if n in df.columns:
del df[n]
clmns = df.columns.values
txts = list(copy.deepcopy(clmns))
py_clmns = [PY.get_pinyins(n)[0].replace("-", "_") for n in clmns]
@ -143,23 +161,29 @@ def chunk(filename, binary=None, callback=None, **kwargs):
cln, ty = column_data_type(df[clmns[j]])
clmn_tys.append(ty)
df[clmns[j]] = cln
if ty == "text": txts.extend([str(c) for c in cln if c])
clmns_map = [(py_clmns[j] + fieds_map[clmn_tys[j]], clmns[j]) for i in range(len(clmns))]
if ty == "text":
txts.extend([str(c) for c in cln if c])
clmns_map = [(py_clmns[j] + fieds_map[clmn_tys[j]], clmns[j])
for i in range(len(clmns))]
eng = is_english(txts)
for ii, row in df.iterrows():
d = {}
row_txt = []
for j in range(len(clmns)):
if row[clmns[j]] is None: continue
if row[clmns[j]] is None:
continue
fld = clmns_map[j][0]
d[fld] = row[clmns[j]] if clmn_tys[j] != "text" else huqie.qie(row[clmns[j]])
d[fld] = row[clmns[j]] if clmn_tys[j] != "text" else huqie.qie(
row[clmns[j]])
row_txt.append("{}:{}".format(clmns[j], row[clmns[j]]))
if not row_txt: continue
if not row_txt:
continue
tokenize(d, "; ".join(row_txt), eng)
res.append(d)
KnowledgebaseService.update_parser_config(kwargs["kb_id"], {"field_map": {k: v for k, v in clmns_map}})
KnowledgebaseService.update_parser_config(
kwargs["kb_id"], {"field_map": {k: v for k, v in clmns_map}})
callback(0.6, "")
return res
@ -168,9 +192,7 @@ def chunk(filename, binary=None, callback=None, **kwargs):
if __name__ == "__main__":
import sys
def dummy(a, b):
pass
chunk(sys.argv[1], callback=dummy)

View File

@ -58,3 +58,21 @@ class QWenChat(Base):
if response.status_code == HTTPStatus.OK:
return response.output.choices[0]['message']['content'], response.usage.output_tokens
return response.message, 0
from zhipuai import ZhipuAI
class ZhipuChat(Base):
def __init__(self, key, model_name="glm-3-turbo"):
self.client = ZhipuAI(api_key=key)
self.model_name = model_name
def chat(self, system, history, gen_conf):
from http import HTTPStatus
history.insert(0, {"role": "system", "content": system})
response = self.client.chat.completions.create(
self.model_name,
messages=history
)
if response.status_code == HTTPStatus.OK:
return response.output.choices[0]['message']['content'], response.usage.completion_tokens
return response.message, 0

View File

@ -61,7 +61,7 @@ class Base(ABC):
class GptV4(Base):
def __init__(self, key, model_name="gpt-4-vision-preview"):
self.client = OpenAI(api_key = key)
self.client = OpenAI(api_key=key)
self.model_name = model_name
def describe(self, image, max_tokens=300):
@ -89,3 +89,22 @@ class QWenCV(Base):
if response.status_code == HTTPStatus.OK:
return response.output.choices[0]['message']['content'], response.usage.output_tokens
return response.message, 0
from zhipuai import ZhipuAI
class Zhipu4V(Base):
def __init__(self, key, model_name="glm-4v"):
self.client = ZhipuAI(api_key=key)
self.model_name = model_name
def describe(self, image, max_tokens=1024):
b64 = self.image2base64(image)
res = self.client.chat.completions.create(
model=self.model_name,
messages=self.prompt(b64),
max_tokens=max_tokens,
)
return res.choices[0].message.content.strip(), res.usage.total_tokens

View File

@ -19,7 +19,6 @@ import dashscope
from openai import OpenAI
from FlagEmbedding import FlagModel
import torch
import os
import numpy as np
from rag.utils import num_tokens_from_string
@ -115,3 +114,20 @@ class QWenEmbed(Base):
text_type="query"
)
return np.array(resp["output"]["embeddings"][0]["embedding"]), resp["usage"]["input_tokens"]
from zhipuai import ZhipuAI
class ZhipuEmbed(Base):
def __init__(self, key, model_name="embedding-2"):
self.client = ZhipuAI(api_key=key)
self.model_name = model_name
def encode(self, texts: list, batch_size=32):
res = self.client.embeddings.create(input=texts,
model=self.model_name)
return np.array([d.embedding for d in res.data]), res.usage.total_tokens
def encode_queries(self, text):
res = self.client.embeddings.create(input=text,
model=self.model_name)
return np.array(res["data"][0]["embedding"]), res.usage.total_tokens

View File

@ -268,9 +268,9 @@ class Dealer:
dim = len(sres.query_vector)
start_idx = (page - 1) * page_size
for i in idx:
ranks["total"] += 1
if sim[i] < similarity_threshold:
break
ranks["total"] += 1
start_idx -= 1
if start_idx >= 0:
continue
@ -280,6 +280,7 @@ class Dealer:
break
id = sres.ids[i]
dnm = sres.field[id]["docnm_kwd"]
did = sres.field[id]["doc_id"]
d = {
"chunk_id": id,
"content_ltks": sres.field[id]["content_ltks"],
@ -296,8 +297,9 @@ class Dealer:
}
ranks["chunks"].append(d)
if dnm not in ranks["doc_aggs"]:
ranks["doc_aggs"][dnm] = 0
ranks["doc_aggs"][dnm] += 1
ranks["doc_aggs"][dnm] = {"doc_id": did, "count": 0}
ranks["doc_aggs"][dnm]["count"] += 1
ranks["doc_aggs"] = [{"doc_name": k, "doc_id": v["doc_id"], "count": v["count"]} for k,v in sorted(ranks["doc_aggs"].items(), key=lambda x:x[1]["count"]*-1)]
return ranks

View File

@ -36,7 +36,7 @@ from rag.nlp import search
from io import BytesIO
import pandas as pd
from rag.app import laws, paper, presentation, manual, qa, table,book
from rag.app import laws, paper, presentation, manual, qa, table, book, resume
from api.db import LLMType, ParserType
from api.db.services.document_service import DocumentService
@ -55,6 +55,7 @@ FACTORY = {
ParserType.LAWS.value: laws,
ParserType.QA.value: qa,
ParserType.TABLE.value: table,
ParserType.RESUME.value: resume,
}
@ -119,7 +120,7 @@ def build(row, cvmdl):
try:
cron_logger.info("Chunkking {}/{}".format(row["location"], row["name"]))
cks = chunker.chunk(row["name"], MINIO.get(row["kb_id"], row["location"]), row["from_page"], row["to_page"],
callback, kb_id=row["kb_id"])
callback, kb_id=row["kb_id"], parser_config=row["parser_config"])
except Exception as e:
if re.search("(No such file|not found)", str(e)):
callback(-1, "Can not find file <%s>" % row["doc_name"])
@ -171,7 +172,7 @@ def init_kb(row):
open(os.path.join(get_project_base_directory(), "conf", "mapping.json"), "r")))
def embedding(docs, mdl):
def embedding(docs, mdl, parser_config={}):
tts, cnts = [rmSpace(d["title_tks"]) for d in docs if d.get("title_tks")], [d["content_with_weight"] for d in docs]
tk_count = 0
if len(tts) == len(cnts):
@ -180,7 +181,8 @@ def embedding(docs, mdl):
cnts, c = mdl.encode(cnts)
tk_count += c
vects = (0.1 * tts + 0.9 * cnts) if len(tts) == len(cnts) else cnts
title_w = float(parser_config.get("filename_embd_weight", 0.1))
vects = (title_w * tts + (1-title_w) * cnts) if len(tts) == len(cnts) else cnts
assert len(vects) == len(docs)
for i, d in enumerate(docs):
@ -216,7 +218,7 @@ def main(comm, mod):
# TODO: exception handler
## set_progress(r["did"], -1, "ERROR: ")
try:
tk_count = embedding(cks, embd_mdl)
tk_count = embedding(cks, embd_mdl, r["parser_config"])
except Exception as e:
callback(-1, "Embedding error:{}".format(str(e)))
cron_logger.error(str(e))