From 76e82859040907b9c71155678f72931fb1d4ebbc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=B1=AA=E5=A8=81?= Date: Wed, 5 Mar 2025 09:35:40 +0800 Subject: [PATCH] use to_df replace to_pl when get infinity Result (#5604) ### What problem does this PR solve? _Briefly describe what this PR aims to solve. Include background context that will help reviewers understand the purpose of the PR._ ### Type of change - [x] Performance Improvement --------- Co-authored-by: wangwei --- conf/infinity_mapping.json | 8 +-- rag/utils/infinity_conn.py | 137 ++++++++++++++++++------------------- 2 files changed, 70 insertions(+), 75 deletions(-) diff --git a/conf/infinity_mapping.json b/conf/infinity_mapping.json index f77fb83cd..95d5f2c64 100644 --- a/conf/infinity_mapping.json +++ b/conf/infinity_mapping.json @@ -9,10 +9,10 @@ "title_tks": {"type": "varchar", "default": "", "analyzer": "whitespace"}, "title_sm_tks": {"type": "varchar", "default": "", "analyzer": "whitespace"}, "name_kwd": {"type": "varchar", "default": "", "analyzer": "whitespace"}, - "important_kwd": {"type": "varchar", "default": "", "analyzer": "whitespace"}, - "tag_kwd": {"type": "varchar", "default": "", "analyzer": "whitespace"}, + "important_kwd": {"type": "varchar", "default": "", "analyzer": "whitespace-#"}, + "tag_kwd": {"type": "varchar", "default": "", "analyzer": "whitespace-#"}, "important_tks": {"type": "varchar", "default": "", "analyzer": "whitespace"}, - "question_kwd": {"type": "varchar", "default": "", "analyzer": "whitespace"}, + "question_kwd": {"type": "varchar", "default": "", "analyzer": "whitespace-#"}, "question_tks": {"type": "varchar", "default": "", "analyzer": "whitespace"}, "content_with_weight": {"type": "varchar", "default": ""}, "content_ltks": {"type": "varchar", "default": "", "analyzer": "whitespace"}, @@ -28,7 +28,7 @@ "rank_flt": {"type": "float", "default": 0}, "available_int": {"type": "integer", "default": 1}, "knowledge_graph_kwd": {"type": "varchar", "default": "", "analyzer": "whitespace"}, - "entities_kwd": {"type": "varchar", "default": "", "analyzer": "whitespace"}, + "entities_kwd": {"type": "varchar", "default": "", "analyzer": "whitespace-#"}, "pagerank_fea": {"type": "integer", "default": 0}, "tag_feas": {"type": "varchar", "default": ""}, diff --git a/rag/utils/infinity_conn.py b/rag/utils/infinity_conn.py index dcdae7e8b..28d107bc3 100644 --- a/rag/utils/infinity_conn.py +++ b/rag/utils/infinity_conn.py @@ -28,8 +28,7 @@ from infinity.errors import ErrorCode from rag import settings from rag.settings import PAGERANK_FLD from rag.utils import singleton -import polars as pl -from polars.series.series import Series +import pandas as pd from api.utils.file_utils import get_project_base_directory from rag.utils.doc_store_conn import ( @@ -90,20 +89,20 @@ def equivalent_condition_to_str(condition: dict, table_instance=None) -> str | N return " AND ".join(cond) if cond else "1=1" -def concat_dataframes(df_list: list[pl.DataFrame], selectFields: list[str]) -> pl.DataFrame: - """ - Concatenate multiple dataframes into one. - """ - df_list = [df for df in df_list if not df.is_empty()] - if df_list: - return pl.concat(df_list) - schema = dict() +def concat_dataframes(df_list: list[pd.DataFrame], selectFields: list[str]) -> pd.DataFrame: + df_list2 = [df for df in df_list if not df.empty] + if df_list2: + return pd.concat(df_list2, axis=0).reset_index(drop=True) + + schema = [] for field_name in selectFields: - if field_name == 'score()': # Workaround: fix schema is changed to score() - schema['SCORE'] = str + if field_name == 'score()': # Workaround: fix schema is changed to score() + schema.append('SCORE') + elif field_name == 'similarity()': # Workaround: fix schema is changed to similarity() + schema.append('SIMILARITY') else: - schema[field_name] = str - return pl.DataFrame(schema=schema) + schema.append(field_name) + return pd.DataFrame(columns=schema) @singleton @@ -121,7 +120,7 @@ class InfinityConnection(DocStoreConnection): connPool = ConnectionPool(infinity_uri) inf_conn = connPool.get_conn() res = inf_conn.show_current_node() - if res.error_code == ErrorCode.OK and res.server_status == "started": + if res.error_code == ErrorCode.OK and res.server_status in ["started", "alive"]: self._migrate_db(inf_conn) self.connPool = connPool connPool.release_conn(inf_conn) @@ -189,7 +188,7 @@ class InfinityConnection(DocStoreConnection): self.connPool.release_conn(inf_conn) res2 = { "type": "infinity", - "status": "green" if res.error_code == 0 and res.server_status == "started" else "red", + "status": "green" if res.error_code == 0 and res.server_status in ["started", "alive"] else "red", "error": res.error_msg, } return res2 @@ -281,7 +280,7 @@ class InfinityConnection(DocStoreConnection): knowledgebaseIds: list[str], aggFields: list[str] = [], rank_feature: dict | None = None - ) -> list[dict] | pl.DataFrame: + ) -> tuple[pd.DataFrame, int]: """ TODO: Infinity doesn't provide highlight """ @@ -292,9 +291,10 @@ class InfinityConnection(DocStoreConnection): db_instance = inf_conn.get_database(self.dbName) df_list = list() table_list = list() + output = selectFields.copy() for essential_field in ["id"]: - if essential_field not in selectFields: - selectFields.append(essential_field) + if essential_field not in output: + output.append(essential_field) score_func = "" score_column = "" for matchExpr in matchExprs: @@ -309,9 +309,11 @@ class InfinityConnection(DocStoreConnection): score_column = "SIMILARITY" break if matchExprs: - selectFields.append(score_func) - selectFields.append(PAGERANK_FLD) - selectFields = [f for f in selectFields if f != "_score"] + if score_func not in output: + output.append(score_func) + if PAGERANK_FLD not in output: + output.append(PAGERANK_FLD) + output = [f for f in output if f != "_score"] # Prepare expressions common to all tables filter_cond = None @@ -339,7 +341,7 @@ class InfinityConnection(DocStoreConnection): matchExpr.extra_options[k] = str(v) logger.debug(f"INFINITY search MatchTextExpr: {json.dumps(matchExpr.__dict__)}") elif isinstance(matchExpr, MatchDenseExpr): - if filter_fulltext and filter_cond and "filter" not in matchExpr.extra_options: + if filter_fulltext and "filter" not in matchExpr.extra_options: matchExpr.extra_options.update({"filter": filter_fulltext}) for k, v in matchExpr.extra_options.items(): if not isinstance(v, str): @@ -370,7 +372,7 @@ class InfinityConnection(DocStoreConnection): except Exception: continue table_list.append(table_name) - builder = table_instance.output(selectFields) + builder = table_instance.output(output) if len(matchExprs) > 0: for matchExpr in matchExprs: if isinstance(matchExpr, MatchTextExpr): @@ -379,7 +381,7 @@ class InfinityConnection(DocStoreConnection): fields, matchExpr.matching_text, matchExpr.topn, - matchExpr.extra_options, + matchExpr.extra_options.copy(), ) elif isinstance(matchExpr, MatchDenseExpr): builder = builder.match_dense( @@ -388,7 +390,7 @@ class InfinityConnection(DocStoreConnection): matchExpr.embedding_data_type, matchExpr.distance_type, matchExpr.topn, - matchExpr.extra_options, + matchExpr.extra_options.copy(), ) elif isinstance(matchExpr, FusionExpr): builder = builder.fusion( @@ -400,18 +402,17 @@ class InfinityConnection(DocStoreConnection): if orderBy.fields: builder.sort(order_by_expr_list) builder.offset(offset).limit(limit) - kb_res, extra_result = builder.option({"total_hits_count": True}).to_pl() + kb_res, extra_result = builder.option({"total_hits_count": True}).to_df() if extra_result: total_hits_count += int(extra_result["total_hits_count"]) logger.debug(f"INFINITY search table: {str(table_name)}, result: {str(kb_res)}") df_list.append(kb_res) self.connPool.release_conn(inf_conn) - res = concat_dataframes(df_list, selectFields) + res = concat_dataframes(df_list, output) if matchExprs: - res = res.sort(pl.col(score_column) + pl.col(PAGERANK_FLD), descending=True, maintain_order=True) - if score_column and score_column != "SCORE": - res = res.rename({score_column: "_score"}) - res = res.limit(limit) + res['Sum'] = res[score_column] + res[PAGERANK_FLD] + res = res.sort_values(by='Sum', ascending=False).reset_index(drop=True).drop(columns=['Sum']) + res = res.head(limit) logger.debug(f"INFINITY search final result: {str(res)}") return res, total_hits_count @@ -433,12 +434,12 @@ class InfinityConnection(DocStoreConnection): logger.warning( f"Table not found: {table_name}, this knowledge base isn't created in Infinity. Maybe it is created in other document engine.") continue - kb_res, _ = table_instance.output(["*"]).filter(f"id = '{chunkId}'").to_pl() + kb_res, _ = table_instance.output(["*"]).filter(f"id = '{chunkId}'").to_df() logger.debug(f"INFINITY get table: {str(table_list)}, result: {str(kb_res)}") df_list.append(kb_res) self.connPool.release_conn(inf_conn) res = concat_dataframes(df_list, ["id"]) - res_fields = self.getFields(res, res.columns) + res_fields = self.getFields(res, res.columns.tolist()) return res_fields.get(chunkId, None) def insert( @@ -572,60 +573,54 @@ class InfinityConnection(DocStoreConnection): Helper functions for search result """ - def getTotal(self, res: tuple[pl.DataFrame, int] | pl.DataFrame) -> int: + def getTotal(self, res: tuple[pd.DataFrame, int] | pd.DataFrame) -> int: if isinstance(res, tuple): return res[1] return len(res) - def getChunkIds(self, res: tuple[pl.DataFrame, int] | pl.DataFrame) -> list[str]: + def getChunkIds(self, res: tuple[pd.DataFrame, int] | pd.DataFrame) -> list[str]: if isinstance(res, tuple): res = res[0] return list(res["id"]) - def getFields(self, res: tuple[pl.DataFrame, int] | pl.DataFrame, fields: list[str]) -> list[str, dict]: + def getFields(self, res: tuple[pd.DataFrame, int] | pd.DataFrame, fields: list[str]) -> dict[str, dict]: if isinstance(res, tuple): res = res[0] - res_fields = {} if not fields: return {} - num_rows = len(res) - column_id = res["id"] - for i in range(num_rows): - id = column_id[i] - m = {"id": id} - for fieldnm in fields: - if fieldnm not in res: - m[fieldnm] = None - continue - v = res[fieldnm][i] - if isinstance(v, Series): - v = list(v) - elif fieldnm in ["important_kwd", "question_kwd", "entities_kwd", "tag_kwd", "source_id"]: - assert isinstance(v, str) - v = [kwd for kwd in v.split("###") if kwd] - elif fieldnm == "position_int": - assert isinstance(v, str) + fieldsAll = fields.copy() + fieldsAll.append('id') + column_map = {col.lower(): col for col in res.columns} + matched_columns = {column_map[col.lower()]:col for col in set(fieldsAll) if col.lower() in column_map} + none_columns = [col for col in set(fieldsAll) if col.lower() not in column_map] + + res2 = res[matched_columns.keys()] + res2 = res2.rename(columns=matched_columns) + res2.drop_duplicates(subset=['id'], inplace=True) + + for column in res2.columns: + k = column.lower() + if k in ["important_kwd", "question_kwd", "entities_kwd", "tag_kwd", "source_id"]: + res2[column] = res2[column].apply(lambda v:[kwd for kwd in v.split("###") if kwd]) + elif k == "position_int": + def to_position_int(v): if v: arr = [int(hex_val, 16) for hex_val in v.split('_')] v = [arr[i:i + 5] for i in range(0, len(arr), 5)] else: v = [] - elif fieldnm in ["page_num_int", "top_int"]: - assert isinstance(v, str) - if v: - v = [int(hex_val, 16) for hex_val in v.split('_')] - else: - v = [] - else: - if not isinstance(v, str): - v = str(v) - # if fieldnm.endswith("_tks"): - # v = rmSpace(v) - m[fieldnm] = v - res_fields[id] = m - return res_fields + return v + res2[column] = res2[column].apply(to_position_int) + elif k in ["page_num_int", "top_int"]: + res2[column] = res2[column].apply(lambda v:[int(hex_val, 16) for hex_val in v.split('_')] if v else []) + else: + pass + for column in none_columns: + res2[column] = None + + return res2.set_index("id").to_dict(orient="index") - def getHighlight(self, res: tuple[pl.DataFrame, int] | pl.DataFrame, keywords: list[str], fieldnm: str): + def getHighlight(self, res: tuple[pd.DataFrame, int] | pd.DataFrame, keywords: list[str], fieldnm: str): if isinstance(res, tuple): res = res[0] ans = {} @@ -655,7 +650,7 @@ class InfinityConnection(DocStoreConnection): ans[id] = "...".join(txts) return ans - def getAggregation(self, res: tuple[pl.DataFrame, int] | pl.DataFrame, fieldnm: str): + def getAggregation(self, res: tuple[pd.DataFrame, int] | pd.DataFrame, fieldnm: str): """ TODO: Infinity doesn't provide aggregation """