mirror of
https://git.mirrors.martin98.com/https://github.com/infiniflow/ragflow.git
synced 2025-08-13 02:58:59 +08:00
use to_df replace to_pl when get infinity Result (#5604)
### What problem does this PR solve? _Briefly describe what this PR aims to solve. Include background context that will help reviewers understand the purpose of the PR._ ### Type of change - [x] Performance Improvement --------- Co-authored-by: wangwei <dwxiayi@163.com>
This commit is contained in:
parent
555c70672e
commit
76e8285904
@ -9,10 +9,10 @@
|
|||||||
"title_tks": {"type": "varchar", "default": "", "analyzer": "whitespace"},
|
"title_tks": {"type": "varchar", "default": "", "analyzer": "whitespace"},
|
||||||
"title_sm_tks": {"type": "varchar", "default": "", "analyzer": "whitespace"},
|
"title_sm_tks": {"type": "varchar", "default": "", "analyzer": "whitespace"},
|
||||||
"name_kwd": {"type": "varchar", "default": "", "analyzer": "whitespace"},
|
"name_kwd": {"type": "varchar", "default": "", "analyzer": "whitespace"},
|
||||||
"important_kwd": {"type": "varchar", "default": "", "analyzer": "whitespace"},
|
"important_kwd": {"type": "varchar", "default": "", "analyzer": "whitespace-#"},
|
||||||
"tag_kwd": {"type": "varchar", "default": "", "analyzer": "whitespace"},
|
"tag_kwd": {"type": "varchar", "default": "", "analyzer": "whitespace-#"},
|
||||||
"important_tks": {"type": "varchar", "default": "", "analyzer": "whitespace"},
|
"important_tks": {"type": "varchar", "default": "", "analyzer": "whitespace"},
|
||||||
"question_kwd": {"type": "varchar", "default": "", "analyzer": "whitespace"},
|
"question_kwd": {"type": "varchar", "default": "", "analyzer": "whitespace-#"},
|
||||||
"question_tks": {"type": "varchar", "default": "", "analyzer": "whitespace"},
|
"question_tks": {"type": "varchar", "default": "", "analyzer": "whitespace"},
|
||||||
"content_with_weight": {"type": "varchar", "default": ""},
|
"content_with_weight": {"type": "varchar", "default": ""},
|
||||||
"content_ltks": {"type": "varchar", "default": "", "analyzer": "whitespace"},
|
"content_ltks": {"type": "varchar", "default": "", "analyzer": "whitespace"},
|
||||||
@ -28,7 +28,7 @@
|
|||||||
"rank_flt": {"type": "float", "default": 0},
|
"rank_flt": {"type": "float", "default": 0},
|
||||||
"available_int": {"type": "integer", "default": 1},
|
"available_int": {"type": "integer", "default": 1},
|
||||||
"knowledge_graph_kwd": {"type": "varchar", "default": "", "analyzer": "whitespace"},
|
"knowledge_graph_kwd": {"type": "varchar", "default": "", "analyzer": "whitespace"},
|
||||||
"entities_kwd": {"type": "varchar", "default": "", "analyzer": "whitespace"},
|
"entities_kwd": {"type": "varchar", "default": "", "analyzer": "whitespace-#"},
|
||||||
"pagerank_fea": {"type": "integer", "default": 0},
|
"pagerank_fea": {"type": "integer", "default": 0},
|
||||||
"tag_feas": {"type": "varchar", "default": ""},
|
"tag_feas": {"type": "varchar", "default": ""},
|
||||||
|
|
||||||
|
@ -28,8 +28,7 @@ from infinity.errors import ErrorCode
|
|||||||
from rag import settings
|
from rag import settings
|
||||||
from rag.settings import PAGERANK_FLD
|
from rag.settings import PAGERANK_FLD
|
||||||
from rag.utils import singleton
|
from rag.utils import singleton
|
||||||
import polars as pl
|
import pandas as pd
|
||||||
from polars.series.series import Series
|
|
||||||
from api.utils.file_utils import get_project_base_directory
|
from api.utils.file_utils import get_project_base_directory
|
||||||
|
|
||||||
from rag.utils.doc_store_conn import (
|
from rag.utils.doc_store_conn import (
|
||||||
@ -90,20 +89,20 @@ def equivalent_condition_to_str(condition: dict, table_instance=None) -> str | N
|
|||||||
return " AND ".join(cond) if cond else "1=1"
|
return " AND ".join(cond) if cond else "1=1"
|
||||||
|
|
||||||
|
|
||||||
def concat_dataframes(df_list: list[pl.DataFrame], selectFields: list[str]) -> pl.DataFrame:
|
def concat_dataframes(df_list: list[pd.DataFrame], selectFields: list[str]) -> pd.DataFrame:
|
||||||
"""
|
df_list2 = [df for df in df_list if not df.empty]
|
||||||
Concatenate multiple dataframes into one.
|
if df_list2:
|
||||||
"""
|
return pd.concat(df_list2, axis=0).reset_index(drop=True)
|
||||||
df_list = [df for df in df_list if not df.is_empty()]
|
|
||||||
if df_list:
|
schema = []
|
||||||
return pl.concat(df_list)
|
|
||||||
schema = dict()
|
|
||||||
for field_name in selectFields:
|
for field_name in selectFields:
|
||||||
if field_name == 'score()': # Workaround: fix schema is changed to score()
|
if field_name == 'score()': # Workaround: fix schema is changed to score()
|
||||||
schema['SCORE'] = str
|
schema.append('SCORE')
|
||||||
|
elif field_name == 'similarity()': # Workaround: fix schema is changed to similarity()
|
||||||
|
schema.append('SIMILARITY')
|
||||||
else:
|
else:
|
||||||
schema[field_name] = str
|
schema.append(field_name)
|
||||||
return pl.DataFrame(schema=schema)
|
return pd.DataFrame(columns=schema)
|
||||||
|
|
||||||
|
|
||||||
@singleton
|
@singleton
|
||||||
@ -121,7 +120,7 @@ class InfinityConnection(DocStoreConnection):
|
|||||||
connPool = ConnectionPool(infinity_uri)
|
connPool = ConnectionPool(infinity_uri)
|
||||||
inf_conn = connPool.get_conn()
|
inf_conn = connPool.get_conn()
|
||||||
res = inf_conn.show_current_node()
|
res = inf_conn.show_current_node()
|
||||||
if res.error_code == ErrorCode.OK and res.server_status == "started":
|
if res.error_code == ErrorCode.OK and res.server_status in ["started", "alive"]:
|
||||||
self._migrate_db(inf_conn)
|
self._migrate_db(inf_conn)
|
||||||
self.connPool = connPool
|
self.connPool = connPool
|
||||||
connPool.release_conn(inf_conn)
|
connPool.release_conn(inf_conn)
|
||||||
@ -189,7 +188,7 @@ class InfinityConnection(DocStoreConnection):
|
|||||||
self.connPool.release_conn(inf_conn)
|
self.connPool.release_conn(inf_conn)
|
||||||
res2 = {
|
res2 = {
|
||||||
"type": "infinity",
|
"type": "infinity",
|
||||||
"status": "green" if res.error_code == 0 and res.server_status == "started" else "red",
|
"status": "green" if res.error_code == 0 and res.server_status in ["started", "alive"] else "red",
|
||||||
"error": res.error_msg,
|
"error": res.error_msg,
|
||||||
}
|
}
|
||||||
return res2
|
return res2
|
||||||
@ -281,7 +280,7 @@ class InfinityConnection(DocStoreConnection):
|
|||||||
knowledgebaseIds: list[str],
|
knowledgebaseIds: list[str],
|
||||||
aggFields: list[str] = [],
|
aggFields: list[str] = [],
|
||||||
rank_feature: dict | None = None
|
rank_feature: dict | None = None
|
||||||
) -> list[dict] | pl.DataFrame:
|
) -> tuple[pd.DataFrame, int]:
|
||||||
"""
|
"""
|
||||||
TODO: Infinity doesn't provide highlight
|
TODO: Infinity doesn't provide highlight
|
||||||
"""
|
"""
|
||||||
@ -292,9 +291,10 @@ class InfinityConnection(DocStoreConnection):
|
|||||||
db_instance = inf_conn.get_database(self.dbName)
|
db_instance = inf_conn.get_database(self.dbName)
|
||||||
df_list = list()
|
df_list = list()
|
||||||
table_list = list()
|
table_list = list()
|
||||||
|
output = selectFields.copy()
|
||||||
for essential_field in ["id"]:
|
for essential_field in ["id"]:
|
||||||
if essential_field not in selectFields:
|
if essential_field not in output:
|
||||||
selectFields.append(essential_field)
|
output.append(essential_field)
|
||||||
score_func = ""
|
score_func = ""
|
||||||
score_column = ""
|
score_column = ""
|
||||||
for matchExpr in matchExprs:
|
for matchExpr in matchExprs:
|
||||||
@ -309,9 +309,11 @@ class InfinityConnection(DocStoreConnection):
|
|||||||
score_column = "SIMILARITY"
|
score_column = "SIMILARITY"
|
||||||
break
|
break
|
||||||
if matchExprs:
|
if matchExprs:
|
||||||
selectFields.append(score_func)
|
if score_func not in output:
|
||||||
selectFields.append(PAGERANK_FLD)
|
output.append(score_func)
|
||||||
selectFields = [f for f in selectFields if f != "_score"]
|
if PAGERANK_FLD not in output:
|
||||||
|
output.append(PAGERANK_FLD)
|
||||||
|
output = [f for f in output if f != "_score"]
|
||||||
|
|
||||||
# Prepare expressions common to all tables
|
# Prepare expressions common to all tables
|
||||||
filter_cond = None
|
filter_cond = None
|
||||||
@ -339,7 +341,7 @@ class InfinityConnection(DocStoreConnection):
|
|||||||
matchExpr.extra_options[k] = str(v)
|
matchExpr.extra_options[k] = str(v)
|
||||||
logger.debug(f"INFINITY search MatchTextExpr: {json.dumps(matchExpr.__dict__)}")
|
logger.debug(f"INFINITY search MatchTextExpr: {json.dumps(matchExpr.__dict__)}")
|
||||||
elif isinstance(matchExpr, MatchDenseExpr):
|
elif isinstance(matchExpr, MatchDenseExpr):
|
||||||
if filter_fulltext and filter_cond and "filter" not in matchExpr.extra_options:
|
if filter_fulltext and "filter" not in matchExpr.extra_options:
|
||||||
matchExpr.extra_options.update({"filter": filter_fulltext})
|
matchExpr.extra_options.update({"filter": filter_fulltext})
|
||||||
for k, v in matchExpr.extra_options.items():
|
for k, v in matchExpr.extra_options.items():
|
||||||
if not isinstance(v, str):
|
if not isinstance(v, str):
|
||||||
@ -370,7 +372,7 @@ class InfinityConnection(DocStoreConnection):
|
|||||||
except Exception:
|
except Exception:
|
||||||
continue
|
continue
|
||||||
table_list.append(table_name)
|
table_list.append(table_name)
|
||||||
builder = table_instance.output(selectFields)
|
builder = table_instance.output(output)
|
||||||
if len(matchExprs) > 0:
|
if len(matchExprs) > 0:
|
||||||
for matchExpr in matchExprs:
|
for matchExpr in matchExprs:
|
||||||
if isinstance(matchExpr, MatchTextExpr):
|
if isinstance(matchExpr, MatchTextExpr):
|
||||||
@ -379,7 +381,7 @@ class InfinityConnection(DocStoreConnection):
|
|||||||
fields,
|
fields,
|
||||||
matchExpr.matching_text,
|
matchExpr.matching_text,
|
||||||
matchExpr.topn,
|
matchExpr.topn,
|
||||||
matchExpr.extra_options,
|
matchExpr.extra_options.copy(),
|
||||||
)
|
)
|
||||||
elif isinstance(matchExpr, MatchDenseExpr):
|
elif isinstance(matchExpr, MatchDenseExpr):
|
||||||
builder = builder.match_dense(
|
builder = builder.match_dense(
|
||||||
@ -388,7 +390,7 @@ class InfinityConnection(DocStoreConnection):
|
|||||||
matchExpr.embedding_data_type,
|
matchExpr.embedding_data_type,
|
||||||
matchExpr.distance_type,
|
matchExpr.distance_type,
|
||||||
matchExpr.topn,
|
matchExpr.topn,
|
||||||
matchExpr.extra_options,
|
matchExpr.extra_options.copy(),
|
||||||
)
|
)
|
||||||
elif isinstance(matchExpr, FusionExpr):
|
elif isinstance(matchExpr, FusionExpr):
|
||||||
builder = builder.fusion(
|
builder = builder.fusion(
|
||||||
@ -400,18 +402,17 @@ class InfinityConnection(DocStoreConnection):
|
|||||||
if orderBy.fields:
|
if orderBy.fields:
|
||||||
builder.sort(order_by_expr_list)
|
builder.sort(order_by_expr_list)
|
||||||
builder.offset(offset).limit(limit)
|
builder.offset(offset).limit(limit)
|
||||||
kb_res, extra_result = builder.option({"total_hits_count": True}).to_pl()
|
kb_res, extra_result = builder.option({"total_hits_count": True}).to_df()
|
||||||
if extra_result:
|
if extra_result:
|
||||||
total_hits_count += int(extra_result["total_hits_count"])
|
total_hits_count += int(extra_result["total_hits_count"])
|
||||||
logger.debug(f"INFINITY search table: {str(table_name)}, result: {str(kb_res)}")
|
logger.debug(f"INFINITY search table: {str(table_name)}, result: {str(kb_res)}")
|
||||||
df_list.append(kb_res)
|
df_list.append(kb_res)
|
||||||
self.connPool.release_conn(inf_conn)
|
self.connPool.release_conn(inf_conn)
|
||||||
res = concat_dataframes(df_list, selectFields)
|
res = concat_dataframes(df_list, output)
|
||||||
if matchExprs:
|
if matchExprs:
|
||||||
res = res.sort(pl.col(score_column) + pl.col(PAGERANK_FLD), descending=True, maintain_order=True)
|
res['Sum'] = res[score_column] + res[PAGERANK_FLD]
|
||||||
if score_column and score_column != "SCORE":
|
res = res.sort_values(by='Sum', ascending=False).reset_index(drop=True).drop(columns=['Sum'])
|
||||||
res = res.rename({score_column: "_score"})
|
res = res.head(limit)
|
||||||
res = res.limit(limit)
|
|
||||||
logger.debug(f"INFINITY search final result: {str(res)}")
|
logger.debug(f"INFINITY search final result: {str(res)}")
|
||||||
return res, total_hits_count
|
return res, total_hits_count
|
||||||
|
|
||||||
@ -433,12 +434,12 @@ class InfinityConnection(DocStoreConnection):
|
|||||||
logger.warning(
|
logger.warning(
|
||||||
f"Table not found: {table_name}, this knowledge base isn't created in Infinity. Maybe it is created in other document engine.")
|
f"Table not found: {table_name}, this knowledge base isn't created in Infinity. Maybe it is created in other document engine.")
|
||||||
continue
|
continue
|
||||||
kb_res, _ = table_instance.output(["*"]).filter(f"id = '{chunkId}'").to_pl()
|
kb_res, _ = table_instance.output(["*"]).filter(f"id = '{chunkId}'").to_df()
|
||||||
logger.debug(f"INFINITY get table: {str(table_list)}, result: {str(kb_res)}")
|
logger.debug(f"INFINITY get table: {str(table_list)}, result: {str(kb_res)}")
|
||||||
df_list.append(kb_res)
|
df_list.append(kb_res)
|
||||||
self.connPool.release_conn(inf_conn)
|
self.connPool.release_conn(inf_conn)
|
||||||
res = concat_dataframes(df_list, ["id"])
|
res = concat_dataframes(df_list, ["id"])
|
||||||
res_fields = self.getFields(res, res.columns)
|
res_fields = self.getFields(res, res.columns.tolist())
|
||||||
return res_fields.get(chunkId, None)
|
return res_fields.get(chunkId, None)
|
||||||
|
|
||||||
def insert(
|
def insert(
|
||||||
@ -572,60 +573,54 @@ class InfinityConnection(DocStoreConnection):
|
|||||||
Helper functions for search result
|
Helper functions for search result
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def getTotal(self, res: tuple[pl.DataFrame, int] | pl.DataFrame) -> int:
|
def getTotal(self, res: tuple[pd.DataFrame, int] | pd.DataFrame) -> int:
|
||||||
if isinstance(res, tuple):
|
if isinstance(res, tuple):
|
||||||
return res[1]
|
return res[1]
|
||||||
return len(res)
|
return len(res)
|
||||||
|
|
||||||
def getChunkIds(self, res: tuple[pl.DataFrame, int] | pl.DataFrame) -> list[str]:
|
def getChunkIds(self, res: tuple[pd.DataFrame, int] | pd.DataFrame) -> list[str]:
|
||||||
if isinstance(res, tuple):
|
if isinstance(res, tuple):
|
||||||
res = res[0]
|
res = res[0]
|
||||||
return list(res["id"])
|
return list(res["id"])
|
||||||
|
|
||||||
def getFields(self, res: tuple[pl.DataFrame, int] | pl.DataFrame, fields: list[str]) -> list[str, dict]:
|
def getFields(self, res: tuple[pd.DataFrame, int] | pd.DataFrame, fields: list[str]) -> dict[str, dict]:
|
||||||
if isinstance(res, tuple):
|
if isinstance(res, tuple):
|
||||||
res = res[0]
|
res = res[0]
|
||||||
res_fields = {}
|
|
||||||
if not fields:
|
if not fields:
|
||||||
return {}
|
return {}
|
||||||
num_rows = len(res)
|
fieldsAll = fields.copy()
|
||||||
column_id = res["id"]
|
fieldsAll.append('id')
|
||||||
for i in range(num_rows):
|
column_map = {col.lower(): col for col in res.columns}
|
||||||
id = column_id[i]
|
matched_columns = {column_map[col.lower()]:col for col in set(fieldsAll) if col.lower() in column_map}
|
||||||
m = {"id": id}
|
none_columns = [col for col in set(fieldsAll) if col.lower() not in column_map]
|
||||||
for fieldnm in fields:
|
|
||||||
if fieldnm not in res:
|
res2 = res[matched_columns.keys()]
|
||||||
m[fieldnm] = None
|
res2 = res2.rename(columns=matched_columns)
|
||||||
continue
|
res2.drop_duplicates(subset=['id'], inplace=True)
|
||||||
v = res[fieldnm][i]
|
|
||||||
if isinstance(v, Series):
|
for column in res2.columns:
|
||||||
v = list(v)
|
k = column.lower()
|
||||||
elif fieldnm in ["important_kwd", "question_kwd", "entities_kwd", "tag_kwd", "source_id"]:
|
if k in ["important_kwd", "question_kwd", "entities_kwd", "tag_kwd", "source_id"]:
|
||||||
assert isinstance(v, str)
|
res2[column] = res2[column].apply(lambda v:[kwd for kwd in v.split("###") if kwd])
|
||||||
v = [kwd for kwd in v.split("###") if kwd]
|
elif k == "position_int":
|
||||||
elif fieldnm == "position_int":
|
def to_position_int(v):
|
||||||
assert isinstance(v, str)
|
|
||||||
if v:
|
if v:
|
||||||
arr = [int(hex_val, 16) for hex_val in v.split('_')]
|
arr = [int(hex_val, 16) for hex_val in v.split('_')]
|
||||||
v = [arr[i:i + 5] for i in range(0, len(arr), 5)]
|
v = [arr[i:i + 5] for i in range(0, len(arr), 5)]
|
||||||
else:
|
else:
|
||||||
v = []
|
v = []
|
||||||
elif fieldnm in ["page_num_int", "top_int"]:
|
return v
|
||||||
assert isinstance(v, str)
|
res2[column] = res2[column].apply(to_position_int)
|
||||||
if v:
|
elif k in ["page_num_int", "top_int"]:
|
||||||
v = [int(hex_val, 16) for hex_val in v.split('_')]
|
res2[column] = res2[column].apply(lambda v:[int(hex_val, 16) for hex_val in v.split('_')] if v else [])
|
||||||
else:
|
else:
|
||||||
v = []
|
pass
|
||||||
else:
|
for column in none_columns:
|
||||||
if not isinstance(v, str):
|
res2[column] = None
|
||||||
v = str(v)
|
|
||||||
# if fieldnm.endswith("_tks"):
|
|
||||||
# v = rmSpace(v)
|
|
||||||
m[fieldnm] = v
|
|
||||||
res_fields[id] = m
|
|
||||||
return res_fields
|
|
||||||
|
|
||||||
def getHighlight(self, res: tuple[pl.DataFrame, int] | pl.DataFrame, keywords: list[str], fieldnm: str):
|
return res2.set_index("id").to_dict(orient="index")
|
||||||
|
|
||||||
|
def getHighlight(self, res: tuple[pd.DataFrame, int] | pd.DataFrame, keywords: list[str], fieldnm: str):
|
||||||
if isinstance(res, tuple):
|
if isinstance(res, tuple):
|
||||||
res = res[0]
|
res = res[0]
|
||||||
ans = {}
|
ans = {}
|
||||||
@ -655,7 +650,7 @@ class InfinityConnection(DocStoreConnection):
|
|||||||
ans[id] = "...".join(txts)
|
ans[id] = "...".join(txts)
|
||||||
return ans
|
return ans
|
||||||
|
|
||||||
def getAggregation(self, res: tuple[pl.DataFrame, int] | pl.DataFrame, fieldnm: str):
|
def getAggregation(self, res: tuple[pd.DataFrame, int] | pd.DataFrame, fieldnm: str):
|
||||||
"""
|
"""
|
||||||
TODO: Infinity doesn't provide aggregation
|
TODO: Infinity doesn't provide aggregation
|
||||||
"""
|
"""
|
||||||
|
Loading…
x
Reference in New Issue
Block a user