diff --git a/agent/component/base.py b/agent/component/base.py
index fc44b4da7..5624b7bff 100644
--- a/agent/component/base.py
+++ b/agent/component/base.py
@@ -17,13 +17,13 @@ from abc import ABC
import builtins
import json
import os
+import logging
from functools import partial
from typing import Tuple, Union
import pandas as pd
from agent import settings
-from agent.settings import flow_logger, DEBUG
_FEEDED_DEPRECATED_PARAMS = "_feeded_deprecated_params"
_DEPRECATED_PARAMS = "_deprecated_params"
@@ -480,7 +480,6 @@ class ComponentBase(ABC):
upstream_outs = []
- if DEBUG: print(self.component_name, reversed_cpnts[::-1])
for u in reversed_cpnts[::-1]:
if self.get_component_name(u) in ["switch", "concentrator"]: continue
if self.component_name.lower() == "generate" and self.get_component_name(u) == "retrieval":
diff --git a/graphrag/search.py b/graphrag/search.py
index 24735c4b5..3fa962fcb 100644
--- a/graphrag/search.py
+++ b/graphrag/search.py
@@ -23,7 +23,7 @@ from rag.nlp.search import Dealer
class KGSearch(Dealer):
- def search(self, req, idxnm, kb_ids, emb_mdl, highlight=False):
+ def search(self, req, idxnm: str | list[str], kb_ids: list[str], emb_mdl=None, highlight=False):
def merge_into_first(sres, title="") -> dict[str, str]:
if not sres:
return {}
diff --git a/rag/utils/infinity_conn.py b/rag/utils/infinity_conn.py
index ad5a83ab7..4d7ef0dcd 100644
--- a/rag/utils/infinity_conn.py
+++ b/rag/utils/infinity_conn.py
@@ -4,7 +4,7 @@ import re
import json
import time
import infinity
-from infinity.common import ConflictType, InfinityException
+from infinity.common import ConflictType, InfinityException, SortType
from infinity.index import IndexInfo, IndexType
from infinity.connection_pool import ConnectionPool
from rag import settings
@@ -22,6 +22,7 @@ from rag.utils.doc_store_conn import (
OrderByExpr,
)
+
def equivalent_condition_to_str(condition: dict) -> str:
assert "_id" not in condition
cond = list()
@@ -65,7 +66,7 @@ class InfinityConnection(DocStoreConnection):
self.connPool = connPool
break
except Exception as e:
- logging.warn(f"{str(e)}. Waiting Infinity {infinity_uri} to be healthy.")
+ logging.warning(f"{str(e)}. Waiting Infinity {infinity_uri} to be healthy.")
time.sleep(5)
if self.connPool is None:
msg = f"Infinity {infinity_uri} didn't become healthy in 120s."
@@ -168,7 +169,7 @@ class InfinityConnection(DocStoreConnection):
self.connPool.release_conn(inf_conn)
return True
except Exception as e:
- logging.warn(f"INFINITY indexExist {str(e)}")
+ logging.warning(f"INFINITY indexExist {str(e)}")
return False
"""
@@ -176,16 +177,16 @@ class InfinityConnection(DocStoreConnection):
"""
def search(
- self,
- selectFields: list[str],
- highlightFields: list[str],
- condition: dict,
- matchExprs: list[MatchExpr],
- orderBy: OrderByExpr,
- offset: int,
- limit: int,
- indexNames: str|list[str],
- knowledgebaseIds: list[str],
+ self,
+ selectFields: list[str],
+ highlightFields: list[str],
+ condition: dict,
+ matchExprs: list[MatchExpr],
+ orderBy: OrderByExpr,
+ offset: int,
+ limit: int,
+ indexNames: str | list[str],
+ knowledgebaseIds: list[str],
) -> list[dict] | pl.DataFrame:
"""
TODO: Infinity doesn't provide highlight
@@ -219,8 +220,8 @@ class InfinityConnection(DocStoreConnection):
minimum_should_match = "0%"
if "minimum_should_match" in matchExpr.extra_options:
minimum_should_match = (
- str(int(matchExpr.extra_options["minimum_should_match"] * 100))
- + "%"
+ str(int(matchExpr.extra_options["minimum_should_match"] * 100))
+ + "%"
)
matchExpr.extra_options.update(
{"minimum_should_match": minimum_should_match}
@@ -234,10 +235,14 @@ class InfinityConnection(DocStoreConnection):
for k, v in matchExpr.extra_options.items():
if not isinstance(v, str):
matchExpr.extra_options[k] = str(v)
+
+ order_by_expr_list = list()
if orderBy.fields:
- order_by_expr_list = list()
for order_field in orderBy.fields:
- order_by_expr_list.append((order_field[0], order_field[1] == 0))
+ if order_field[1] == 0:
+ order_by_expr_list.append((order_field[0], SortType.Asc))
+ else:
+ order_by_expr_list.append((order_field[0], SortType.Desc))
# Scatter search tables and gather the results
for indexName in indexNames:
@@ -249,28 +254,32 @@ class InfinityConnection(DocStoreConnection):
continue
table_list.append(table_name)
builder = table_instance.output(selectFields)
- for matchExpr in matchExprs:
- if isinstance(matchExpr, MatchTextExpr):
- fields = ",".join(matchExpr.fields)
- builder = builder.match_text(
- fields,
- matchExpr.matching_text,
- matchExpr.topn,
- matchExpr.extra_options,
- )
- elif isinstance(matchExpr, MatchDenseExpr):
- builder = builder.match_dense(
- matchExpr.vector_column_name,
- matchExpr.embedding_data,
- matchExpr.embedding_data_type,
- matchExpr.distance_type,
- matchExpr.topn,
- matchExpr.extra_options,
- )
- elif isinstance(matchExpr, FusionExpr):
- builder = builder.fusion(
- matchExpr.method, matchExpr.topn, matchExpr.fusion_params
- )
+ if len(matchExprs) > 0:
+ for matchExpr in matchExprs:
+ if isinstance(matchExpr, MatchTextExpr):
+ fields = ",".join(matchExpr.fields)
+ builder = builder.match_text(
+ fields,
+ matchExpr.matching_text,
+ matchExpr.topn,
+ matchExpr.extra_options,
+ )
+ elif isinstance(matchExpr, MatchDenseExpr):
+ builder = builder.match_dense(
+ matchExpr.vector_column_name,
+ matchExpr.embedding_data,
+ matchExpr.embedding_data_type,
+ matchExpr.distance_type,
+ matchExpr.topn,
+ matchExpr.extra_options,
+ )
+ elif isinstance(matchExpr, FusionExpr):
+ builder = builder.fusion(
+ matchExpr.method, matchExpr.topn, matchExpr.fusion_params
+ )
+ else:
+ if len(filter_cond) > 0:
+ builder.filter(filter_cond)
if orderBy.fields:
builder.sort(order_by_expr_list)
builder.offset(offset).limit(limit)
@@ -282,7 +291,7 @@ class InfinityConnection(DocStoreConnection):
return res
def get(
- self, chunkId: str, indexName: str, knowledgebaseIds: list[str]
+ self, chunkId: str, indexName: str, knowledgebaseIds: list[str]
) -> dict | None:
inf_conn = self.connPool.get_conn()
db_instance = inf_conn.get_database(self.dbName)
@@ -299,7 +308,7 @@ class InfinityConnection(DocStoreConnection):
return res_fields.get(chunkId, None)
def insert(
- self, documents: list[dict], indexName: str, knowledgebaseId: str
+ self, documents: list[dict], indexName: str, knowledgebaseId: str
) -> list[str]:
inf_conn = self.connPool.get_conn()
db_instance = inf_conn.get_database(self.dbName)
@@ -341,7 +350,7 @@ class InfinityConnection(DocStoreConnection):
return []
def update(
- self, condition: dict, newValue: dict, indexName: str, knowledgebaseId: str
+ self, condition: dict, newValue: dict, indexName: str, knowledgebaseId: str
) -> bool:
# if 'position_list' in newValue:
# logging.info(f"upsert position_list: {newValue['position_list']}")
@@ -430,7 +439,7 @@ class InfinityConnection(DocStoreConnection):
flags=re.IGNORECASE | re.MULTILINE,
)
if not re.search(
- r"[^<>]+", t, flags=re.IGNORECASE | re.MULTILINE
+ r"[^<>]+", t, flags=re.IGNORECASE | re.MULTILINE
):
continue
txts.append(t)