diff --git a/agent/component/base.py b/agent/component/base.py index fc44b4da7..5624b7bff 100644 --- a/agent/component/base.py +++ b/agent/component/base.py @@ -17,13 +17,13 @@ from abc import ABC import builtins import json import os +import logging from functools import partial from typing import Tuple, Union import pandas as pd from agent import settings -from agent.settings import flow_logger, DEBUG _FEEDED_DEPRECATED_PARAMS = "_feeded_deprecated_params" _DEPRECATED_PARAMS = "_deprecated_params" @@ -480,7 +480,6 @@ class ComponentBase(ABC): upstream_outs = [] - if DEBUG: print(self.component_name, reversed_cpnts[::-1]) for u in reversed_cpnts[::-1]: if self.get_component_name(u) in ["switch", "concentrator"]: continue if self.component_name.lower() == "generate" and self.get_component_name(u) == "retrieval": diff --git a/graphrag/search.py b/graphrag/search.py index 24735c4b5..3fa962fcb 100644 --- a/graphrag/search.py +++ b/graphrag/search.py @@ -23,7 +23,7 @@ from rag.nlp.search import Dealer class KGSearch(Dealer): - def search(self, req, idxnm, kb_ids, emb_mdl, highlight=False): + def search(self, req, idxnm: str | list[str], kb_ids: list[str], emb_mdl=None, highlight=False): def merge_into_first(sres, title="") -> dict[str, str]: if not sres: return {} diff --git a/rag/utils/infinity_conn.py b/rag/utils/infinity_conn.py index ad5a83ab7..4d7ef0dcd 100644 --- a/rag/utils/infinity_conn.py +++ b/rag/utils/infinity_conn.py @@ -4,7 +4,7 @@ import re import json import time import infinity -from infinity.common import ConflictType, InfinityException +from infinity.common import ConflictType, InfinityException, SortType from infinity.index import IndexInfo, IndexType from infinity.connection_pool import ConnectionPool from rag import settings @@ -22,6 +22,7 @@ from rag.utils.doc_store_conn import ( OrderByExpr, ) + def equivalent_condition_to_str(condition: dict) -> str: assert "_id" not in condition cond = list() @@ -65,7 +66,7 @@ class InfinityConnection(DocStoreConnection): self.connPool = connPool break except Exception as e: - logging.warn(f"{str(e)}. Waiting Infinity {infinity_uri} to be healthy.") + logging.warning(f"{str(e)}. Waiting Infinity {infinity_uri} to be healthy.") time.sleep(5) if self.connPool is None: msg = f"Infinity {infinity_uri} didn't become healthy in 120s." @@ -168,7 +169,7 @@ class InfinityConnection(DocStoreConnection): self.connPool.release_conn(inf_conn) return True except Exception as e: - logging.warn(f"INFINITY indexExist {str(e)}") + logging.warning(f"INFINITY indexExist {str(e)}") return False """ @@ -176,16 +177,16 @@ class InfinityConnection(DocStoreConnection): """ def search( - self, - selectFields: list[str], - highlightFields: list[str], - condition: dict, - matchExprs: list[MatchExpr], - orderBy: OrderByExpr, - offset: int, - limit: int, - indexNames: str|list[str], - knowledgebaseIds: list[str], + self, + selectFields: list[str], + highlightFields: list[str], + condition: dict, + matchExprs: list[MatchExpr], + orderBy: OrderByExpr, + offset: int, + limit: int, + indexNames: str | list[str], + knowledgebaseIds: list[str], ) -> list[dict] | pl.DataFrame: """ TODO: Infinity doesn't provide highlight @@ -219,8 +220,8 @@ class InfinityConnection(DocStoreConnection): minimum_should_match = "0%" if "minimum_should_match" in matchExpr.extra_options: minimum_should_match = ( - str(int(matchExpr.extra_options["minimum_should_match"] * 100)) - + "%" + str(int(matchExpr.extra_options["minimum_should_match"] * 100)) + + "%" ) matchExpr.extra_options.update( {"minimum_should_match": minimum_should_match} @@ -234,10 +235,14 @@ class InfinityConnection(DocStoreConnection): for k, v in matchExpr.extra_options.items(): if not isinstance(v, str): matchExpr.extra_options[k] = str(v) + + order_by_expr_list = list() if orderBy.fields: - order_by_expr_list = list() for order_field in orderBy.fields: - order_by_expr_list.append((order_field[0], order_field[1] == 0)) + if order_field[1] == 0: + order_by_expr_list.append((order_field[0], SortType.Asc)) + else: + order_by_expr_list.append((order_field[0], SortType.Desc)) # Scatter search tables and gather the results for indexName in indexNames: @@ -249,28 +254,32 @@ class InfinityConnection(DocStoreConnection): continue table_list.append(table_name) builder = table_instance.output(selectFields) - for matchExpr in matchExprs: - if isinstance(matchExpr, MatchTextExpr): - fields = ",".join(matchExpr.fields) - builder = builder.match_text( - fields, - matchExpr.matching_text, - matchExpr.topn, - matchExpr.extra_options, - ) - elif isinstance(matchExpr, MatchDenseExpr): - builder = builder.match_dense( - matchExpr.vector_column_name, - matchExpr.embedding_data, - matchExpr.embedding_data_type, - matchExpr.distance_type, - matchExpr.topn, - matchExpr.extra_options, - ) - elif isinstance(matchExpr, FusionExpr): - builder = builder.fusion( - matchExpr.method, matchExpr.topn, matchExpr.fusion_params - ) + if len(matchExprs) > 0: + for matchExpr in matchExprs: + if isinstance(matchExpr, MatchTextExpr): + fields = ",".join(matchExpr.fields) + builder = builder.match_text( + fields, + matchExpr.matching_text, + matchExpr.topn, + matchExpr.extra_options, + ) + elif isinstance(matchExpr, MatchDenseExpr): + builder = builder.match_dense( + matchExpr.vector_column_name, + matchExpr.embedding_data, + matchExpr.embedding_data_type, + matchExpr.distance_type, + matchExpr.topn, + matchExpr.extra_options, + ) + elif isinstance(matchExpr, FusionExpr): + builder = builder.fusion( + matchExpr.method, matchExpr.topn, matchExpr.fusion_params + ) + else: + if len(filter_cond) > 0: + builder.filter(filter_cond) if orderBy.fields: builder.sort(order_by_expr_list) builder.offset(offset).limit(limit) @@ -282,7 +291,7 @@ class InfinityConnection(DocStoreConnection): return res def get( - self, chunkId: str, indexName: str, knowledgebaseIds: list[str] + self, chunkId: str, indexName: str, knowledgebaseIds: list[str] ) -> dict | None: inf_conn = self.connPool.get_conn() db_instance = inf_conn.get_database(self.dbName) @@ -299,7 +308,7 @@ class InfinityConnection(DocStoreConnection): return res_fields.get(chunkId, None) def insert( - self, documents: list[dict], indexName: str, knowledgebaseId: str + self, documents: list[dict], indexName: str, knowledgebaseId: str ) -> list[str]: inf_conn = self.connPool.get_conn() db_instance = inf_conn.get_database(self.dbName) @@ -341,7 +350,7 @@ class InfinityConnection(DocStoreConnection): return [] def update( - self, condition: dict, newValue: dict, indexName: str, knowledgebaseId: str + self, condition: dict, newValue: dict, indexName: str, knowledgebaseId: str ) -> bool: # if 'position_list' in newValue: # logging.info(f"upsert position_list: {newValue['position_list']}") @@ -430,7 +439,7 @@ class InfinityConnection(DocStoreConnection): flags=re.IGNORECASE | re.MULTILINE, ) if not re.search( - r"[^<>]+", t, flags=re.IGNORECASE | re.MULTILINE + r"[^<>]+", t, flags=re.IGNORECASE | re.MULTILINE ): continue txts.append(t)