### What problem does this PR solve?

1. Remove unused code
2. Fix type mismatch, in nlp search and infinity search interface
3. Fix chunk list, get all chunks of this user.

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)

---------

Signed-off-by: jinhai <haijin.chn@gmail.com>
This commit is contained in:
Jin Hai 2024-11-20 09:31:36 +08:00 committed by Yingfeng Zhang
parent c4f2464935
commit 2044bb0039
3 changed files with 54 additions and 46 deletions

View File

@ -17,13 +17,13 @@ from abc import ABC
import builtins import builtins
import json import json
import os import os
import logging
from functools import partial from functools import partial
from typing import Tuple, Union from typing import Tuple, Union
import pandas as pd import pandas as pd
from agent import settings from agent import settings
from agent.settings import flow_logger, DEBUG
_FEEDED_DEPRECATED_PARAMS = "_feeded_deprecated_params" _FEEDED_DEPRECATED_PARAMS = "_feeded_deprecated_params"
_DEPRECATED_PARAMS = "_deprecated_params" _DEPRECATED_PARAMS = "_deprecated_params"
@ -480,7 +480,6 @@ class ComponentBase(ABC):
upstream_outs = [] upstream_outs = []
if DEBUG: print(self.component_name, reversed_cpnts[::-1])
for u in reversed_cpnts[::-1]: for u in reversed_cpnts[::-1]:
if self.get_component_name(u) in ["switch", "concentrator"]: continue if self.get_component_name(u) in ["switch", "concentrator"]: continue
if self.component_name.lower() == "generate" and self.get_component_name(u) == "retrieval": if self.component_name.lower() == "generate" and self.get_component_name(u) == "retrieval":

View File

@ -23,7 +23,7 @@ from rag.nlp.search import Dealer
class KGSearch(Dealer): class KGSearch(Dealer):
def search(self, req, idxnm, kb_ids, emb_mdl, highlight=False): def search(self, req, idxnm: str | list[str], kb_ids: list[str], emb_mdl=None, highlight=False):
def merge_into_first(sres, title="") -> dict[str, str]: def merge_into_first(sres, title="") -> dict[str, str]:
if not sres: if not sres:
return {} return {}

View File

