From c354239b797bef96faa271589b32e574992483d9 Mon Sep 17 00:00:00 2001 From: Kevin Hu Date: Sun, 26 Jan 2025 18:45:36 +0800 Subject: [PATCH] Make infinity adapt to condition `exist`. (#4657) ### What problem does this PR solve? ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue) --- api/apps/kb_app.py | 10 +++++- rag/raptor.py | 4 ++- rag/utils/es_conn.py | 2 +- rag/utils/infinity_conn.py | 68 +++++++++++++++++++++++++++++++------- 4 files changed, 69 insertions(+), 15 deletions(-) diff --git a/api/apps/kb_app.py b/api/apps/kb_app.py index 0cb511611..6dec0e2e6 100644 --- a/api/apps/kb_app.py +++ b/api/apps/kb_app.py @@ -24,6 +24,7 @@ from api.db.services.document_service import DocumentService from api.db.services.file2document_service import File2DocumentService from api.db.services.file_service import FileService from api.db.services.user_service import TenantService, UserTenantService +from api.settings import DOC_ENGINE from api.utils.api_utils import server_error_response, get_data_error_result, validate_request, not_allowed_parameters from api.utils import get_uuid from api.db import StatusEnum, FileSource @@ -96,6 +97,13 @@ def update(): return get_data_error_result( message="Can't find this knowledgebase!") + if req.get("parser_id", "") == "tag" and DOC_ENGINE == "infinity": + return get_json_result( + data=False, + message='The chunk method Tag has not been supported by Infinity yet.', + code=settings.RetCode.OPERATING_ERROR + ) + if req["name"].lower() != kb.name.lower() \ and len( KnowledgebaseService.query(name=req["name"], tenant_id=current_user.id, status=StatusEnum.VALID.value)) > 1: @@ -112,7 +120,7 @@ def update(): search.index_name(kb.tenant_id), kb.id) else: # Elasticsearch requires PAGERANK_FLD be non-zero! - settings.docStoreConn.update({"exist": PAGERANK_FLD}, {"remove": PAGERANK_FLD}, + settings.docStoreConn.update({"exists": PAGERANK_FLD}, {"remove": PAGERANK_FLD}, search.index_name(kb.tenant_id), kb.id) e, kb = KnowledgebaseService.get_by_id(kb.id) diff --git a/rag/raptor.py b/rag/raptor.py index e05293cee..fcc686565 100644 --- a/rag/raptor.py +++ b/rag/raptor.py @@ -71,7 +71,7 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval: start, end = 0, len(chunks) if len(chunks) <= 1: return - chunks = [(s, a) for s, a in chunks if len(a) > 0] + chunks = [(s, a) for s, a in chunks if s and len(a) > 0] def summarize(ck_idx, lock): nonlocal chunks @@ -125,6 +125,8 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval: threads = [] for c in range(n_clusters): ck_idx = [i + start for i in range(len(lbls)) if lbls[i] == c] + if not ck_idx: + continue threads.append(executor.submit(summarize, ck_idx, lock)) wait(threads, return_when=ALL_COMPLETED) for th in threads: diff --git a/rag/utils/es_conn.py b/rag/utils/es_conn.py index ce05d5d6d..c4ed0f185 100644 --- a/rag/utils/es_conn.py +++ b/rag/utils/es_conn.py @@ -336,7 +336,7 @@ class ESConnection(DocStoreConnection): for k, v in condition.items(): if not isinstance(k, str) or not v: continue - if k == "exist": + if k == "exists": bqry.filter.append(Q("exists", field=v)) continue if isinstance(v, list): diff --git a/rag/utils/infinity_conn.py b/rag/utils/infinity_conn.py index b6d3202aa..e8a38c177 100644 --- a/rag/utils/infinity_conn.py +++ b/rag/utils/infinity_conn.py @@ -44,8 +44,23 @@ from rag.utils.doc_store_conn import ( logger = logging.getLogger('ragflow.infinity_conn') -def equivalent_condition_to_str(condition: dict) -> str | None: +def equivalent_condition_to_str(condition: dict, table_instance=None) -> str | None: assert "_id" not in condition + clmns = {} + if table_instance: + for n, ty, de, _ in table_instance.show_columns().rows(): + clmns[n] = (ty, de) + + def exists(cln): + nonlocal clmns + assert cln in clmns, f"'{cln}' should be in '{clmns}'." + ty, de = clmns[cln] + if ty.lower().find("cha"): + if not de: + de = "" + return f" {cln}!='{de}' " + return f"{cln}!={de}" + cond = list() for k, v in condition.items(): if not isinstance(k, str) or k in ["kb_id"] or not v: @@ -61,8 +76,15 @@ def equivalent_condition_to_str(condition: dict) -> str | None: strInCond = ", ".join(inCond) strInCond = f"{k} IN ({strInCond})" cond.append(strInCond) + elif k == "must_not": + if isinstance(v, dict): + for kk, vv in v.items(): + if kk == "exists": + cond.append("NOT (%s)" % exists(vv)) elif isinstance(v, str): cond.append(f"{k}='{v}'") + elif k == "exists": + cond.append(exists(v)) else: cond.append(f"{k}={str(v)}") return " AND ".join(cond) if cond else "1=1" @@ -294,7 +316,11 @@ class InfinityConnection(DocStoreConnection): filter_cond = None filter_fulltext = "" if condition: - filter_cond = equivalent_condition_to_str(condition) + for indexName in indexNames: + table_name = f"{indexName}_{knowledgebaseIds[0]}" + filter_cond = equivalent_condition_to_str(condition, db_instance.get_table(table_name)) + break + for matchExpr in matchExprs: if isinstance(matchExpr, MatchTextExpr): if filter_cond and "filter" not in matchExpr.extra_options: @@ -434,12 +460,21 @@ class InfinityConnection(DocStoreConnection): self.createIdx(indexName, knowledgebaseId, vector_size) table_instance = db_instance.get_table(table_name) + # embedding fields can't have a default value.... + embedding_clmns = [] + clmns = table_instance.show_columns().rows() + for n, ty, _, _ in clmns: + r = re.search(r"Embedding\([a-z]+,([0-9]+)\)", ty) + if not r: + continue + embedding_clmns.append((n, int(r.group(1)))) + docs = copy.deepcopy(documents) for d in docs: assert "_id" not in d assert "id" in d for k, v in d.items(): - if k in ["important_kwd", "question_kwd", "entities_kwd", "tag_kwd"]: + if k in ["important_kwd", "question_kwd", "entities_kwd", "tag_kwd", "source_id"]: assert isinstance(v, list) d[k] = "###".join(v) elif re.search(r"_feas$", k): @@ -454,6 +489,11 @@ class InfinityConnection(DocStoreConnection): elif k in ["page_num_int", "top_int"]: assert isinstance(v, list) d[k] = "_".join(f"{num:08x}" for num in v) + + for n, vs in embedding_clmns: + if n in d: + continue + d[n] = [0] * vs ids = ["'{}'".format(d["id"]) for d in docs] str_ids = ", ".join(ids) str_filter = f"id IN ({str_ids})" @@ -475,11 +515,11 @@ class InfinityConnection(DocStoreConnection): db_instance = inf_conn.get_database(self.dbName) table_name = f"{indexName}_{knowledgebaseId}" table_instance = db_instance.get_table(table_name) - if "exist" in condition: - del condition["exist"] - filter = equivalent_condition_to_str(condition) + #if "exists" in condition: + # del condition["exists"] + filter = equivalent_condition_to_str(condition, table_instance) for k, v in list(newValue.items()): - if k in ["important_kwd", "question_kwd", "entities_kwd", "tag_kwd"]: + if k in ["important_kwd", "question_kwd", "entities_kwd", "tag_kwd", "source_id"]: assert isinstance(v, list) newValue[k] = "###".join(v) elif re.search(r"_feas$", k): @@ -496,9 +536,11 @@ class InfinityConnection(DocStoreConnection): elif k in ["page_num_int", "top_int"]: assert isinstance(v, list) newValue[k] = "_".join(f"{num:08x}" for num in v) - elif k == "remove" and v in [PAGERANK_FLD]: + elif k == "remove": del newValue[k] - newValue[v] = 0 + if v in [PAGERANK_FLD]: + newValue[v] = 0 + logger.debug(f"INFINITY update table {table_name}, filter {filter}, newValue {newValue}.") table_instance.update(filter, newValue) self.connPool.release_conn(inf_conn) @@ -508,14 +550,14 @@ class InfinityConnection(DocStoreConnection): inf_conn = self.connPool.get_conn() db_instance = inf_conn.get_database(self.dbName) table_name = f"{indexName}_{knowledgebaseId}" - filter = equivalent_condition_to_str(condition) try: table_instance = db_instance.get_table(table_name) except Exception: logger.warning( - f"Skipped deleting `{filter}` from table {table_name} since the table doesn't exist." + f"Skipped deleting from table {table_name} since the table doesn't exist." ) return 0 + filter = equivalent_condition_to_str(condition, table_instance) logger.debug(f"INFINITY delete table {table_name}, filter {filter}.") res = table_instance.delete(filter) self.connPool.release_conn(inf_conn) @@ -553,7 +595,7 @@ class InfinityConnection(DocStoreConnection): v = res[fieldnm][i] if isinstance(v, Series): v = list(v) - elif fieldnm in ["important_kwd", "question_kwd", "entities_kwd", "tag_kwd"]: + elif fieldnm in ["important_kwd", "question_kwd", "entities_kwd", "tag_kwd", "source_id"]: assert isinstance(v, str) v = [kwd for kwd in v.split("###") if kwd] elif fieldnm == "position_int": @@ -584,6 +626,8 @@ class InfinityConnection(DocStoreConnection): ans = {} num_rows = len(res) column_id = res["id"] + if fieldnm not in res: + return {} for i in range(num_rows): id = column_id[i] txt = res[fieldnm][i]