Added kb_id filter to knn. Fix #3458 (#3513)

### What problem does this PR solve?

Added kb_id filter to knn. Fix #3458

- [x] Bug Fix (non-breaking change which fixes an issue)
This commit is contained in:
Zhichang Yu 2024-11-20 11:47:39 +08:00 committed by Yingfeng Zhang
parent e559cebcdc
commit cad341e794
3 changed files with 39 additions and 43 deletions

View File

@ -757,7 +757,7 @@ class RAGFlowPdfParser:
if ii is not None: if ii is not None:
b = louts[ii] b = louts[ii]
else: else:
logging.warn( logging.warning(
f"Missing layout match: {pn + 1},%s" % f"Missing layout match: {pn + 1},%s" %
(bxs[0].get( (bxs[0].get(
"layoutno", ""))) "layoutno", "")))

View File

@ -33,7 +33,7 @@ class Dealer:
try: try:
self.dictionary = json.load(open(path, 'r')) self.dictionary = json.load(open(path, 'r'))
except Exception: except Exception:
logging.warn("Missing synonym.json") logging.warning("Missing synonym.json")
self.dictionary = {} self.dictionary = {}
if not redis: if not redis:

View File

@ -35,7 +35,7 @@ class ESConnection(DocStoreConnection):
self.info = self.es.info() self.info = self.es.info()
break break
except Exception as e: except Exception as e:
logging.warn(f"{str(e)}. Waiting Elasticsearch {settings.ES['hosts']} to be healthy.") logging.warning(f"{str(e)}. Waiting Elasticsearch {settings.ES['hosts']} to be healthy.")
time.sleep(5) time.sleep(5)
if not self.es.ping(): if not self.es.ping():
msg = f"Elasticsearch {settings.ES['hosts']} didn't become healthy in 120s." msg = f"Elasticsearch {settings.ES['hosts']} didn't become healthy in 120s."
@ -80,7 +80,7 @@ class ESConnection(DocStoreConnection):
settings=self.mapping["settings"], settings=self.mapping["settings"],
mappings=self.mapping["mappings"]) mappings=self.mapping["mappings"])
except Exception: except Exception:
logging.exception("ES create index error %s" % (indexName)) logging.exception("ESConnection.createIndex error %s" % (indexName))
def deleteIdx(self, indexName: str, knowledgebaseId: str): def deleteIdx(self, indexName: str, knowledgebaseId: str):
try: try:
@ -88,7 +88,7 @@ class ESConnection(DocStoreConnection):
except NotFoundError: except NotFoundError:
pass pass
except Exception: except Exception:
logging.exception("ES delete index error %s" % (indexName)) logging.exception("ESConnection.deleteIdx error %s" % (indexName))
def indexExist(self, indexName: str, knowledgebaseId: str) -> bool: def indexExist(self, indexName: str, knowledgebaseId: str) -> bool:
s = Index(indexName, self.es) s = Index(indexName, self.es)
@ -96,7 +96,7 @@ class ESConnection(DocStoreConnection):
try: try:
return s.exists() return s.exists()
except Exception as e: except Exception as e:
logging.exception("ES indexExist") logging.exception("ESConnection.indexExist got exception")
if str(e).find("Timeout") > 0 or str(e).find("Conflict") > 0: if str(e).find("Timeout") > 0 or str(e).find("Conflict") > 0:
continue continue
return False return False
@ -115,8 +115,21 @@ class ESConnection(DocStoreConnection):
indexNames = indexNames.split(",") indexNames = indexNames.split(",")
assert isinstance(indexNames, list) and len(indexNames) > 0 assert isinstance(indexNames, list) and len(indexNames) > 0
assert "_id" not in condition assert "_id" not in condition
bqry = Q("bool", must=[])
condition["kb_id"] = knowledgebaseIds
for k, v in condition.items():
if not isinstance(k, str) or not v:
continue
if isinstance(v, list):
bqry.filter.append(Q("terms", **{k: v}))
elif isinstance(v, str) or isinstance(v, int):
bqry.filter.append(Q("term", **{k: v}))
else:
raise Exception(
f"Condition `{str(k)}={str(v)}` value type is {str(type(v))}, expected to be int, str or list.")
s = Search() s = Search()
bqry = None
vector_similarity_weight = 0.5 vector_similarity_weight = 0.5
for m in matchExprs: for m in matchExprs:
if isinstance(m, FusionExpr) and m.method == "weighted_sum" and "weights" in m.fusion_params: if isinstance(m, FusionExpr) and m.method == "weighted_sum" and "weights" in m.fusion_params:
@ -130,13 +143,12 @@ class ESConnection(DocStoreConnection):
minimum_should_match = "0%" minimum_should_match = "0%"
if "minimum_should_match" in m.extra_options: if "minimum_should_match" in m.extra_options:
minimum_should_match = str(int(m.extra_options["minimum_should_match"] * 100)) + "%" minimum_should_match = str(int(m.extra_options["minimum_should_match"] * 100)) + "%"
bqry = Q("bool", bqry.must.append(Q("query_string", fields=m.fields,
must=Q("query_string", fields=m.fields,
type="best_fields", query=m.matching_text, type="best_fields", query=m.matching_text,
minimum_should_match=minimum_should_match, minimum_should_match=minimum_should_match,
boost=1), boost=1))
boost=1.0 - vector_similarity_weight, bqry.boost = 1.0 - vector_similarity_weight
)
elif isinstance(m, MatchDenseExpr): elif isinstance(m, MatchDenseExpr):
assert (bqry is not None) assert (bqry is not None)
similarity = 0.0 similarity = 0.0
@ -150,21 +162,6 @@ class ESConnection(DocStoreConnection):
similarity=similarity, similarity=similarity,
) )
condition["kb_id"] = knowledgebaseIds
if condition:
if not bqry:
bqry = Q("bool", must=[])
for k, v in condition.items():
if not isinstance(k, str) or not v:
continue
if isinstance(v, list):
bqry.filter.append(Q("terms", **{k: v}))
elif isinstance(v, str) or isinstance(v, int):
bqry.filter.append(Q("term", **{k: v}))
else:
raise Exception(
f"Condition `{str(k)}={str(v)}` value type is {str(type(v))}, expected to be int, str or list.")
if bqry: if bqry:
s = s.query(bqry) s = s.query(bqry)
for field in highlightFields: for field in highlightFields:
@ -181,8 +178,7 @@ class ESConnection(DocStoreConnection):
if limit > 0: if limit > 0:
s = s[offset:limit] s = s[offset:limit]
q = s.to_dict() q = s.to_dict()
print(json.dumps(q), flush=True) logging.debug(f"ESConnection.search {str(indexNames)} query: " + json.dumps(q))
logging.debug("ESConnection.search [Q]: " + json.dumps(q))
for i in range(3): for i in range(3):
try: try:
@ -194,15 +190,15 @@ class ESConnection(DocStoreConnection):
_source=True) _source=True)
if str(res.get("timed_out", "")).lower() == "true": if str(res.get("timed_out", "")).lower() == "true":
raise Exception("Es Timeout.") raise Exception("Es Timeout.")
logging.debug("ESConnection.search res: " + str(res)) logging.debug(f"ESConnection.search {str(indexNames)} res: " + str(res))
return res return res
except Exception as e: except Exception as e:
logging.exception("ES search [Q]: " + str(q)) logging.exception(f"ESConnection.search {str(indexNames)} query: " + str(q))
if str(e).find("Timeout") > 0: if str(e).find("Timeout") > 0:
continue continue
raise e raise e
logging.error("ES search timeout for 3 times!") logging.error("ESConnection.search timeout for 3 times!")
raise Exception("ES search timeout.") raise Exception("ESConnection.search timeout.")
def get(self, chunkId: str, indexName: str, knowledgebaseIds: list[str]) -> dict | None: def get(self, chunkId: str, indexName: str, knowledgebaseIds: list[str]) -> dict | None:
for i in range(3): for i in range(3):
@ -217,12 +213,12 @@ class ESConnection(DocStoreConnection):
chunk["id"] = chunkId chunk["id"] = chunkId
return chunk return chunk
except Exception as e: except Exception as e:
logging.exception(f"ES get({chunkId}) got exception") logging.exception(f"ESConnection.get({chunkId}) got exception")
if str(e).find("Timeout") > 0: if str(e).find("Timeout") > 0:
continue continue
raise e raise e
logging.error("ES search timeout for 3 times!") logging.error("ESConnection.get timeout for 3 times!")
raise Exception("ES search timeout.") raise Exception("ESConnection.get timeout.")
def insert(self, documents: list[dict], indexName: str, knowledgebaseId: str) -> list[str]: def insert(self, documents: list[dict], indexName: str, knowledgebaseId: str) -> list[str]:
# Refers to https://www.elastic.co/guide/en/elasticsearch/reference/current/docs-bulk.html # Refers to https://www.elastic.co/guide/en/elasticsearch/reference/current/docs-bulk.html
@ -250,7 +246,7 @@ class ESConnection(DocStoreConnection):
res.append(str(item[action]["_id"]) + ":" + str(item[action]["error"])) res.append(str(item[action]["_id"]) + ":" + str(item[action]["error"]))
return res return res
except Exception as e: except Exception as e:
logging.warning("Fail to bulk: " + str(e)) logging.warning("ESConnection.insert got exception: " + str(e))
if re.search(r"(Timeout|time out)", str(e), re.IGNORECASE): if re.search(r"(Timeout|time out)", str(e), re.IGNORECASE):
time.sleep(3) time.sleep(3)
continue continue
@ -268,7 +264,7 @@ class ESConnection(DocStoreConnection):
return True return True
except Exception as e: except Exception as e:
logging.exception( logging.exception(
f"ES failed to update(index={indexName}, id={id}, doc={json.dumps(condition, ensure_ascii=False)})") f"ESConnection.update(index={indexName}, id={id}, doc={json.dumps(condition, ensure_ascii=False)}) got exception")
if str(e).find("Timeout") > 0: if str(e).find("Timeout") > 0:
continue continue
else: else:
@ -307,7 +303,7 @@ class ESConnection(DocStoreConnection):
_ = ubq.execute() _ = ubq.execute()
return True return True
except Exception as e: except Exception as e:
logging.error("ES update exception: " + str(e) + "[Q]:" + str(bqry.to_dict())) logging.error("ESConnection.update got exception: " + str(e))
if str(e).find("Timeout") > 0 or str(e).find("Conflict") > 0: if str(e).find("Timeout") > 0 or str(e).find("Conflict") > 0:
continue continue
return False return False
@ -329,7 +325,7 @@ class ESConnection(DocStoreConnection):
qry.must.append(Q("term", **{k: v})) qry.must.append(Q("term", **{k: v}))
else: else:
raise Exception("Condition value must be int, str or list.") raise Exception("Condition value must be int, str or list.")
logging.debug("ESConnection.delete [Q]: " + json.dumps(qry.to_dict())) logging.debug("ESConnection.delete query: " + json.dumps(qry.to_dict()))
for _ in range(10): for _ in range(10):
try: try:
res = self.es.delete_by_query( res = self.es.delete_by_query(
@ -338,7 +334,7 @@ class ESConnection(DocStoreConnection):
refresh=True) refresh=True)
return res["deleted"] return res["deleted"]
except Exception as e: except Exception as e:
logging.warning("Fail to delete: " + str(filter) + str(e)) logging.warning("ESConnection.delete got exception: " + str(e))
if re.search(r"(Timeout|time out)", str(e), re.IGNORECASE): if re.search(r"(Timeout|time out)", str(e), re.IGNORECASE):
time.sleep(3) time.sleep(3)
continue continue
@ -447,10 +443,10 @@ class ESConnection(DocStoreConnection):
request_timeout="2s") request_timeout="2s")
return res return res
except ConnectionTimeout: except ConnectionTimeout:
logging.exception("ESConnection.sql timeout [Q]: " + sql) logging.exception("ESConnection.sql timeout")
continue continue
except Exception: except Exception:
logging.exception("ESConnection.sql got exception [Q]: " + sql) logging.exception("ESConnection.sql got exception")
return None return None
logging.error("ESConnection.sql timeout for 3 times!") logging.error("ESConnection.sql timeout for 3 times!")
return None return None