From 2044bb0039dcd9c3dccd0100ae06291db79bab42 Mon Sep 17 00:00:00 2001 From: Jin Hai Date: Wed, 20 Nov 2024 09:31:36 +0800 Subject: [PATCH] Fix bugs (#3502) ### What problem does this PR solve? 1. Remove unused code 2. Fix type mismatch, in nlp search and infinity search interface 3. Fix chunk list, get all chunks of this user. ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue) --------- Signed-off-by: jinhai --- agent/component/base.py | 3 +- graphrag/search.py | 2 +- rag/utils/infinity_conn.py | 95 +++++++++++++++++++++----------------- 3 files changed, 54 insertions(+), 46 deletions(-) diff --git a/agent/component/base.py b/agent/component/base.py index fc44b4da7..5624b7bff 100644 --- a/agent/component/base.py +++ b/agent/component/base.py @@ -17,13 +17,13 @@ from abc import ABC import builtins import json import os +import logging from functools import partial from typing import Tuple, Union import pandas as pd from agent import settings -from agent.settings import flow_logger, DEBUG _FEEDED_DEPRECATED_PARAMS = "_feeded_deprecated_params" _DEPRECATED_PARAMS = "_deprecated_params" @@ -480,7 +480,6 @@ class ComponentBase(ABC): upstream_outs = [] - if DEBUG: print(self.component_name, reversed_cpnts[::-1]) for u in reversed_cpnts[::-1]: if self.get_component_name(u) in ["switch", "concentrator"]: continue if self.component_name.lower() == "generate" and self.get_component_name(u) == "retrieval": diff --git a/graphrag/search.py b/graphrag/search.py index 24735c4b5..3fa962fcb 100644 --- a/graphrag/search.py +++ b/graphrag/search.py @@ -23,7 +23,7 @@ from rag.nlp.search import Dealer class KGSearch(Dealer): - def search(self, req, idxnm, kb_ids, emb_mdl, highlight=False): + def search(self, req, idxnm: str | list[str], kb_ids: list[str], emb_mdl=None, highlight=False): def merge_into_first(sres, title="") -> dict[str, str]: if not sres: return {} diff --git a/rag/utils/infinity_conn.py b/rag/utils/infinity_conn.py index ad5a83ab7..4d7ef0dcd 100644 --- a/rag/utils/infinity_conn.py +++ b/rag/utils/infinity_conn.py @@ -4,7 +4,7 @@ import re import json import time import infinity -from infinity.common import ConflictType, InfinityException +from infinity.common import ConflictType, InfinityException, SortType from infinity.index import IndexInfo, IndexType from infinity.connection_pool import ConnectionPool from rag import settings @@ -22,6 +22,7 @@ from rag.utils.doc_store_conn import ( OrderByExpr, ) + def equivalent_condition_to_str(condition: dict) -> str: assert "_id" not in condition cond = list() @@ -65,7 +66,7 @@ class InfinityConnection(DocStoreConnection): self.connPool = connPool break except Exception as e: - logging.warn(f"{str(e)}. Waiting Infinity {infinity_uri} to be healthy.") + logging.warning(f"{str(e)}. Waiting Infinity {infinity_uri} to be healthy.") time.sleep(5) if self.connPool is None: msg = f"Infinity {infinity_uri} didn't become healthy in 120s." @@ -168,7 +169,7 @@ class InfinityConnection(DocStoreConnection): self.connPool.release_conn(inf_conn) return True except Exception as e: - logging.warn(f"INFINITY indexExist {str(e)}") + logging.warning(f"INFINITY indexExist {str(e)}") return False """ @@ -176,16 +177,16 @@ class InfinityConnection(DocStoreConnection): """ def search( - self, - selectFields: list[str], - highlightFields: list[str], - condition: dict, - matchExprs: list[MatchExpr], - orderBy: OrderByExpr, - offset: int, - limit: int, - indexNames: str|list[str], - knowledgebaseIds: list[str], + self, + selectFields: list[str], + highlightFields: list[str], + condition: dict, + matchExprs: list[MatchExpr], + orderBy: OrderByExpr, + offset: int, + limit: int, + indexNames: str | list[str], + knowledgebaseIds: list[str], ) -> list[dict] | pl.DataFrame: """ TODO: Infinity doesn't provide highlight @@ -219,8 +220,8 @@ class InfinityConnection(DocStoreConnection): minimum_should_match = "0%" if "minimum_should_match" in matchExpr.extra_options: minimum_should_match = ( - str(int(matchExpr.extra_options["minimum_should_match"] * 100)) - + "%" + str(int(matchExpr.extra_options["minimum_should_match"] * 100)) + + "%" ) matchExpr.extra_options.update( {"minimum_should_match": minimum_should_match} @@ -234,10 +235,14 @@ class InfinityConnection(DocStoreConnection): for k, v in matchExpr.extra_options.items(): if not isinstance(v, str): matchExpr.extra_options[k] = str(v) + + order_by_expr_list = list() if orderBy.fields: - order_by_expr_list = list() for order_field in orderBy.fields: - order_by_expr_list.append((order_field[0], order_field[1] == 0)) + if order_field[1] == 0: + order_by_expr_list.append((order_field[0], SortType.Asc)) + else: + order_by_expr_list.append((order_field[0], SortType.Desc)) # Scatter search tables and gather the results for indexName in indexNames: @@ -249,28 +254,32 @@ class InfinityConnection(DocStoreConnection): continue table_list.append(table_name) builder = table_instance.output(selectFields) - for matchExpr in matchExprs: - if isinstance(matchExpr, MatchTextExpr): - fields = ",".join(matchExpr.fields) - builder = builder.match_text( - fields, - matchExpr.matching_text, - matchExpr.topn, - matchExpr.extra_options, - ) - elif isinstance(matchExpr, MatchDenseExpr): - builder = builder.match_dense( - matchExpr.vector_column_name, - matchExpr.embedding_data, - matchExpr.embedding_data_type, - matchExpr.distance_type, - matchExpr.topn, - matchExpr.extra_options, - ) - elif isinstance(matchExpr, FusionExpr): - builder = builder.fusion( - matchExpr.method, matchExpr.topn, matchExpr.fusion_params - ) + if len(matchExprs) > 0: + for matchExpr in matchExprs: + if isinstance(matchExpr, MatchTextExpr): + fields = ",".join(matchExpr.fields) + builder = builder.match_text( + fields, + matchExpr.matching_text, + matchExpr.topn, + matchExpr.extra_options, + ) + elif isinstance(matchExpr, MatchDenseExpr): + builder = builder.match_dense( + matchExpr.vector_column_name, + matchExpr.embedding_data, + matchExpr.embedding_data_type, + matchExpr.distance_type, + matchExpr.topn, + matchExpr.extra_options, + ) + elif isinstance(matchExpr, FusionExpr): + builder = builder.fusion( + matchExpr.method, matchExpr.topn, matchExpr.fusion_params + ) + else: + if len(filter_cond) > 0: + builder.filter(filter_cond) if orderBy.fields: builder.sort(order_by_expr_list) builder.offset(offset).limit(limit) @@ -282,7 +291,7 @@ class InfinityConnection(DocStoreConnection): return res def get( - self, chunkId: str, indexName: str, knowledgebaseIds: list[str] + self, chunkId: str, indexName: str, knowledgebaseIds: list[str] ) -> dict | None: inf_conn = self.connPool.get_conn() db_instance = inf_conn.get_database(self.dbName) @@ -299,7 +308,7 @@ class InfinityConnection(DocStoreConnection): return res_fields.get(chunkId, None) def insert( - self, documents: list[dict], indexName: str, knowledgebaseId: str + self, documents: list[dict], indexName: str, knowledgebaseId: str ) -> list[str]: inf_conn = self.connPool.get_conn() db_instance = inf_conn.get_database(self.dbName) @@ -341,7 +350,7 @@ class InfinityConnection(DocStoreConnection): return [] def update( - self, condition: dict, newValue: dict, indexName: str, knowledgebaseId: str + self, condition: dict, newValue: dict, indexName: str, knowledgebaseId: str ) -> bool: # if 'position_list' in newValue: # logging.info(f"upsert position_list: {newValue['position_list']}") @@ -430,7 +439,7 @@ class InfinityConnection(DocStoreConnection): flags=re.IGNORECASE | re.MULTILINE, ) if not re.search( - r"[^<>]+", t, flags=re.IGNORECASE | re.MULTILINE + r"[^<>]+", t, flags=re.IGNORECASE | re.MULTILINE ): continue txts.append(t)