ragflow/python/nlp/search.py

224 lines
7.8 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import re
from elasticsearch_dsl import Q,Search,A
from typing import List, Optional, Tuple,Dict, Union
from dataclasses import dataclass
from util import setup_logging, rmSpace
from nlp import huqie, query
from datetime import datetime
from sklearn.metrics.pairwise import cosine_similarity as CosineSimilarity
import numpy as np
from copy import deepcopy
def index_name(uid):return f"docgpt_{uid}"
class Dealer:
def __init__(self, es, emb_mdl):
self.qryr = query.EsQueryer(es)
self.qryr.flds = ["title_tks^10", "title_sm_tks^5", "content_ltks^2", "content_sm_ltks"]
self.es = es
self.emb_mdl = emb_mdl
@dataclass
class SearchResult:
total:int
ids: List[str]
query_vector: List[float] = None
field: Optional[Dict] = None
highlight: Optional[Dict] = None
aggregation: Union[List, Dict, None] = None
keywords: Optional[List[str]] = None
group_docs: List[List] = None
def _vector(self, txt, sim=0.8, topk=10):
return {
"field": "q_vec",
"k": topk,
"similarity": sim,
"num_candidates": 1000,
"query_vector": self.emb_mdl.encode_queries(txt)
}
def search(self, req, idxnm, tks_num=3):
keywords = []
qst = req.get("question", "")
bqry,keywords = self.qryr.question(qst)
if req.get("kb_ids"): bqry.filter.append(Q("terms", kb_id=req["kb_ids"]))
bqry.filter.append(Q("exists", field="q_tks"))
bqry.boost = 0.05
print(bqry)
s = Search()
pg = int(req.get("page", 1))-1
ps = int(req.get("size", 1000))
src = req.get("field", ["docnm_kwd", "content_ltks", "kb_id",
"image_id", "doc_id", "q_vec"])
s = s.query(bqry)[pg*ps:(pg+1)*ps]
s = s.highlight("content_ltks")
s = s.highlight("title_ltks")
if not qst: s = s.sort({"create_time":{"order":"desc", "unmapped_type":"date"}})
s = s.highlight_options(
fragment_size = 120,
number_of_fragments=5,
boundary_scanner_locale="zh-CN",
boundary_scanner="SENTENCE",
boundary_chars=",./;:\\!(),。?:!……()——、"
)
s = s.to_dict()
q_vec = []
if req.get("vector"):
s["knn"] = self._vector(qst, req.get("similarity", 0.4), ps)
s["knn"]["filter"] = bqry.to_dict()
del s["highlight"]
q_vec = s["knn"]["query_vector"]
res = self.es.search(s, idxnm=idxnm, timeout="600s",src=src)
print("TOTAL: ", self.es.getTotal(res))
if self.es.getTotal(res) == 0 and "knn" in s:
bqry,_ = self.qryr.question(qst, min_match="10%")
if req.get("kb_ids"): bqry.filter.append(Q("terms", kb_id=req["kb_ids"]))
s["query"] = bqry.to_dict()
s["knn"]["filter"] = bqry.to_dict()
s["knn"]["similarity"] = 0.7
res = self.es.search(s, idxnm=idxnm, timeout="600s",src=src)
kwds = set([])
for k in keywords:
kwds.add(k)
for kk in huqie.qieqie(k).split(" "):
if len(kk) < 2:continue
if kk in kwds:continue
kwds.add(kk)
aggs = self.getAggregation(res, "docnm_kwd")
return self.SearchResult(
total = self.es.getTotal(res),
ids = self.es.getDocIds(res),
query_vector = q_vec,
aggregation = aggs,
highlight = self.getHighlight(res),
field = self.getFields(res, ["docnm_kwd", "content_ltks",
"kb_id","image_id", "doc_id", "q_vec"]),
keywords = list(kwds)
)
def getAggregation(self, res, g):
if not "aggregations" in res or "aggs_"+g not in res["aggregations"]:return
bkts = res["aggregations"]["aggs_"+g]["buckets"]
return [(b["key"], b["doc_count"]) for b in bkts]
def getHighlight(self, res):
def rmspace(line):
eng = set(list("qwertyuioplkjhgfdsazxcvbnm"))
r = []
for t in line.split(" "):
if not t:continue
if len(r)>0 and len(t)>0 and r[-1][-1] in eng and t[0] in eng:r.append(" ")
r.append(t)
r = "".join(r)
return r
ans = {}
for d in res["hits"]["hits"]:
hlts = d.get("highlight")
if not hlts:continue
ans[d["_id"]] = "".join([a for a in list(hlts.items())[0][1]])
return ans
def getFields(self, sres, flds):
res = {}
if not flds:return {}
for d in self.es.getSource(sres):
m = {n:d.get(n) for n in flds if d.get(n) is not None}
for n,v in m.items():
if type(v) == type([]):
m[n] = "\t".join([str(vv) for vv in v])
continue
if type(v) != type(""):m[n] = str(m[n])
m[n] = rmSpace(m[n])
if m:res[d["id"]] = m
return res
@staticmethod
def trans2floats(txt):
return [float(t) for t in txt.split("\t")]
def insert_citations(self, ans, top_idx, sres, vfield = "q_vec", cfield="content_ltks"):
ins_embd = [Dealer.trans2floats(sres.field[sres.ids[i]][vfield]) for i in top_idx]
ins_tw =[sres.field[sres.ids[i]][cfield].split(" ") for i in top_idx]
s = 0
e = 0
res = ""
def citeit():
nonlocal s, e, ans, res
if not ins_embd:return
embd = self.emb_mdl.encode(ans[s: e])
sim = self.qryr.hybrid_similarity(embd,
ins_embd,
huqie.qie(ans[s:e]).split(" "),
ins_tw)
print(ans[s: e], sim)
mx = np.max(sim)*0.99
if mx < 0.55:return
cita = list(set([top_idx[i] for i in range(len(ins_embd)) if sim[i] >mx]))[:4]
for i in cita: res += f"@?{i}?@"
return cita
punct = set(";。?!")
if not self.qryr.isChinese(ans):
punct.add("?")
punct.add(".")
while e < len(ans):
if e - s < 12 or ans[e] not in punct:
e += 1
continue
if ans[e] == "." and e+1<len(ans) and re.match(r"[0-9]", ans[e+1]):
e += 1
continue
if ans[e] == "." and e-2>=0 and ans[e-2] == "\n":
e += 1
continue
res += ans[s: e]
citeit()
res += ans[e]
e += 1
s = e
if s< len(ans):
res += ans[s:]
citeit()
return res
def rerank(self, sres, query, tkweight=0.3, vtweight=0.7, vfield="q_vec", cfield="content_ltks"):
ins_embd = [Dealer.trans2floats(sres.field[i]["q_vec"]) for i in sres.ids]
if not ins_embd: return []
ins_tw =[sres.field[i][cfield].split(" ") for i in sres.ids]
#return CosineSimilarity([sres.query_vector], ins_embd)[0]
sim = self.qryr.hybrid_similarity(sres.query_vector,
ins_embd,
huqie.qie(query).split(" "),
ins_tw, tkweight, vtweight)
return sim
if __name__ == "__main__":
from util import es_conn
SE = Dealer(es_conn.HuEs("infiniflow"))
qs = [
"胡凯",
""
]
for q in qs:
print(">>>>>>>>>>>>>>>>>>>>", q)
print(SE.search({"question": q, "kb_ids": "64f072a75f3b97c865718c4a"}, "infiniflow_*"))