@ -4,7 +4,7 @@ import re
import json import json
import time import time
import infinity import infinity
from infinity.common import ConflictType, InfinityException from infinity.common import ConflictType, InfinityException, SortType
from infinity.index import IndexInfo, IndexType from infinity.index import IndexInfo, IndexType
from infinity.connection_pool import ConnectionPool from infinity.connection_pool import ConnectionPool
from rag import settings from rag import settings
@ -22,6 +22,7 @@ from rag.utils.doc_store_conn import (
OrderByExpr, OrderByExpr,
) )
def equivalent_condition_to_str(condition: dict) -> str: def equivalent_condition_to_str(condition: dict) -> str:
assert "_id" not in condition assert "_id" not in condition
cond = list() cond = list()
@ -65,7 +66,7 @@ class InfinityConnection(DocStoreConnection):
self.connPool = connPool self.connPool = connPool
break break
except Exception as e: except Exception as e:
logging.warn(f"{str(e)}. Waiting Infinity {infinity_uri} to be healthy.") logging.warning(f"{str(e)}. Waiting Infinity {infinity_uri} to be healthy.")
time.sleep(5) time.sleep(5)
if self.connPool is None: if self.connPool is None:
msg = f"Infinity {infinity_uri} didn't become healthy in 120s." msg = f"Infinity {infinity_uri} didn't become healthy in 120s."
@ -168,7 +169,7 @@ class InfinityConnection(DocStoreConnection):
self.connPool.release_conn(inf_conn) self.connPool.release_conn(inf_conn)
return True return True
except Exception as e: except Exception as e:
logging.warn(f"INFINITY indexExist {str(e)}") logging.warning(f"INFINITY indexExist {str(e)}")
return False return False
""" """
@ -176,16 +177,16 @@ class InfinityConnection(DocStoreConnection):
""" """
def search( def search(
self, self,
selectFields: list[str], selectFields: list[str],
highlightFields: list[str], highlightFields: list[str],
condition: dict, condition: dict,
matchExprs: list[MatchExpr], matchExprs: list[MatchExpr],
orderBy: OrderByExpr, orderBy: OrderByExpr,
offset: int, offset: int,
limit: int, limit: int,
indexNames: str|list[str], indexNames: str | list[str],
knowledgebaseIds: list[str], knowledgebaseIds: list[str],
) -> list[dict] | pl.DataFrame: ) -> list[dict] | pl.DataFrame:
""" """
TODO: Infinity doesn't provide highlight TODO: Infinity doesn't provide highlight
@ -219,8 +220,8 @@ class InfinityConnection(DocStoreConnection):
minimum_should_match = "0%" minimum_should_match = "0%"
if "minimum_should_match" in matchExpr.extra_options: if "minimum_should_match" in matchExpr.extra_options:
minimum_should_match = ( minimum_should_match = (
str(int(matchExpr.extra_options["minimum_should_match"] * 100)) str(int(matchExpr.extra_options["minimum_should_match"] * 100))
+ "%" + "%"
) )
matchExpr.extra_options.update( matchExpr.extra_options.update(
{"minimum_should_match": minimum_should_match} {"minimum_should_match": minimum_should_match}
@ -234,10 +235,14 @@ class InfinityConnection(DocStoreConnection):
for k, v in matchExpr.extra_options.items(): for k, v in matchExpr.extra_options.items():
if not isinstance(v, str): if not isinstance(v, str):
matchExpr.extra_options[k] = str(v) matchExpr.extra_options[k] = str(v)
order_by_expr_list = list()
if orderBy.fields: if orderBy.fields:
order_by_expr_list = list()
for order_field in orderBy.fields: for order_field in orderBy.fields:
order_by_expr_list.append((order_field[0], order_field[1] == 0)) if order_field[1] == 0:
order_by_expr_list.append((order_field[0], SortType.Asc))
else:
order_by_expr_list.append((order_field[0], SortType.Desc))
# Scatter search tables and gather the results # Scatter search tables and gather the results
for indexName in indexNames: for indexName in indexNames:
@ -249,28 +254,32 @@ class InfinityConnection(DocStoreConnection):
continue continue
table_list.append(table_name) table_list.append(table_name)
builder = table_instance.output(selectFields) builder = table_instance.output(selectFields)
for matchExpr in matchExprs: if len(matchExprs) > 0:
if isinstance(matchExpr, MatchTextExpr): for matchExpr in matchExprs:
fields = ",".join(matchExpr.fields) if isinstance(matchExpr, MatchTextExpr):
builder = builder.match_text( fields = ",".join(matchExpr.fields)
fields, builder = builder.match_text(
matchExpr.matching_text, fields,
matchExpr.topn, matchExpr.matching_text,
matchExpr.extra_options, matchExpr.topn,
) matchExpr.extra_options,
elif isinstance(matchExpr, MatchDenseExpr): )
builder = builder.match_dense( elif isinstance(matchExpr, MatchDenseExpr):
matchExpr.vector_column_name, builder = builder.match_dense(
matchExpr.embedding_data, matchExpr.vector_column_name,
matchExpr.embedding_data_type, matchExpr.embedding_data,
matchExpr.distance_type, matchExpr.embedding_data_type,
matchExpr.topn, matchExpr.distance_type,
matchExpr.extra_options, matchExpr.topn,
) matchExpr.extra_options,
elif isinstance(matchExpr, FusionExpr): )
builder = builder.fusion( elif isinstance(matchExpr, FusionExpr):
matchExpr.method, matchExpr.topn, matchExpr.fusion_params builder = builder.fusion(
) matchExpr.method, matchExpr.topn, matchExpr.fusion_params
)
else:
if len(filter_cond) > 0:
builder.filter(filter_cond)
if orderBy.fields: if orderBy.fields:
builder.sort(order_by_expr_list) builder.sort(order_by_expr_list)
builder.offset(offset).limit(limit) builder.offset(offset).limit(limit)
@ -282,7 +291,7 @@ class InfinityConnection(DocStoreConnection):
return res return res
def get( def get(
self, chunkId: str, indexName: str, knowledgebaseIds: list[str] self, chunkId: str, indexName: str, knowledgebaseIds: list[str]
) -> dict | None: ) -> dict | None:
inf_conn = self.connPool.get_conn() inf_conn = self.connPool.get_conn()
db_instance = inf_conn.get_database(self.dbName) db_instance = inf_conn.get_database(self.dbName)
@ -299,7 +308,7 @@ class InfinityConnection(DocStoreConnection):
return res_fields.get(chunkId, None) return res_fields.get(chunkId, None)
def insert( def insert(
self, documents: list[dict], indexName: str, knowledgebaseId: str self, documents: list[dict], indexName: str, knowledgebaseId: str
) -> list[str]: ) -> list[str]:
inf_conn = self.connPool.get_conn() inf_conn = self.connPool.get_conn()
db_instance = inf_conn.get_database(self.dbName) db_instance = inf_conn.get_database(self.dbName)
@ -341,7 +350,7 @@ class InfinityConnection(DocStoreConnection):
return [] return []
def update( def update(
self, condition: dict, newValue: dict, indexName: str, knowledgebaseId: str self, condition: dict, newValue: dict, indexName: str, knowledgebaseId: str
) -> bool: ) -> bool:
# if 'position_list' in newValue: # if 'position_list' in newValue:
# logging.info(f"upsert position_list: {newValue['position_list']}") # logging.info(f"upsert position_list: {newValue['position_list']}")
@ -430,7 +439,7 @@ class InfinityConnection(DocStoreConnection):
flags=re.IGNORECASE | re.MULTILINE, flags=re.IGNORECASE | re.MULTILINE,
) )
if not re.search( if not re.search(
r"<em>[^<>]+</em>", t, flags=re.IGNORECASE | re.MULTILINE r"<em>[^<>]+</em>", t, flags=re.IGNORECASE | re.MULTILINE
): ):
continue continue
txts.append(t) txts.append(t)