### What problem does this PR solve?

1. Remove unused code
2. Fix type mismatch, in nlp search and infinity search interface
3. Fix chunk list, get all chunks of this user.

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)

---------

Signed-off-by: jinhai <haijin.chn@gmail.com>
This commit is contained in:
Jin Hai 2024-11-20 09:31:36 +08:00 committed by Yingfeng Zhang
parent c4f2464935
commit 2044bb0039
3 changed files with 54 additions and 46 deletions

View File

@ -17,13 +17,13 @@ from abc import ABC
import builtins
import json
import os
import logging
from functools import partial
from typing import Tuple, Union
import pandas as pd
from agent import settings
from agent.settings import flow_logger, DEBUG
_FEEDED_DEPRECATED_PARAMS = "_feeded_deprecated_params"
_DEPRECATED_PARAMS = "_deprecated_params"
@ -480,7 +480,6 @@ class ComponentBase(ABC):
upstream_outs = []
if DEBUG: print(self.component_name, reversed_cpnts[::-1])
for u in reversed_cpnts[::-1]:
if self.get_component_name(u) in ["switch", "concentrator"]: continue
if self.component_name.lower() == "generate" and self.get_component_name(u) == "retrieval":

View File

@ -23,7 +23,7 @@ from rag.nlp.search import Dealer
class KGSearch(Dealer):
def search(self, req, idxnm, kb_ids, emb_mdl, highlight=False):
def search(self, req, idxnm: str | list[str], kb_ids: list[str], emb_mdl=None, highlight=False):
def merge_into_first(sres, title="") -> dict[str, str]:
if not sres:
return {}

View File

