mirror of
https://git.mirrors.martin98.com/https://github.com/infiniflow/ragflow.git
synced 2025-08-12 19:49:02 +08:00
use to_df replace to_pl when get infinity Result (#5604)
### What problem does this PR solve? _Briefly describe what this PR aims to solve. Include background context that will help reviewers understand the purpose of the PR._ ### Type of change - [x] Performance Improvement --------- Co-authored-by: wangwei <dwxiayi@163.com>
This commit is contained in:
parent
555c70672e
commit
76e8285904
@ -9,10 +9,10 @@
|
||||
"title_tks": {"type": "varchar", "default": "", "analyzer": "whitespace"},
|
||||
"title_sm_tks": {"type": "varchar", "default": "", "analyzer": "whitespace"},
|
||||
"name_kwd": {"type": "varchar", "default": "", "analyzer": "whitespace"},
|
||||
"important_kwd": {"type": "varchar", "default": "", "analyzer": "whitespace"},
|
||||
"tag_kwd": {"type": "varchar", "default": "", "analyzer": "whitespace"},
|
||||
"important_kwd": {"type": "varchar", "default": "", "analyzer": "whitespace-#"},
|
||||
"tag_kwd": {"type": "varchar", "default": "", "analyzer": "whitespace-#"},
|
||||
"important_tks": {"type": "varchar", "default": "", "analyzer": "whitespace"},
|
||||
"question_kwd": {"type": "varchar", "default": "", "analyzer": "whitespace"},
|
||||
"question_kwd": {"type": "varchar", "default": "", "analyzer": "whitespace-#"},
|
||||
"question_tks": {"type": "varchar", "default": "", "analyzer": "whitespace"},
|
||||
"content_with_weight": {"type": "varchar", "default": ""},
|
||||
"content_ltks": {"type": "varchar", "default": "", "analyzer": "whitespace"},
|
||||
@ -28,7 +28,7 @@
|
||||
"rank_flt": {"type": "float", "default": 0},
|
||||
"available_int": {"type": "integer", "default": 1},
|
||||
"knowledge_graph_kwd": {"type": "varchar", "default": "", "analyzer": "whitespace"},
|
||||
"entities_kwd": {"type": "varchar", "default": "", "analyzer": "whitespace"},
|
||||
"entities_kwd": {"type": "varchar", "default": "", "analyzer": "whitespace-#"},
|
||||
"pagerank_fea": {"type": "integer", "default": 0},
|
||||
"tag_feas": {"type": "varchar", "default": ""},
|
||||
|
||||
|
@ -28,8 +28,7 @@ from infinity.errors import ErrorCode
|
||||
from rag import settings
|
||||
from rag.settings import PAGERANK_FLD
|
||||
from rag.utils import singleton
|
||||
import polars as pl
|
||||
from polars.series.series import Series
|
||||
import pandas as pd
|
||||
from api.utils.file_utils import get_project_base_directory
|
||||
|
||||
from rag.utils.doc_store_conn import (
|
||||
@ -90,20 +89,20 @@ def equivalent_condition_to_str(condition: dict, table_instance=None) -> str | N
|
||||
return " AND ".join(cond) if cond else "1=1"
|
||||
|
||||
|
||||
def concat_dataframes(df_list: list[pl.DataFrame], selectFields: list[str]) -> pl.DataFrame:
|
||||
"""
|
||||
Concatenate multiple dataframes into one.
|
||||
"""
|
||||
df_list = [df for df in df_list if not df.is_empty()]
|
||||
if df_list:
|
||||
return pl.concat(df_list)
|
||||
schema = dict()
|
||||
def concat_dataframes(df_list: list[pd.DataFrame], selectFields: list[str]) -> pd.DataFrame:
|
||||
df_list2 = [df for df in df_list if not df.empty]
|
||||
if df_list2:
|
||||
return pd.concat(df_list2, axis=0).reset_index(drop=True)
|
||||
|
||||
schema = []
|
||||
for field_name in selectFields:
|
||||
if field_name == 'score()': # Workaround: fix schema is changed to score()
|
||||
schema['SCORE'] = str
|
||||
schema.append('SCORE')
|
||||
elif field_name == 'similarity()': # Workaround: fix schema is changed to similarity()
|
||||
schema.append('SIMILARITY')
|
||||
else:
|
||||
schema[field_name] = str
|
||||
return pl.DataFrame(schema=schema)
|
||||
schema.append(field_name)
|
||||
return pd.DataFrame(columns=schema)
|
||||
|
||||
|
||||
@singleton
|
||||
@ -121,7 +120,7 @@ class InfinityConnection(DocStoreConnection):
|
||||
connPool = ConnectionPool(infinity_uri)
|
||||
inf_conn = connPool.get_conn()
|
||||
res = inf_conn.show_current_node()
|
||||
if res.error_code == ErrorCode.OK and res.server_status == "started":
|
||||
if res.error_code == ErrorCode.OK and res.server_status in ["started", "alive"]:
|
||||
self._migrate_db(inf_conn)
|
||||
self.connPool = connPool
|
||||
connPool.release_conn(inf_conn)
|
||||
@ -189,7 +188,7 @@ class InfinityConnection(DocStoreConnection):
|
||||
self.connPool.release_conn(inf_conn)
|
||||
res2 = {
|
||||
"type": "infinity",
|
||||
"status": "green" if res.error_code == 0 and res.server_status == "started" else "red",
|
||||
"status": "green" if res.error_code == 0 and res.server_status in ["started", "alive"] else "red",
|
||||
"error": res.error_msg,
|
||||
}
|
||||
return res2
|
||||
@ -281,7 +280,7 @@ class InfinityConnection(DocStoreConnection):
|
||||
knowledgebaseIds: list[str],
|
||||
aggFields: list[str] = [],
|
||||
rank_feature: dict | None = None
|
||||
) -> list[dict] | pl.DataFrame:
|
||||
) -> tuple[pd.DataFrame, int]:
|
||||
"""
|
||||
TODO: Infinity doesn't provide highlight
|
||||
"""
|
||||
@ -292,9 +291,10 @@ class InfinityConnection(DocStoreConnection):
|
||||
db_instance = inf_conn.get_database(self.dbName)
|
||||
df_list = list()
|
||||
table_list = list()
|
||||
output = selectFields.copy()
|
||||
for essential_field in ["id"]:
|
||||
if essential_field not in selectFields:
|
||||
selectFields.append(essential_field)
|
||||
if essential_field not in output:
|
||||
output.append(essential_field)
|
||||
score_func = ""
|
||||
score_column = ""
|
||||
for matchExpr in matchExprs:
|
||||
@ -309,9 +309,11 @@ class InfinityConnection(DocStoreConnection):
|
||||
score_column = "SIMILARITY"
|
||||
break
|
||||
if matchExprs:
|
||||
selectFields.append(score_func)
|
||||
selectFields.append(PAGERANK_FLD)
|
||||
selectFields = [f for f in selectFields if f != "_score"]
|
||||
if score_func not in output:
|
||||
output.append(score_func)
|
||||
if PAGERANK_FLD not in output:
|
||||
output.append(PAGERANK_FLD)
|
||||
output = [f for f in output if f != "_score"]
|
||||
|
||||
# Prepare expressions common to all tables
|
||||
filter_cond = None
|
||||
@ -339,7 +341,7 @@ class InfinityConnection(DocStoreConnection):
|
||||
matchExpr.extra_options[k] = str(v)
|
||||
logger.debug(f"INFINITY search MatchTextExpr: {json.dumps(matchExpr.__dict__)}")
|
||||
elif isinstance(matchExpr, MatchDenseExpr):
|
||||
if filter_fulltext and filter_cond and "filter" not in matchExpr.extra_options:
|
||||
if filter_fulltext and "filter" not in matchExpr.extra_options:
|
||||
matchExpr.extra_options.update({"filter": filter_fulltext})
|
||||
for k, v in matchExpr.extra_options.items():
|
||||
if not isinstance(v, str):
|
||||
@ -370,7 +372,7 @@ class InfinityConnection(DocStoreConnection):
|
||||
except Exception:
|
||||
continue
|
||||
table_list.append(table_name)
|
||||
builder = table_instance.output(selectFields)
|
||||
builder = table_instance.output(output)
|
||||
if len(matchExprs) > 0:
|
||||
for matchExpr in matchExprs:
|
||||
if isinstance(matchExpr, MatchTextExpr):
|
||||
@ -379,7 +381,7 @@ class InfinityConnection(DocStoreConnection):
|
||||
fields,
|
||||
matchExpr.matching_text,
|
||||
matchExpr.topn,
|
||||
matchExpr.extra_options,
|
||||
matchExpr.extra_options.copy(),
|
||||
)
|
||||
elif isinstance(matchExpr, MatchDenseExpr):
|
||||
builder = builder.match_dense(
|
||||
@ -388,7 +390,7 @@ class InfinityConnection(DocStoreConnection):
|
||||
matchExpr.embedding_data_type,
|
||||
matchExpr.distance_type,
|
||||
matchExpr.topn,
|
||||
matchExpr.extra_options,
|
||||
matchExpr.extra_options.copy(),
|
||||
)
|
||||
elif isinstance(matchExpr, FusionExpr):
|
||||
builder = builder.fusion(
|
||||
@ -400,18 +402,17 @@ class InfinityConnection(DocStoreConnection):
|
||||
if orderBy.fields:
|
||||
builder.sort(order_by_expr_list)
|
||||
builder.offset(offset).limit(limit)
|
||||
kb_res, extra_result = builder.option({"total_hits_count": True}).to_pl()
|
||||
kb_res, extra_result = builder.option({"total_hits_count": True}).to_df()
|
||||
if extra_result:
|
||||
total_hits_count += int(extra_result["total_hits_count"])
|
||||
logger.debug(f"INFINITY search table: {str(table_name)}, result: {str(kb_res)}")
|
||||
df_list.append(kb_res)
|
||||
self.connPool.release_conn(inf_conn)
|
||||
res = concat_dataframes(df_list, selectFields)
|
||||
res = concat_dataframes(df_list, output)
|
||||
if matchExprs:
|
||||
res = res.sort(pl.col(score_column) + pl.col(PAGERANK_FLD), descending=True, maintain_order=True)
|
||||
if score_column and score_column != "SCORE":
|
||||
res = res.rename({score_column: "_score"})
|
||||
res = res.limit(limit)
|
||||
res['Sum'] = res[score_column] + res[PAGERANK_FLD]
|
||||
res = res.sort_values(by='Sum', ascending=False).reset_index(drop=True).drop(columns=['Sum'])
|
||||
res = res.head(limit)
|
||||
logger.debug(f"INFINITY search final result: {str(res)}")
|
||||
return res, total_hits_count
|
||||
|
||||
@ -433,12 +434,12 @@ class InfinityConnection(DocStoreConnection):
|
||||
logger.warning(
|
||||
f"Table not found: {table_name}, this knowledge base isn't created in Infinity. Maybe it is created in other document engine.")
|
||||
continue
|
||||
kb_res, _ = table_instance.output(["*"]).filter(f"id = '{chunkId}'").to_pl()
|
||||
kb_res, _ = table_instance.output(["*"]).filter(f"id = '{chunkId}'").to_df()
|
||||
logger.debug(f"INFINITY get table: {str(table_list)}, result: {str(kb_res)}")
|
||||
df_list.append(kb_res)
|
||||
self.connPool.release_conn(inf_conn)
|
||||
res = concat_dataframes(df_list, ["id"])
|
||||
res_fields = self.getFields(res, res.columns)
|
||||
res_fields = self.getFields(res, res.columns.tolist())
|
||||
return res_fields.get(chunkId, None)
|
||||
|
||||
def insert(
|
||||
@ -572,60 +573,54 @@ class InfinityConnection(DocStoreConnection):
|
||||
Helper functions for search result
|
||||
"""
|
||||
|
||||
def getTotal(self, res: tuple[pl.DataFrame, int] | pl.DataFrame) -> int:
|
||||
def getTotal(self, res: tuple[pd.DataFrame, int] | pd.DataFrame) -> int:
|
||||
if isinstance(res, tuple):
|
||||
return res[1]
|
||||
return len(res)
|
||||
|
||||
def getChunkIds(self, res: tuple[pl.DataFrame, int] | pl.DataFrame) -> list[str]:
|
||||
def getChunkIds(self, res: tuple[pd.DataFrame, int] | pd.DataFrame) -> list[str]:
|
||||
if isinstance(res, tuple):
|
||||
res = res[0]
|
||||
return list(res["id"])
|
||||
|
||||
def getFields(self, res: tuple[pl.DataFrame, int] | pl.DataFrame, fields: list[str]) -> list[str, dict]:
|
||||
def getFields(self, res: tuple[pd.DataFrame, int] | pd.DataFrame, fields: list[str]) -> dict[str, dict]:
|
||||
if isinstance(res, tuple):
|
||||
res = res[0]
|
||||
res_fields = {}
|
||||
if not fields:
|
||||
return {}
|
||||
num_rows = len(res)
|
||||
column_id = res["id"]
|
||||
for i in range(num_rows):
|
||||
id = column_id[i]
|
||||
m = {"id": id}
|
||||
for fieldnm in fields:
|
||||
if fieldnm not in res:
|
||||
m[fieldnm] = None
|
||||
continue
|
||||
v = res[fieldnm][i]
|
||||
if isinstance(v, Series):
|
||||
v = list(v)
|
||||
elif fieldnm in ["important_kwd", "question_kwd", "entities_kwd", "tag_kwd", "source_id"]:
|
||||
assert isinstance(v, str)
|
||||
v = [kwd for kwd in v.split("###") if kwd]
|
||||
elif fieldnm == "position_int":
|
||||
assert isinstance(v, str)
|
||||
fieldsAll = fields.copy()
|
||||
fieldsAll.append('id')
|
||||
column_map = {col.lower(): col for col in res.columns}
|
||||
matched_columns = {column_map[col.lower()]:col for col in set(fieldsAll) if col.lower() in column_map}
|
||||
none_columns = [col for col in set(fieldsAll) if col.lower() not in column_map]
|
||||
|
||||
res2 = res[matched_columns.keys()]
|
||||
res2 = res2.rename(columns=matched_columns)
|
||||
res2.drop_duplicates(subset=['id'], inplace=True)
|
||||
|
||||
for column in res2.columns:
|
||||
k = column.lower()
|
||||
if k in ["important_kwd", "question_kwd", "entities_kwd", "tag_kwd", "source_id"]:
|
||||
res2[column] = res2[column].apply(lambda v:[kwd for kwd in v.split("###") if kwd])
|
||||
elif k == "position_int":
|
||||
def to_position_int(v):
|
||||
if v:
|
||||
arr = [int(hex_val, 16) for hex_val in v.split('_')]
|
||||
v = [arr[i:i + 5] for i in range(0, len(arr), 5)]
|
||||
else:
|
||||
v = []
|
||||
elif fieldnm in ["page_num_int", "top_int"]:
|
||||
assert isinstance(v, str)
|
||||
if v:
|
||||
v = [int(hex_val, 16) for hex_val in v.split('_')]
|
||||
return v
|
||||
res2[column] = res2[column].apply(to_position_int)
|
||||
elif k in ["page_num_int", "top_int"]:
|
||||
res2[column] = res2[column].apply(lambda v:[int(hex_val, 16) for hex_val in v.split('_')] if v else [])
|
||||
else:
|
||||
v = []
|
||||
else:
|
||||
if not isinstance(v, str):
|
||||
v = str(v)
|
||||
# if fieldnm.endswith("_tks"):
|
||||
# v = rmSpace(v)
|
||||
m[fieldnm] = v
|
||||
res_fields[id] = m
|
||||
return res_fields
|
||||
pass
|
||||
for column in none_columns:
|
||||
res2[column] = None
|
||||
|
||||
def getHighlight(self, res: tuple[pl.DataFrame, int] | pl.DataFrame, keywords: list[str], fieldnm: str):
|
||||
return res2.set_index("id").to_dict(orient="index")
|
||||
|
||||
def getHighlight(self, res: tuple[pd.DataFrame, int] | pd.DataFrame, keywords: list[str], fieldnm: str):
|
||||
if isinstance(res, tuple):
|
||||
res = res[0]
|
||||
ans = {}
|
||||
@ -655,7 +650,7 @@ class InfinityConnection(DocStoreConnection):
|
||||
ans[id] = "...".join(txts)
|
||||
return ans
|
||||
|
||||
def getAggregation(self, res: tuple[pl.DataFrame, int] | pl.DataFrame, fieldnm: str):
|
||||
def getAggregation(self, res: tuple[pd.DataFrame, int] | pd.DataFrame, fieldnm: str):
|
||||
"""
|
||||
TODO: Infinity doesn't provide aggregation
|
||||
"""
|
||||
|
Loading…
x
Reference in New Issue
Block a user