use to_df replace to_pl when get infinity Result (#5604)

### What problem does this PR solve?

_Briefly describe what this PR aims to solve. Include background context
that will help reviewers understand the purpose of the PR._

### Type of change

- [x] Performance Improvement

---------

Co-authored-by: wangwei <dwxiayi@163.com>
This commit is contained in:
汪威 2025-03-05 09:35:40 +08:00 committed by GitHub
parent 555c70672e
commit 76e8285904
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 70 additions and 75 deletions

View File

@ -9,10 +9,10 @@
"title_tks": {"type": "varchar", "default": "", "analyzer": "whitespace"},
"title_sm_tks": {"type": "varchar", "default": "", "analyzer": "whitespace"},
"name_kwd": {"type": "varchar", "default": "", "analyzer": "whitespace"},
"important_kwd": {"type": "varchar", "default": "", "analyzer": "whitespace"},
"tag_kwd": {"type": "varchar", "default": "", "analyzer": "whitespace"},
"important_kwd": {"type": "varchar", "default": "", "analyzer": "whitespace-#"},
"tag_kwd": {"type": "varchar", "default": "", "analyzer": "whitespace-#"},
"important_tks": {"type": "varchar", "default": "", "analyzer": "whitespace"},
"question_kwd": {"type": "varchar", "default": "", "analyzer": "whitespace"},
"question_kwd": {"type": "varchar", "default": "", "analyzer": "whitespace-#"},
"question_tks": {"type": "varchar", "default": "", "analyzer": "whitespace"},
"content_with_weight": {"type": "varchar", "default": ""},
"content_ltks": {"type": "varchar", "default": "", "analyzer": "whitespace"},
@ -28,7 +28,7 @@
"rank_flt": {"type": "float", "default": 0},
"available_int": {"type": "integer", "default": 1},
"knowledge_graph_kwd": {"type": "varchar", "default": "", "analyzer": "whitespace"},
"entities_kwd": {"type": "varchar", "default": "", "analyzer": "whitespace"},
"entities_kwd": {"type": "varchar", "default": "", "analyzer": "whitespace-#"},
"pagerank_fea": {"type": "integer", "default": 0},
"tag_feas": {"type": "varchar", "default": ""},

View File

@ -28,8 +28,7 @@ from infinity.errors import ErrorCode
from rag import settings
from rag.settings import PAGERANK_FLD
from rag.utils import singleton
import polars as pl
from polars.series.series import Series
import pandas as pd
from api.utils.file_utils import get_project_base_directory
from rag.utils.doc_store_conn import (
@ -90,20 +89,20 @@ def equivalent_condition_to_str(condition: dict, table_instance=None) -> str | N
return " AND ".join(cond) if cond else "1=1"
def concat_dataframes(df_list: list[pl.DataFrame], selectFields: list[str]) -> pl.DataFrame:
"""
Concatenate multiple dataframes into one.
"""
df_list = [df for df in df_list if not df.is_empty()]
if df_list:
return pl.concat(df_list)
schema = dict()
def concat_dataframes(df_list: list[pd.DataFrame], selectFields: list[str]) -> pd.DataFrame:
df_list2 = [df for df in df_list if not df.empty]
if df_list2:
return pd.concat(df_list2, axis=0).reset_index(drop=True)
schema = []
for field_name in selectFields:
if field_name == 'score()': # Workaround: fix schema is changed to score()
schema['SCORE'] = str
schema.append('SCORE')
elif field_name == 'similarity()': # Workaround: fix schema is changed to similarity()
schema.append('SIMILARITY')
else:
schema[field_name] = str
return pl.DataFrame(schema=schema)
schema.append(field_name)
return pd.DataFrame(columns=schema)
@singleton
@ -121,7 +120,7 @@ class InfinityConnection(DocStoreConnection):
connPool = ConnectionPool(infinity_uri)
inf_conn = connPool.get_conn()
res = inf_conn.show_current_node()
if res.error_code == ErrorCode.OK and res.server_status == "started":
if res.error_code == ErrorCode.OK and res.server_status in ["started", "alive"]:
self._migrate_db(inf_conn)
self.connPool = connPool
connPool.release_conn(inf_conn)
@ -189,7 +188,7 @@ class InfinityConnection(DocStoreConnection):
self.connPool.release_conn(inf_conn)
res2 = {
"type": "infinity",
"status": "green" if res.error_code == 0 and res.server_status == "started" else "red",
"status": "green" if res.error_code == 0 and res.server_status in ["started", "alive"] else "red",
"error": res.error_msg,
}
return res2
@ -281,7 +280,7 @@ class InfinityConnection(DocStoreConnection):
knowledgebaseIds: list[str],
aggFields: list[str] = [],
rank_feature: dict | None = None
) -> list[dict] | pl.DataFrame:
) -> tuple[pd.DataFrame, int]:
"""
TODO: Infinity doesn't provide highlight
"""
@ -292,9 +291,10 @@ class InfinityConnection(DocStoreConnection):
db_instance = inf_conn.get_database(self.dbName)
df_list = list()
table_list = list()
output = selectFields.copy()
for essential_field in ["id"]:
if essential_field not in selectFields:
selectFields.append(essential_field)
if essential_field not in output:
output.append(essential_field)
score_func = ""
score_column = ""
for matchExpr in matchExprs:
@ -309,9 +309,11 @@ class InfinityConnection(DocStoreConnection):
score_column = "SIMILARITY"
break
if matchExprs:
selectFields.append(score_func)
selectFields.append(PAGERANK_FLD)
selectFields = [f for f in selectFields if f != "_score"]
if score_func not in output:
output.append(score_func)
if PAGERANK_FLD not in output:
output.append(PAGERANK_FLD)
output = [f for f in output if f != "_score"]
# Prepare expressions common to all tables
filter_cond = None
@ -339,7 +341,7 @@ class InfinityConnection(DocStoreConnection):
matchExpr.extra_options[k] = str(v)
logger.debug(f"INFINITY search MatchTextExpr: {json.dumps(matchExpr.__dict__)}")
elif isinstance(matchExpr, MatchDenseExpr):
if filter_fulltext and filter_cond and "filter" not in matchExpr.extra_options:
if filter_fulltext and "filter" not in matchExpr.extra_options:
matchExpr.extra_options.update({"filter": filter_fulltext})
for k, v in matchExpr.extra_options.items():
if not isinstance(v, str):
@ -370,7 +372,7 @@ class InfinityConnection(DocStoreConnection):
except Exception:
continue
table_list.append(table_name)
builder = table_instance.output(selectFields)
builder = table_instance.output(output)
if len(matchExprs) > 0:
for matchExpr in matchExprs:
if isinstance(matchExpr, MatchTextExpr):
@ -379,7 +381,7 @@ class InfinityConnection(DocStoreConnection):
fields,
matchExpr.matching_text,
matchExpr.topn,
matchExpr.extra_options,
matchExpr.extra_options.copy(),
)
elif isinstance(matchExpr, MatchDenseExpr):
builder = builder.match_dense(
@ -388,7 +390,7 @@ class InfinityConnection(DocStoreConnection):
matchExpr.embedding_data_type,
matchExpr.distance_type,
matchExpr.topn,
matchExpr.extra_options,
matchExpr.extra_options.copy(),
)
elif isinstance(matchExpr, FusionExpr):
builder = builder.fusion(
@ -400,18 +402,17 @@ class InfinityConnection(DocStoreConnection):
if orderBy.fields:
builder.sort(order_by_expr_list)
builder.offset(offset).limit(limit)
kb_res, extra_result = builder.option({"total_hits_count": True}).to_pl()
kb_res, extra_result = builder.option({"total_hits_count": True}).to_df()
if extra_result:
total_hits_count += int(extra_result["total_hits_count"])
logger.debug(f"INFINITY search table: {str(table_name)}, result: {str(kb_res)}")
df_list.append(kb_res)
self.connPool.release_conn(inf_conn)
res = concat_dataframes(df_list, selectFields)
res = concat_dataframes(df_list, output)
if matchExprs:
res = res.sort(pl.col(score_column) + pl.col(PAGERANK_FLD), descending=True, maintain_order=True)
if score_column and score_column != "SCORE":
res = res.rename({score_column: "_score"})
res = res.limit(limit)
res['Sum'] = res[score_column] + res[PAGERANK_FLD]
res = res.sort_values(by='Sum', ascending=False).reset_index(drop=True).drop(columns=['Sum'])
res = res.head(limit)
logger.debug(f"INFINITY search final result: {str(res)}")
return res, total_hits_count
@ -433,12 +434,12 @@ class InfinityConnection(DocStoreConnection):
logger.warning(
f"Table not found: {table_name}, this knowledge base isn't created in Infinity. Maybe it is created in other document engine.")
continue
kb_res, _ = table_instance.output(["*"]).filter(f"id = '{chunkId}'").to_pl()
kb_res, _ = table_instance.output(["*"]).filter(f"id = '{chunkId}'").to_df()
logger.debug(f"INFINITY get table: {str(table_list)}, result: {str(kb_res)}")
df_list.append(kb_res)
self.connPool.release_conn(inf_conn)
res = concat_dataframes(df_list, ["id"])
res_fields = self.getFields(res, res.columns)
res_fields = self.getFields(res, res.columns.tolist())
return res_fields.get(chunkId, None)
def insert(
@ -572,60 +573,54 @@ class InfinityConnection(DocStoreConnection):
Helper functions for search result
"""
def getTotal(self, res: tuple[pl.DataFrame, int] | pl.DataFrame) -> int:
def getTotal(self, res: tuple[pd.DataFrame, int] | pd.DataFrame) -> int:
if isinstance(res, tuple):
return res[1]
return len(res)
def getChunkIds(self, res: tuple[pl.DataFrame, int] | pl.DataFrame) -> list[str]:
def getChunkIds(self, res: tuple[pd.DataFrame, int] | pd.DataFrame) -> list[str]:
if isinstance(res, tuple):
res = res[0]
return list(res["id"])
def getFields(self, res: tuple[pl.DataFrame, int] | pl.DataFrame, fields: list[str]) -> list[str, dict]:
def getFields(self, res: tuple[pd.DataFrame, int] | pd.DataFrame, fields: list[str]) -> dict[str, dict]:
if isinstance(res, tuple):
res = res[0]
res_fields = {}
if not fields:
return {}
num_rows = len(res)
column_id = res["id"]
for i in range(num_rows):
id = column_id[i]
m = {"id": id}
for fieldnm in fields:
if fieldnm not in res:
m[fieldnm] = None
continue
v = res[fieldnm][i]
if isinstance(v, Series):
v = list(v)
elif fieldnm in ["important_kwd", "question_kwd", "entities_kwd", "tag_kwd", "source_id"]:
assert isinstance(v, str)
v = [kwd for kwd in v.split("###") if kwd]
elif fieldnm == "position_int":
assert isinstance(v, str)
fieldsAll = fields.copy()
fieldsAll.append('id')
column_map = {col.lower(): col for col in res.columns}
matched_columns = {column_map[col.lower()]:col for col in set(fieldsAll) if col.lower() in column_map}
none_columns = [col for col in set(fieldsAll) if col.lower() not in column_map]
res2 = res[matched_columns.keys()]
res2 = res2.rename(columns=matched_columns)
res2.drop_duplicates(subset=['id'], inplace=True)
for column in res2.columns:
k = column.lower()
if k in ["important_kwd", "question_kwd", "entities_kwd", "tag_kwd", "source_id"]:
res2[column] = res2[column].apply(lambda v:[kwd for kwd in v.split("###") if kwd])
elif k == "position_int":
def to_position_int(v):
if v:
arr = [int(hex_val, 16) for hex_val in v.split('_')]
v = [arr[i:i + 5] for i in range(0, len(arr), 5)]
else:
v = []
elif fieldnm in ["page_num_int", "top_int"]:
assert isinstance(v, str)
if v:
v = [int(hex_val, 16) for hex_val in v.split('_')]
return v
res2[column] = res2[column].apply(to_position_int)
elif k in ["page_num_int", "top_int"]:
res2[column] = res2[column].apply(lambda v:[int(hex_val, 16) for hex_val in v.split('_')] if v else [])
else:
v = []
else:
if not isinstance(v, str):
v = str(v)
# if fieldnm.endswith("_tks"):
# v = rmSpace(v)
m[fieldnm] = v
res_fields[id] = m
return res_fields
pass
for column in none_columns:
res2[column] = None
def getHighlight(self, res: tuple[pl.DataFrame, int] | pl.DataFrame, keywords: list[str], fieldnm: str):
return res2.set_index("id").to_dict(orient="index")
def getHighlight(self, res: tuple[pd.DataFrame, int] | pd.DataFrame, keywords: list[str], fieldnm: str):
if isinstance(res, tuple):
res = res[0]
ans = {}
@ -655,7 +650,7 @@ class InfinityConnection(DocStoreConnection):
ans[id] = "...".join(txts)
return ans
def getAggregation(self, res: tuple[pl.DataFrame, int] | pl.DataFrame, fieldnm: str):
def getAggregation(self, res: tuple[pd.DataFrame, int] | pd.DataFrame, fieldnm: str):
"""
TODO: Infinity doesn't provide aggregation
"""