From 9d4bb5767c0ebf544c1517ff2b6b5bc65443f417 Mon Sep 17 00:00:00 2001 From: Kevin Hu Date: Fri, 13 Sep 2024 17:03:51 +0800 Subject: [PATCH] make highlight friendly to English (#2417) ### What problem does this PR solve? #2415 ### Type of change - [x] Performance Improvement --- rag/nlp/__init__.py | 2 +- rag/nlp/search.py | 36 +++++++++++++++++++----------------- rag/utils/__init__.py | 4 ++-- 3 files changed, 22 insertions(+), 20 deletions(-) diff --git a/rag/nlp/__init__.py b/rag/nlp/__init__.py index 64e953cf0..9d8e78763 100644 --- a/rag/nlp/__init__.py +++ b/rag/nlp/__init__.py @@ -214,7 +214,7 @@ def is_english(texts): eng = 0 if not texts: return False for t in texts: - if re.match(r"[a-zA-Z]{2,}", t.strip()): + if re.match(r"[ `a-zA-Z.,':;/\"?<>!\(\)-]", t.strip()): eng += 1 if eng / len(texts) > 0.8: return True diff --git a/rag/nlp/search.py b/rag/nlp/search.py index 478a0909b..2d5ba5945 100644 --- a/rag/nlp/search.py +++ b/rag/nlp/search.py @@ -24,7 +24,7 @@ from dataclasses import dataclass from rag.settings import es_logger from rag.utils import rmSpace -from rag.nlp import rag_tokenizer, query +from rag.nlp import rag_tokenizer, query, is_english import numpy as np @@ -164,7 +164,7 @@ class Dealer: ids=self.es.getDocIds(res), query_vector=q_vec, aggregation=aggs, - highlight=self.getHighlight(res), + highlight=self.getHighlight(res, keywords, "content_with_weight"), field=self.getFields(res, src), keywords=list(kwds) ) @@ -175,26 +175,28 @@ class Dealer: bkts = res["aggregations"]["aggs_" + g]["buckets"] return [(b["key"], b["doc_count"]) for b in bkts] - def getHighlight(self, res): - def rmspace(line): - eng = set(list("qwertyuioplkjhgfdsazxcvbnm")) - r = [] - for t in line.split(" "): - if not t: - continue - if len(r) > 0 and len( - t) > 0 and r[-1][-1] in eng and t[0] in eng: - r.append(" ") - r.append(t) - r = "".join(r) - return r - + def getHighlight(self, res, keywords, fieldnm): ans = {} for d in res["hits"]["hits"]: hlts = d.get("highlight") if not hlts: continue - ans[d["_id"]] = "".join([a for a in list(hlts.items())[0][1]]) + txt = "...".join([a for a in list(hlts.items())[0][1]]) + if not is_english(txt.split(" ")): + ans[d["_id"]] = txt + continue + + txt = d["_source"][fieldnm] + txt = re.sub(r"[\r\n]", " ", txt, flags=re.IGNORECASE|re.MULTILINE) + txts = [] + for w in keywords: + txt = re.sub(r"(^|[ .?/'\"\(\)!,:;-])(%s)([ .?/'\"\(\)!,:;-])"%re.escape(w), r"\1\2\3", txt, flags=re.IGNORECASE|re.MULTILINE) + + for t in re.split(r"[.?!;\n]", txt): + if not re.search(r"[^<>]+", t, flags=re.IGNORECASE|re.MULTILINE): continue + txts.append(t) + ans[d["_id"]] = "...".join(txts) if txts else "...".join([a for a in list(hlts.items())[0][1]]) + return ans def getFields(self, sres, flds): diff --git a/rag/utils/__init__.py b/rag/utils/__init__.py index 8f0d8c28f..796b0e965 100644 --- a/rag/utils/__init__.py +++ b/rag/utils/__init__.py @@ -32,8 +32,8 @@ def singleton(cls, *args, **kw): def rmSpace(txt): - txt = re.sub(r"([^a-z0-9.,]) +([^ ])", r"\1\2", txt, flags=re.IGNORECASE) - return re.sub(r"([^ ]) +([^a-z0-9.,])", r"\1\2", txt, flags=re.IGNORECASE) + txt = re.sub(r"([^a-z0-9.,\)>]) +([^ ])", r"\1\2", txt, flags=re.IGNORECASE) + return re.sub(r"([^ ]) +([^a-z0-9.,\(<])", r"\1\2", txt, flags=re.IGNORECASE) def findMaxDt(fnm):