Added pagerank support to infinity (#4059)

### What problem does this PR solve?

Added pagerank support to infinity

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
This commit is contained in:
Zhichang Yu 2024-12-17 15:45:01 +08:00 committed by GitHub
parent fddac1345d
commit bcccaccc2b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 18 additions and 9 deletions

View File

@ -107,6 +107,7 @@ def update():
settings.docStoreConn.update({"kb_id": kb.id}, {"pagerank_fea": req["pagerank"]}, settings.docStoreConn.update({"kb_id": kb.id}, {"pagerank_fea": req["pagerank"]},
search.index_name(kb.tenant_id), kb.id) search.index_name(kb.tenant_id), kb.id)
else: else:
# Elasticsearch requires pagerank_fea be non-zero!
settings.docStoreConn.update({"exist": "pagerank_fea"}, {"remove": "pagerank_fea"}, settings.docStoreConn.update({"exist": "pagerank_fea"}, {"remove": "pagerank_fea"},
search.index_name(kb.tenant_id), kb.id) search.index_name(kb.tenant_id), kb.id)

View File

@ -46,13 +46,14 @@ def equivalent_condition_to_str(condition: dict) -> str|None:
cond.append(f"{k}='{v}'") cond.append(f"{k}='{v}'")
else: else:
cond.append(f"{k}={str(v)}") cond.append(f"{k}={str(v)}")
return " AND ".join(cond) if cond else None return " AND ".join(cond) if cond else "1=1"
def concat_dataframes(df_list: list[pl.DataFrame], selectFields: list[str]) -> pl.DataFrame: def concat_dataframes(df_list: list[pl.DataFrame], selectFields: list[str]) -> pl.DataFrame:
""" """
Concatenate multiple dataframes into one. Concatenate multiple dataframes into one.
""" """
df_list = [df for df in df_list if not df.is_empty()]
if df_list: if df_list:
return pl.concat(df_list) return pl.concat(df_list)
schema = dict() schema = dict()
@ -246,8 +247,9 @@ class InfinityConnection(DocStoreConnection):
db_instance = inf_conn.get_database(self.dbName) db_instance = inf_conn.get_database(self.dbName)
df_list = list() df_list = list()
table_list = list() table_list = list()
if "id" not in selectFields: for essential_field in ["id", "score()", "pagerank_fea"]:
selectFields.append("id") if essential_field not in selectFields:
selectFields.append(essential_field)
# Prepare expressions common to all tables # Prepare expressions common to all tables
filter_cond = None filter_cond = None
@ -331,10 +333,13 @@ class InfinityConnection(DocStoreConnection):
kb_res, extra_result = builder.option({"total_hits_count": True}).to_pl() kb_res, extra_result = builder.option({"total_hits_count": True}).to_pl()
if extra_result: if extra_result:
total_hits_count += int(extra_result["total_hits_count"]) total_hits_count += int(extra_result["total_hits_count"])
logger.debug(f"INFINITY search table: {str(table_name)}, result: {str(kb_res)}")
df_list.append(kb_res) df_list.append(kb_res)
self.connPool.release_conn(inf_conn) self.connPool.release_conn(inf_conn)
res = concat_dataframes(df_list, selectFields) res = concat_dataframes(df_list, selectFields)
logger.debug(f"INFINITY search tables: {str(table_list)}, result: {str(res)}") res = res.sort(pl.col("SCORE") + pl.col("pagerank_fea"), descending=True, maintain_order=True)
res = res.limit(limit)
logger.debug(f"INFINITY search final result: {str(res)}")
return res, total_hits_count return res, total_hits_count
def get( def get(
@ -350,12 +355,10 @@ class InfinityConnection(DocStoreConnection):
table_list.append(table_name) table_list.append(table_name)
table_instance = db_instance.get_table(table_name) table_instance = db_instance.get_table(table_name)
kb_res, _ = table_instance.output(["*"]).filter(f"id = '{chunkId}'").to_pl() kb_res, _ = table_instance.output(["*"]).filter(f"id = '{chunkId}'").to_pl()
if len(kb_res) != 0 and kb_res.shape[0] > 0: logger.debug(f"INFINITY get table: {str(table_list)}, result: {str(kb_res)}")
df_list.append(kb_res) df_list.append(kb_res)
self.connPool.release_conn(inf_conn) self.connPool.release_conn(inf_conn)
res = concat_dataframes(df_list, ["id"]) res = concat_dataframes(df_list, ["id"])
logger.debug(f"INFINITY get tables: {str(table_list)}, result: {str(res)}")
res_fields = self.getFields(res, res.columns) res_fields = self.getFields(res, res.columns)
return res_fields.get(chunkId, None) return res_fields.get(chunkId, None)
@ -421,8 +424,10 @@ class InfinityConnection(DocStoreConnection):
db_instance = inf_conn.get_database(self.dbName) db_instance = inf_conn.get_database(self.dbName)
table_name = f"{indexName}_{knowledgebaseId}" table_name = f"{indexName}_{knowledgebaseId}"
table_instance = db_instance.get_table(table_name) table_instance = db_instance.get_table(table_name)
if "exist" in condition:
del condition["exist"]
filter = equivalent_condition_to_str(condition) filter = equivalent_condition_to_str(condition)
for k, v in newValue.items(): for k, v in list(newValue.items()):
if k.endswith("_kwd") and isinstance(v, list): if k.endswith("_kwd") and isinstance(v, list):
newValue[k] = " ".join(v) newValue[k] = " ".join(v)
elif k == 'kb_id': elif k == 'kb_id':
@ -435,6 +440,9 @@ class InfinityConnection(DocStoreConnection):
elif k in ["page_num_int", "top_int"]: elif k in ["page_num_int", "top_int"]:
assert isinstance(v, list) assert isinstance(v, list)
newValue[k] = "_".join(f"{num:08x}" for num in v) newValue[k] = "_".join(f"{num:08x}" for num in v)
elif k == "remove" and v in ["pagerank_fea"]:
del newValue[k]
newValue[v] = 0
logger.debug(f"INFINITY update table {table_name}, filter {filter}, newValue {newValue}.") logger.debug(f"INFINITY update table {table_name}, filter {filter}, newValue {newValue}.")
table_instance.update(filter, newValue) table_instance.update(filter, newValue)
self.connPool.release_conn(inf_conn) self.connPool.release_conn(inf_conn)