use to_df replace to_pl when get infinity Result (#5604)

### What problem does this PR solve?

_Briefly describe what this PR aims to solve. Include background context
that will help reviewers understand the purpose of the PR._

### Type of change

- [x] Performance Improvement

---------

Co-authored-by: wangwei <dwxiayi@163.com>
This commit is contained in:
汪威 2025-03-05 09:35:40 +08:00 committed by GitHub
parent 555c70672e
commit 76e8285904
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 70 additions and 75 deletions

View File

@ -9,10 +9,10 @@
"title_tks": {"type": "varchar", "default": "", "analyzer": "whitespace"}, "title_tks": {"type": "varchar", "default": "", "analyzer": "whitespace"},
"title_sm_tks": {"type": "varchar", "default": "", "analyzer": "whitespace"}, "title_sm_tks": {"type": "varchar", "default": "", "analyzer": "whitespace"},
"name_kwd": {"type": "varchar", "default": "", "analyzer": "whitespace"}, "name_kwd": {"type": "varchar", "default": "", "analyzer": "whitespace"},
"important_kwd": {"type": "varchar", "default": "", "analyzer": "whitespace"}, "important_kwd": {"type": "varchar", "default": "", "analyzer": "whitespace-#"},
"tag_kwd": {"type": "varchar", "default": "", "analyzer": "whitespace"}, "tag_kwd": {"type": "varchar", "default": "", "analyzer": "whitespace-#"},
"important_tks": {"type": "varchar", "default": "", "analyzer": "whitespace"}, "important_tks": {"type": "varchar", "default": "", "analyzer": "whitespace"},
"question_kwd": {"type": "varchar", "default": "", "analyzer": "whitespace"}, "question_kwd": {"type": "varchar", "default": "", "analyzer": "whitespace-#"},
"question_tks": {"type": "varchar", "default": "", "analyzer": "whitespace"}, "question_tks": {"type": "varchar", "default": "", "analyzer": "whitespace"},
"content_with_weight": {"type": "varchar", "default": ""}, "content_with_weight": {"type": "varchar", "default": ""},
"content_ltks": {"type": "varchar", "default": "", "analyzer": "whitespace"}, "content_ltks": {"type": "varchar", "default": "", "analyzer": "whitespace"},
@ -28,7 +28,7 @@
"rank_flt": {"type": "float", "default": 0}, "rank_flt": {"type": "float", "default": 0},
"available_int": {"type": "integer", "default": 1}, "available_int": {"type": "integer", "default": 1},
"knowledge_graph_kwd": {"type": "varchar", "default": "", "analyzer": "whitespace"}, "knowledge_graph_kwd": {"type": "varchar", "default": "", "analyzer": "whitespace"},
"entities_kwd": {"type": "varchar", "default": "", "analyzer": "whitespace"}, "entities_kwd": {"type": "varchar", "default": "", "analyzer": "whitespace-#"},
"pagerank_fea": {"type": "integer", "default": 0}, "pagerank_fea": {"type": "integer", "default": 0},
"tag_feas": {"type": "varchar", "default": ""}, "tag_feas": {"type": "varchar", "default": ""},

View File

@ -28,8 +28,7 @@ from infinity.errors import ErrorCode
from rag import settings from rag import settings
from rag.settings import PAGERANK_FLD from rag.settings import PAGERANK_FLD
from rag.utils import singleton from rag.utils import singleton
import polars as pl import pandas as pd
from polars.series.series import Series
from api.utils.file_utils import get_project_base_directory from api.utils.file_utils import get_project_base_directory
from rag.utils.doc_store_conn import ( from rag.utils.doc_store_conn import (
@ -90,20 +89,20 @@ def equivalent_condition_to_str(condition: dict, table_instance=None) -> str | N
return " AND ".join(cond) if cond else "1=1" return " AND ".join(cond) if cond else "1=1"
def concat_dataframes(df_list: list[pl.DataFrame], selectFields: list[str]) -> pl.DataFrame: def concat_dataframes(df_list: list[pd.DataFrame], selectFields: list[str]) -> pd.DataFrame:
""" df_list2 = [df for df in df_list if not df.empty]
Concatenate multiple dataframes into one. if df_list2:
""" return pd.concat(df_list2, axis=0).reset_index(drop=True)
df_list = [df for df in df_list if not df.is_empty()]
if df_list: schema = []
return pl.concat(df_list)
schema = dict()
for field_name in selectFields: for field_name in selectFields:
if field_name == 'score()': # Workaround: fix schema is changed to score() if field_name == 'score()': # Workaround: fix schema is changed to score()
schema['SCORE'] = str schema.append('SCORE')
elif field_name == 'similarity()': # Workaround: fix schema is changed to similarity()
schema.append('SIMILARITY')
else: else:
schema[field_name] = str schema.append(field_name)
return pl.DataFrame(schema=schema) return pd.DataFrame(columns=schema)
@singleton @singleton
@ -121,7 +120,7 @@ class InfinityConnection(DocStoreConnection):
connPool = ConnectionPool(infinity_uri) connPool = ConnectionPool(infinity_uri)
inf_conn = connPool.get_conn() inf_conn = connPool.get_conn()
res = inf_conn.show_current_node() res = inf_conn.show_current_node()
if res.error_code == ErrorCode.OK and res.server_status == "started": if res.error_code == ErrorCode.OK and res.server_status in ["started", "alive"]:
self._migrate_db(inf_conn) self._migrate_db(inf_conn)
self.connPool = connPool self.connPool = connPool
connPool.release_conn(inf_conn) connPool.release_conn(inf_conn)
@ -189,7 +188,7 @@ class InfinityConnection(DocStoreConnection):
self.connPool.release_conn(inf_conn) self.connPool.release_conn(inf_conn)
res2 = { res2 = {
"type": "infinity", "type": "infinity",
"status": "green" if res.error_code == 0 and res.server_status == "started" else "red", "status": "green" if res.error_code == 0 and res.server_status in ["started", "alive"] else "red",
"error": res.error_msg, "error": res.error_msg,
} }
return res2 return res2
@ -281,7 +280,7 @@ class InfinityConnection(DocStoreConnection):
knowledgebaseIds: list[str], knowledgebaseIds: list[str],
aggFields: list[str] = [], aggFields: list[str] = [],
rank_feature: dict | None = None rank_feature: dict | None = None
) -> list[dict] | pl.DataFrame: ) -> tuple[pd.DataFrame, int]:
""" """
TODO: Infinity doesn't provide highlight TODO: Infinity doesn't provide highlight
""" """
@ -292,9 +291,10 @@ class InfinityConnection(DocStoreConnection):
db_instance = inf_conn.get_database(self.dbName) db_instance = inf_conn.get_database(self.dbName)
df_list = list() df_list = list()
table_list = list() table_list = list()
output = selectFields.copy()
for essential_field in ["id"]: for essential_field in ["id"]:
if essential_field not in selectFields: if essential_field not in output:
selectFields.append(essential_field) output.append(essential_field)
score_func = "" score_func = ""
score_column = "" score_column = ""
for matchExpr in matchExprs: for matchExpr in matchExprs:
@ -309,9 +309,11 @@ class InfinityConnection(DocStoreConnection):
score_column = "SIMILARITY" score_column = "SIMILARITY"
break break
if matchExprs: if matchExprs:
selectFields.append(score_func) if score_func not in output:
selectFields.append(PAGERANK_FLD) output.append(score_func)
selectFields = [f for f in selectFields if f != "_score"] if PAGERANK_FLD not in output:
output.append(PAGERANK_FLD)
output = [f for f in output if f != "_score"]
# Prepare expressions common to all tables # Prepare expressions common to all tables
filter_cond = None filter_cond = None
@ -339,7 +341,7 @@ class InfinityConnection(DocStoreConnection):
matchExpr.extra_options[k] = str(v) matchExpr.extra_options[k] = str(v)
logger.debug(f"INFINITY search MatchTextExpr: {json.dumps(matchExpr.__dict__)}") logger.debug(f"INFINITY search MatchTextExpr: {json.dumps(matchExpr.__dict__)}")
elif isinstance(matchExpr, MatchDenseExpr): elif isinstance(matchExpr, MatchDenseExpr):
if filter_fulltext and filter_cond and "filter" not in matchExpr.extra_options: if filter_fulltext and "filter" not in matchExpr.extra_options:
matchExpr.extra_options.update({"filter": filter_fulltext}) matchExpr.extra_options.update({"filter": filter_fulltext})
for k, v in matchExpr.extra_options.items(): for k, v in matchExpr.extra_options.items():
if not isinstance(v, str): if not isinstance(v, str):
@ -370,7 +372,7 @@ class InfinityConnection(DocStoreConnection):
except Exception: except Exception:
continue continue
table_list.append(table_name) table_list.append(table_name)
builder = table_instance.output(selectFields) builder = table_instance.output(output)
if len(matchExprs) > 0: if len(matchExprs) > 0:
for matchExpr in matchExprs: for matchExpr in matchExprs:
if isinstance(matchExpr, MatchTextExpr): if isinstance(matchExpr, MatchTextExpr):
@ -379,7 +381,7 @@ class InfinityConnection(DocStoreConnection):
fields, fields,
matchExpr.matching_text, matchExpr.matching_text,
matchExpr.topn, matchExpr.topn,
matchExpr.extra_options, matchExpr.extra_options.copy(),
) )
elif isinstance(matchExpr, MatchDenseExpr): elif isinstance(matchExpr, MatchDenseExpr):
builder = builder.match_dense( builder = builder.match_dense(
@ -388,7 +390,7 @@ class InfinityConnection(DocStoreConnection):
matchExpr.embedding_data_type, matchExpr.embedding_data_type,
matchExpr.distance_type, matchExpr.distance_type,
matchExpr.topn, matchExpr.topn,
matchExpr.extra_options, matchExpr.extra_options.copy(),
) )
elif isinstance(matchExpr, FusionExpr): elif isinstance(matchExpr, FusionExpr):
builder = builder.fusion( builder = builder.fusion(
@ -400,18 +402,17 @@ class InfinityConnection(DocStoreConnection):
if orderBy.fields: if orderBy.fields:
builder.sort(order_by_expr_list) builder.sort(order_by_expr_list)
builder.offset(offset).limit(limit) builder.offset(offset).limit(limit)
kb_res, extra_result = builder.option({"total_hits_count": True}).to_pl() kb_res, extra_result = builder.option({"total_hits_count": True}).to_df()
if extra_result: if extra_result:
total_hits_count += int(extra_result["total_hits_count"]) total_hits_count += int(extra_result["total_hits_count"])
logger.debug(f"INFINITY search table: {str(table_name)}, result: {str(kb_res)}") logger.debug(f"INFINITY search table: {str(table_name)}, result: {str(kb_res)}")
df_list.append(kb_res) df_list.append(kb_res)
self.connPool.release_conn(inf_conn) self.connPool.release_conn(inf_conn)
res = concat_dataframes(df_list, selectFields) res = concat_dataframes(df_list, output)
if matchExprs: if matchExprs:
res = res.sort(pl.col(score_column) + pl.col(PAGERANK_FLD), descending=True, maintain_order=True) res['Sum'] = res[score_column] + res[PAGERANK_FLD]
if score_column and score_column != "SCORE": res = res.sort_values(by='Sum', ascending=False).reset_index(drop=True).drop(columns=['Sum'])
res = res.rename({score_column: "_score"}) res = res.head(limit)
res = res.limit(limit)
logger.debug(f"INFINITY search final result: {str(res)}") logger.debug(f"INFINITY search final result: {str(res)}")
return res, total_hits_count return res, total_hits_count
@ -433,12 +434,12 @@ class InfinityConnection(DocStoreConnection):
logger.warning( logger.warning(
f"Table not found: {table_name}, this knowledge base isn't created in Infinity. Maybe it is created in other document engine.") f"Table not found: {table_name}, this knowledge base isn't created in Infinity. Maybe it is created in other document engine.")
continue continue
kb_res, _ = table_instance.output(["*"]).filter(f"id = '{chunkId}'").to_pl() kb_res, _ = table_instance.output(["*"]).filter(f"id = '{chunkId}'").to_df()
logger.debug(f"INFINITY get table: {str(table_list)}, result: {str(kb_res)}") logger.debug(f"INFINITY get table: {str(table_list)}, result: {str(kb_res)}")
df_list.append(kb_res) df_list.append(kb_res)
self.connPool.release_conn(inf_conn) self.connPool.release_conn(inf_conn)
res = concat_dataframes(df_list, ["id"]) res = concat_dataframes(df_list, ["id"])
res_fields = self.getFields(res, res.columns) res_fields = self.getFields(res, res.columns.tolist())
return res_fields.get(chunkId, None) return res_fields.get(chunkId, None)
def insert( def insert(
@ -572,60 +573,54 @@ class InfinityConnection(DocStoreConnection):
Helper functions for search result Helper functions for search result
""" """
def getTotal(self, res: tuple[pl.DataFrame, int] | pl.DataFrame) -> int: def getTotal(self, res: tuple[pd.DataFrame, int] | pd.DataFrame) -> int:
if isinstance(res, tuple): if isinstance(res, tuple):
return res[1] return res[1]
return len(res) return len(res)
def getChunkIds(self, res: tuple[pl.DataFrame, int] | pl.DataFrame) -> list[str]: def getChunkIds(self, res: tuple[pd.DataFrame, int] | pd.DataFrame) -> list[str]:
if isinstance(res, tuple): if isinstance(res, tuple):
res = res[0] res = res[0]
return list(res["id"]) return list(res["id"])
def getFields(self, res: tuple[pl.DataFrame, int] | pl.DataFrame, fields: list[str]) -> list[str, dict]: def getFields(self, res: tuple[pd.DataFrame, int] | pd.DataFrame, fields: list[str]) -> dict[str, dict]:
if isinstance(res, tuple): if isinstance(res, tuple):
res = res[0] res = res[0]
res_fields = {}
if not fields: if not fields:
return {} return {}
num_rows = len(res) fieldsAll = fields.copy()
column_id = res["id"] fieldsAll.append('id')
for i in range(num_rows): column_map = {col.lower(): col for col in res.columns}
id = column_id[i] matched_columns = {column_map[col.lower()]:col for col in set(fieldsAll) if col.lower() in column_map}
m = {"id": id} none_columns = [col for col in set(fieldsAll) if col.lower() not in column_map]
for fieldnm in fields:
if fieldnm not in res: res2 = res[matched_columns.keys()]
m[fieldnm] = None res2 = res2.rename(columns=matched_columns)
continue res2.drop_duplicates(subset=['id'], inplace=True)
v = res[fieldnm][i]
if isinstance(v, Series): for column in res2.columns:
v = list(v) k = column.lower()
elif fieldnm in ["important_kwd", "question_kwd", "entities_kwd", "tag_kwd", "source_id"]: if k in ["important_kwd", "question_kwd", "entities_kwd", "tag_kwd", "source_id"]:
assert isinstance(v, str) res2[column] = res2[column].apply(lambda v:[kwd for kwd in v.split("###") if kwd])
v = [kwd for kwd in v.split("###") if kwd] elif k == "position_int":
elif fieldnm == "position_int": def to_position_int(v):
assert isinstance(v, str)
if v: if v:
arr = [int(hex_val, 16) for hex_val in v.split('_')] arr = [int(hex_val, 16) for hex_val in v.split('_')]
v = [arr[i:i + 5] for i in range(0, len(arr), 5)] v = [arr[i:i + 5] for i in range(0, len(arr), 5)]
else: else:
v = [] v = []
elif fieldnm in ["page_num_int", "top_int"]: return v
assert isinstance(v, str) res2[column] = res2[column].apply(to_position_int)
if v: elif k in ["page_num_int", "top_int"]:
v = [int(hex_val, 16) for hex_val in v.split('_')] res2[column] = res2[column].apply(lambda v:[int(hex_val, 16) for hex_val in v.split('_')] if v else [])
else: else:
v = [] pass
else: for column in none_columns:
if not isinstance(v, str): res2[column] = None
v = str(v)
# if fieldnm.endswith("_tks"): return res2.set_index("id").to_dict(orient="index")
# v = rmSpace(v)
m[fieldnm] = v
res_fields[id] = m
return res_fields
def getHighlight(self, res: tuple[pl.DataFrame, int] | pl.DataFrame, keywords: list[str], fieldnm: str): def getHighlight(self, res: tuple[pd.DataFrame, int] | pd.DataFrame, keywords: list[str], fieldnm: str):
if isinstance(res, tuple): if isinstance(res, tuple):
res = res[0] res = res[0]
ans = {} ans = {}
@ -655,7 +650,7 @@ class InfinityConnection(DocStoreConnection):
ans[id] = "...".join(txts) ans[id] = "...".join(txts)
return ans return ans
def getAggregation(self, res: tuple[pl.DataFrame, int] | pl.DataFrame, fieldnm: str): def getAggregation(self, res: tuple[pd.DataFrame, int] | pd.DataFrame, fieldnm: str):
""" """
TODO: Infinity doesn't provide aggregation TODO: Infinity doesn't provide aggregation
""" """