@ -4,7 +4,7 @@ import re
import json
import time
import infinity
from infinity.common import ConflictType, InfinityException
from infinity.common import ConflictType, InfinityException, SortType
from infinity.index import IndexInfo, IndexType
from infinity.connection_pool import ConnectionPool
from rag import settings
@ -22,6 +22,7 @@ from rag.utils.doc_store_conn import (
OrderByExpr,
)
def equivalent_condition_to_str(condition: dict) -> str:
assert "_id" not in condition
cond = list()
@ -65,7 +66,7 @@ class InfinityConnection(DocStoreConnection):
self.connPool = connPool
break
except Exception as e:
logging.warn(f"{str(e)}. Waiting Infinity {infinity_uri} to be healthy.")
logging.warning(f"{str(e)}. Waiting Infinity {infinity_uri} to be healthy.")
time.sleep(5)
if self.connPool is None:
msg = f"Infinity {infinity_uri} didn't become healthy in 120s."
@ -168,7 +169,7 @@ class InfinityConnection(DocStoreConnection):
self.connPool.release_conn(inf_conn)
return True
except Exception as e:
logging.warn(f"INFINITY indexExist {str(e)}")
logging.warning(f"INFINITY indexExist {str(e)}")
return False
"""
@ -176,16 +177,16 @@ class InfinityConnection(DocStoreConnection):
"""
def search(
self,
selectFields: list[str],
highlightFields: list[str],
condition: dict,
matchExprs: list[MatchExpr],
orderBy: OrderByExpr,
offset: int,
limit: int,
indexNames: str|list[str],
knowledgebaseIds: list[str],
self,
selectFields: list[str],
highlightFields: list[str],
condition: dict,
matchExprs: list[MatchExpr],
orderBy: OrderByExpr,
offset: int,
limit: int,
indexNames: str | list[str],
knowledgebaseIds: list[str],
) -> list[dict] | pl.DataFrame:
"""
TODO: Infinity doesn't provide highlight
@ -219,8 +220,8 @@ class InfinityConnection(DocStoreConnection):
minimum_should_match = "0%"
if "minimum_should_match" in matchExpr.extra_options:
minimum_should_match = (
str(int(matchExpr.extra_options["minimum_should_match"] * 100))
+ "%"
str(int(matchExpr.extra_options["minimum_should_match"] * 100))
+ "%"
)
matchExpr.extra_options.update(
{"minimum_should_match": minimum_should_match}
@ -234,10 +235,14 @@ class InfinityConnection(DocStoreConnection):
for k, v in matchExpr.extra_options.items():
if not isinstance(v, str):
matchExpr.extra_options[k] = str(v)
order_by_expr_list = list()
if orderBy.fields:
order_by_expr_list = list()
for order_field in orderBy.fields:
order_by_expr_list.append((order_field[0], order_field[1] == 0))
if order_field[1] == 0:
order_by_expr_list.append((order_field[0], SortType.Asc))
else:
order_by_expr_list.append((order_field[0], SortType.Desc))
# Scatter search tables and gather the results
for indexName in indexNames:
@ -249,28 +254,32 @@ class InfinityConnection(DocStoreConnection):
continue
table_list.append(table_name)
builder = table_instance.output(selectFields)
for matchExpr in matchExprs:
if isinstance(matchExpr, MatchTextExpr):
fields = ",".join(matchExpr.fields)
builder = builder.match_text(
fields,
matchExpr.matching_text,
matchExpr.topn,
matchExpr.extra_options,
)
elif isinstance(matchExpr, MatchDenseExpr):
builder = builder.match_dense(
matchExpr.vector_column_name,
matchExpr.embedding_data,
matchExpr.embedding_data_type,
matchExpr.distance_type,
matchExpr.topn,
matchExpr.extra_options,
)
elif isinstance(matchExpr, FusionExpr):
builder = builder.fusion(
matchExpr.method, matchExpr.topn, matchExpr.fusion_params
)
if len(matchExprs) > 0:
for matchExpr in matchExprs:
if isinstance(matchExpr, MatchTextExpr):
fields = ",".join(matchExpr.fields)
builder = builder.match_text(
fields,
matchExpr.matching_text,
matchExpr.topn,
matchExpr.extra_options,
)
elif isinstance(matchExpr, MatchDenseExpr):
builder = builder.match_dense(
matchExpr.vector_column_name,
matchExpr.embedding_data,
matchExpr.embedding_data_type,
matchExpr.distance_type,
matchExpr.topn,
matchExpr.extra_options,
)
elif isinstance(matchExpr, FusionExpr):
builder = builder.fusion(
matchExpr.method, matchExpr.topn, matchExpr.fusion_params
)
else:
if len(filter_cond) > 0:
builder.filter(filter_cond)
if orderBy.fields:
builder.sort(order_by_expr_list)
builder.offset(offset).limit(limit)
@ -282,7 +291,7 @@ class InfinityConnection(DocStoreConnection):
return res
def get(
self, chunkId: str, indexName: str, knowledgebaseIds: list[str]
self, chunkId: str, indexName: str, knowledgebaseIds: list[str]
) -> dict | None:
inf_conn = self.connPool.get_conn()
db_instance = inf_conn.get_database(self.dbName)
@ -299,7 +308,7 @@ class InfinityConnection(DocStoreConnection):
return res_fields.get(chunkId, None)
def insert(
self, documents: list[dict], indexName: str, knowledgebaseId: str
self, documents: list[dict], indexName: str, knowledgebaseId: str
) -> list[str]:
inf_conn = self.connPool.get_conn()
db_instance = inf_conn.get_database(self.dbName)
@ -341,7 +350,7 @@ class InfinityConnection(DocStoreConnection):
return []
def update(
self, condition: dict, newValue: dict, indexName: str, knowledgebaseId: str
self, condition: dict, newValue: dict, indexName: str, knowledgebaseId: str
) -> bool:
# if 'position_list' in newValue:
# logging.info(f"upsert position_list: {newValue['position_list']}")
@ -430,7 +439,7 @@ class InfinityConnection(DocStoreConnection):
flags=re.IGNORECASE | re.MULTILINE,
)
if not re.search(
r"<em>[^<>]+</em>", t, flags=re.IGNORECASE | re.MULTILINE
r"<em>[^<>]+</em>", t, flags=re.IGNORECASE | re.MULTILINE
):
continue
txts.append(t)