mirror of
https://git.mirrors.martin98.com/https://github.com/infiniflow/ragflow.git
synced 2025-04-22 22:20:07 +08:00

### What problem does this PR solve? ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue)
587 lines
22 KiB
Python
587 lines
22 KiB
Python
#
|
|
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
#
|
|
from abc import ABC
|
|
import builtins
|
|
import json
|
|
import os
|
|
import logging
|
|
from functools import partial
|
|
from typing import Tuple, Union
|
|
|
|
import pandas as pd
|
|
|
|
from agent import settings
|
|
|
|
_FEEDED_DEPRECATED_PARAMS = "_feeded_deprecated_params"
|
|
_DEPRECATED_PARAMS = "_deprecated_params"
|
|
_USER_FEEDED_PARAMS = "_user_feeded_params"
|
|
_IS_RAW_CONF = "_is_raw_conf"
|
|
|
|
|
|
class ComponentParamBase(ABC):
|
|
def __init__(self):
|
|
self.output_var_name = "output"
|
|
self.message_history_window_size = 22
|
|
self.query = []
|
|
self.inputs = []
|
|
self.debug_inputs = []
|
|
|
|
def set_name(self, name: str):
|
|
self._name = name
|
|
return self
|
|
|
|
def check(self):
|
|
raise NotImplementedError("Parameter Object should be checked.")
|
|
|
|
@classmethod
|
|
def _get_or_init_deprecated_params_set(cls):
|
|
if not hasattr(cls, _DEPRECATED_PARAMS):
|
|
setattr(cls, _DEPRECATED_PARAMS, set())
|
|
return getattr(cls, _DEPRECATED_PARAMS)
|
|
|
|
def _get_or_init_feeded_deprecated_params_set(self, conf=None):
|
|
if not hasattr(self, _FEEDED_DEPRECATED_PARAMS):
|
|
if conf is None:
|
|
setattr(self, _FEEDED_DEPRECATED_PARAMS, set())
|
|
else:
|
|
setattr(
|
|
self,
|
|
_FEEDED_DEPRECATED_PARAMS,
|
|
set(conf[_FEEDED_DEPRECATED_PARAMS]),
|
|
)
|
|
return getattr(self, _FEEDED_DEPRECATED_PARAMS)
|
|
|
|
def _get_or_init_user_feeded_params_set(self, conf=None):
|
|
if not hasattr(self, _USER_FEEDED_PARAMS):
|
|
if conf is None:
|
|
setattr(self, _USER_FEEDED_PARAMS, set())
|
|
else:
|
|
setattr(self, _USER_FEEDED_PARAMS, set(conf[_USER_FEEDED_PARAMS]))
|
|
return getattr(self, _USER_FEEDED_PARAMS)
|
|
|
|
def get_user_feeded(self):
|
|
return self._get_or_init_user_feeded_params_set()
|
|
|
|
def get_feeded_deprecated_params(self):
|
|
return self._get_or_init_feeded_deprecated_params_set()
|
|
|
|
@property
|
|
def _deprecated_params_set(self):
|
|
return {name: True for name in self.get_feeded_deprecated_params()}
|
|
|
|
def __str__(self):
|
|
return json.dumps(self.as_dict(), ensure_ascii=False)
|
|
|
|
def as_dict(self):
|
|
def _recursive_convert_obj_to_dict(obj):
|
|
ret_dict = {}
|
|
for attr_name in list(obj.__dict__):
|
|
if attr_name in [_FEEDED_DEPRECATED_PARAMS, _DEPRECATED_PARAMS, _USER_FEEDED_PARAMS, _IS_RAW_CONF]:
|
|
continue
|
|
# get attr
|
|
attr = getattr(obj, attr_name)
|
|
if isinstance(attr, pd.DataFrame):
|
|
ret_dict[attr_name] = attr.to_dict()
|
|
continue
|
|
if attr and type(attr).__name__ not in dir(builtins):
|
|
ret_dict[attr_name] = _recursive_convert_obj_to_dict(attr)
|
|
else:
|
|
ret_dict[attr_name] = attr
|
|
|
|
return ret_dict
|
|
|
|
return _recursive_convert_obj_to_dict(self)
|
|
|
|
def update(self, conf, allow_redundant=False):
|
|
update_from_raw_conf = conf.get(_IS_RAW_CONF, True)
|
|
if update_from_raw_conf:
|
|
deprecated_params_set = self._get_or_init_deprecated_params_set()
|
|
feeded_deprecated_params_set = (
|
|
self._get_or_init_feeded_deprecated_params_set()
|
|
)
|
|
user_feeded_params_set = self._get_or_init_user_feeded_params_set()
|
|
setattr(self, _IS_RAW_CONF, False)
|
|
else:
|
|
feeded_deprecated_params_set = (
|
|
self._get_or_init_feeded_deprecated_params_set(conf)
|
|
)
|
|
user_feeded_params_set = self._get_or_init_user_feeded_params_set(conf)
|
|
|
|
def _recursive_update_param(param, config, depth, prefix):
|
|
if depth > settings.PARAM_MAXDEPTH:
|
|
raise ValueError("Param define nesting too deep!!!, can not parse it")
|
|
|
|
inst_variables = param.__dict__
|
|
redundant_attrs = []
|
|
for config_key, config_value in config.items():
|
|
# redundant attr
|
|
if config_key not in inst_variables:
|
|
if not update_from_raw_conf and config_key.startswith("_"):
|
|
setattr(param, config_key, config_value)
|
|
else:
|
|
setattr(param, config_key, config_value)
|
|
# redundant_attrs.append(config_key)
|
|
continue
|
|
|
|
full_config_key = f"{prefix}{config_key}"
|
|
|
|
if update_from_raw_conf:
|
|
# add user feeded params
|
|
user_feeded_params_set.add(full_config_key)
|
|
|
|
# update user feeded deprecated param set
|
|
if full_config_key in deprecated_params_set:
|
|
feeded_deprecated_params_set.add(full_config_key)
|
|
|
|
# supported attr
|
|
attr = getattr(param, config_key)
|
|
if type(attr).__name__ in dir(builtins) or attr is None:
|
|
setattr(param, config_key, config_value)
|
|
|
|
else:
|
|
# recursive set obj attr
|
|
sub_params = _recursive_update_param(
|
|
attr, config_value, depth + 1, prefix=f"{prefix}{config_key}."
|
|
)
|
|
setattr(param, config_key, sub_params)
|
|
|
|
if not allow_redundant and redundant_attrs:
|
|
raise ValueError(
|
|
f"cpn `{getattr(self, '_name', type(self))}` has redundant parameters: `{[redundant_attrs]}`"
|
|
)
|
|
|
|
return param
|
|
|
|
return _recursive_update_param(param=self, config=conf, depth=0, prefix="")
|
|
|
|
def extract_not_builtin(self):
|
|
def _get_not_builtin_types(obj):
|
|
ret_dict = {}
|
|
for variable in obj.__dict__:
|
|
attr = getattr(obj, variable)
|
|
if attr and type(attr).__name__ not in dir(builtins):
|
|
ret_dict[variable] = _get_not_builtin_types(attr)
|
|
|
|
return ret_dict
|
|
|
|
return _get_not_builtin_types(self)
|
|
|
|
def validate(self):
|
|
self.builtin_types = dir(builtins)
|
|
self.func = {
|
|
"ge": self._greater_equal_than,
|
|
"le": self._less_equal_than,
|
|
"in": self._in,
|
|
"not_in": self._not_in,
|
|
"range": self._range,
|
|
}
|
|
home_dir = os.path.abspath(os.path.dirname(os.path.realpath(__file__)))
|
|
param_validation_path_prefix = home_dir + "/param_validation/"
|
|
|
|
param_name = type(self).__name__
|
|
param_validation_path = "/".join(
|
|
[param_validation_path_prefix, param_name + ".json"]
|
|
)
|
|
|
|
validation_json = None
|
|
|
|
try:
|
|
with open(param_validation_path, "r") as fin:
|
|
validation_json = json.loads(fin.read())
|
|
except BaseException:
|
|
return
|
|
|
|
self._validate_param(self, validation_json)
|
|
|
|
def _validate_param(self, param_obj, validation_json):
|
|
default_section = type(param_obj).__name__
|
|
var_list = param_obj.__dict__
|
|
|
|
for variable in var_list:
|
|
attr = getattr(param_obj, variable)
|
|
|
|
if type(attr).__name__ in self.builtin_types or attr is None:
|
|
if variable not in validation_json:
|
|
continue
|
|
|
|
validation_dict = validation_json[default_section][variable]
|
|
value = getattr(param_obj, variable)
|
|
value_legal = False
|
|
|
|
for op_type in validation_dict:
|
|
if self.func[op_type](value, validation_dict[op_type]):
|
|
value_legal = True
|
|
break
|
|
|
|
if not value_legal:
|
|
raise ValueError(
|
|
"Plase check runtime conf, {} = {} does not match user-parameter restriction".format(
|
|
variable, value
|
|
)
|
|
)
|
|
|
|
elif variable in validation_json:
|
|
self._validate_param(attr, validation_json)
|
|
|
|
@staticmethod
|
|
def check_string(param, descr):
|
|
if type(param).__name__ not in ["str"]:
|
|
raise ValueError(
|
|
descr + " {} not supported, should be string type".format(param)
|
|
)
|
|
|
|
@staticmethod
|
|
def check_empty(param, descr):
|
|
if not param:
|
|
raise ValueError(
|
|
descr + " does not support empty value."
|
|
)
|
|
|
|
@staticmethod
|
|
def check_positive_integer(param, descr):
|
|
if type(param).__name__ not in ["int", "long"] or param <= 0:
|
|
raise ValueError(
|
|
descr + " {} not supported, should be positive integer".format(param)
|
|
)
|
|
|
|
@staticmethod
|
|
def check_positive_number(param, descr):
|
|
if type(param).__name__ not in ["float", "int", "long"] or param <= 0:
|
|
raise ValueError(
|
|
descr + " {} not supported, should be positive numeric".format(param)
|
|
)
|
|
|
|
@staticmethod
|
|
def check_nonnegative_number(param, descr):
|
|
if type(param).__name__ not in ["float", "int", "long"] or param < 0:
|
|
raise ValueError(
|
|
descr
|
|
+ " {} not supported, should be non-negative numeric".format(param)
|
|
)
|
|
|
|
@staticmethod
|
|
def check_decimal_float(param, descr):
|
|
if type(param).__name__ not in ["float", "int"] or param < 0 or param > 1:
|
|
raise ValueError(
|
|
descr
|
|
+ " {} not supported, should be a float number in range [0, 1]".format(
|
|
param
|
|
)
|
|
)
|
|
|
|
@staticmethod
|
|
def check_boolean(param, descr):
|
|
if type(param).__name__ != "bool":
|
|
raise ValueError(
|
|
descr + " {} not supported, should be bool type".format(param)
|
|
)
|
|
|
|
@staticmethod
|
|
def check_open_unit_interval(param, descr):
|
|
if type(param).__name__ not in ["float"] or param <= 0 or param >= 1:
|
|
raise ValueError(
|
|
descr + " should be a numeric number between 0 and 1 exclusively"
|
|
)
|
|
|
|
@staticmethod
|
|
def check_valid_value(param, descr, valid_values):
|
|
if param not in valid_values:
|
|
raise ValueError(
|
|
descr
|
|
+ " {} is not supported, it should be in {}".format(param, valid_values)
|
|
)
|
|
|
|
@staticmethod
|
|
def check_defined_type(param, descr, types):
|
|
if type(param).__name__ not in types:
|
|
raise ValueError(
|
|
descr + " {} not supported, should be one of {}".format(param, types)
|
|
)
|
|
|
|
@staticmethod
|
|
def check_and_change_lower(param, valid_list, descr=""):
|
|
if type(param).__name__ != "str":
|
|
raise ValueError(
|
|
descr
|
|
+ " {} not supported, should be one of {}".format(param, valid_list)
|
|
)
|
|
|
|
lower_param = param.lower()
|
|
if lower_param in valid_list:
|
|
return lower_param
|
|
else:
|
|
raise ValueError(
|
|
descr
|
|
+ " {} not supported, should be one of {}".format(param, valid_list)
|
|
)
|
|
|
|
@staticmethod
|
|
def _greater_equal_than(value, limit):
|
|
return value >= limit - settings.FLOAT_ZERO
|
|
|
|
@staticmethod
|
|
def _less_equal_than(value, limit):
|
|
return value <= limit + settings.FLOAT_ZERO
|
|
|
|
@staticmethod
|
|
def _range(value, ranges):
|
|
in_range = False
|
|
for left_limit, right_limit in ranges:
|
|
if (
|
|
left_limit - settings.FLOAT_ZERO
|
|
<= value
|
|
<= right_limit + settings.FLOAT_ZERO
|
|
):
|
|
in_range = True
|
|
break
|
|
|
|
return in_range
|
|
|
|
@staticmethod
|
|
def _in(value, right_value_list):
|
|
return value in right_value_list
|
|
|
|
@staticmethod
|
|
def _not_in(value, wrong_value_list):
|
|
return value not in wrong_value_list
|
|
|
|
def _warn_deprecated_param(self, param_name, descr):
|
|
if self._deprecated_params_set.get(param_name):
|
|
logging.warning(
|
|
f"{descr} {param_name} is deprecated and ignored in this version."
|
|
)
|
|
|
|
def _warn_to_deprecate_param(self, param_name, descr, new_param):
|
|
if self._deprecated_params_set.get(param_name):
|
|
logging.warning(
|
|
f"{descr} {param_name} will be deprecated in future release; "
|
|
f"please use {new_param} instead."
|
|
)
|
|
return True
|
|
return False
|
|
|
|
|
|
class ComponentBase(ABC):
|
|
component_name: str
|
|
|
|
def __str__(self):
|
|
"""
|
|
{
|
|
"component_name": "Begin",
|
|
"params": {}
|
|
}
|
|
"""
|
|
return """{{
|
|
"component_name": "{}",
|
|
"params": {},
|
|
"output": {},
|
|
"inputs": {}
|
|
}}""".format(self.component_name,
|
|
self._param,
|
|
json.dumps(json.loads(str(self._param)).get("output", {}), ensure_ascii=False),
|
|
json.dumps(json.loads(str(self._param)).get("inputs", []), ensure_ascii=False)
|
|
)
|
|
|
|
def __init__(self, canvas, id, param: ComponentParamBase):
|
|
self._canvas = canvas
|
|
self._id = id
|
|
self._param = param
|
|
self._param.check()
|
|
|
|
def get_dependent_components(self):
|
|
cpnts = set([para["component_id"].split("@")[0] for para in self._param.query \
|
|
if para.get("component_id") \
|
|
and para["component_id"].lower().find("answer") < 0 \
|
|
and para["component_id"].lower().find("begin") < 0])
|
|
return list(cpnts)
|
|
|
|
def run(self, history, **kwargs):
|
|
logging.debug("{}, history: {}, kwargs: {}".format(self, json.dumps(history, ensure_ascii=False),
|
|
json.dumps(kwargs, ensure_ascii=False)))
|
|
self._param.debug_inputs = []
|
|
try:
|
|
res = self._run(history, **kwargs)
|
|
self.set_output(res)
|
|
except Exception as e:
|
|
self.set_output(pd.DataFrame([{"content": str(e)}]))
|
|
raise e
|
|
|
|
return res
|
|
|
|
def _run(self, history, **kwargs):
|
|
raise NotImplementedError()
|
|
|
|
def output(self, allow_partial=True) -> Tuple[str, Union[pd.DataFrame, partial]]:
|
|
o = getattr(self._param, self._param.output_var_name)
|
|
if not isinstance(o, partial):
|
|
if not isinstance(o, pd.DataFrame):
|
|
if isinstance(o, list):
|
|
return self._param.output_var_name, pd.DataFrame(o)
|
|
if o is None:
|
|
return self._param.output_var_name, pd.DataFrame()
|
|
return self._param.output_var_name, pd.DataFrame([{"content": str(o)}])
|
|
return self._param.output_var_name, o
|
|
|
|
if allow_partial or not isinstance(o, partial):
|
|
if not isinstance(o, partial) and not isinstance(o, pd.DataFrame):
|
|
return pd.DataFrame(o if isinstance(o, list) else [o])
|
|
return self._param.output_var_name, o
|
|
|
|
outs = None
|
|
for oo in o():
|
|
if not isinstance(oo, pd.DataFrame):
|
|
outs = pd.DataFrame(oo if isinstance(oo, list) else [oo])
|
|
else:
|
|
outs = oo
|
|
return self._param.output_var_name, outs
|
|
|
|
def reset(self):
|
|
setattr(self._param, self._param.output_var_name, None)
|
|
self._param.inputs = []
|
|
|
|
def set_output(self, v):
|
|
setattr(self._param, self._param.output_var_name, v)
|
|
|
|
def get_input(self):
|
|
if self._param.debug_inputs:
|
|
return pd.DataFrame([{"content": v["value"]} for v in self._param.debug_inputs if v.get("value")])
|
|
|
|
reversed_cpnts = []
|
|
if len(self._canvas.path) > 1:
|
|
reversed_cpnts.extend(self._canvas.path[-2])
|
|
reversed_cpnts.extend(self._canvas.path[-1])
|
|
|
|
if self._param.query:
|
|
self._param.inputs = []
|
|
outs = []
|
|
for q in self._param.query:
|
|
if q.get("component_id"):
|
|
if q["component_id"].split("@")[0].lower().find("begin") >= 0:
|
|
cpn_id, key = q["component_id"].split("@")
|
|
for p in self._canvas.get_component(cpn_id)["obj"]._param.query:
|
|
if p["key"] == key:
|
|
outs.append(pd.DataFrame([{"content": p.get("value", "")}]))
|
|
self._param.inputs.append({"component_id": q["component_id"],
|
|
"content": p.get("value", "")})
|
|
break
|
|
else:
|
|
assert False, f"Can't find parameter '{key}' for {cpn_id}"
|
|
continue
|
|
|
|
if q["component_id"].lower().find("answer") == 0:
|
|
txt = []
|
|
for r, c in self._canvas.history[::-1][:self._param.message_history_window_size][::-1]:
|
|
txt.append(f"{r.upper()}: {c}")
|
|
txt = "\n".join(txt)
|
|
self._param.inputs.append({"content": txt, "component_id": q["component_id"]})
|
|
outs.append(pd.DataFrame([{"content": txt}]))
|
|
continue
|
|
|
|
outs.append(self._canvas.get_component(q["component_id"])["obj"].output(allow_partial=False)[1])
|
|
self._param.inputs.append({"component_id": q["component_id"],
|
|
"content": "\n".join(
|
|
[str(d["content"]) for d in outs[-1].to_dict('records')])})
|
|
elif q.get("value"):
|
|
self._param.inputs.append({"component_id": None, "content": q["value"]})
|
|
outs.append(pd.DataFrame([{"content": q["value"]}]))
|
|
if outs:
|
|
df = pd.concat(outs, ignore_index=True)
|
|
if "content" in df:
|
|
df = df.drop_duplicates(subset=['content']).reset_index(drop=True)
|
|
return df
|
|
|
|
upstream_outs = []
|
|
|
|
for u in reversed_cpnts[::-1]:
|
|
if self.get_component_name(u) in ["switch", "concentrator"]:
|
|
continue
|
|
if self.component_name.lower() == "generate" and self.get_component_name(u) == "retrieval":
|
|
o = self._canvas.get_component(u)["obj"].output(allow_partial=False)[1]
|
|
if o is not None:
|
|
o["component_id"] = u
|
|
upstream_outs.append(o)
|
|
continue
|
|
#if self.component_name.lower()!="answer" and u not in self._canvas.get_component(self._id)["upstream"]: continue
|
|
if self.component_name.lower().find("switch") < 0 \
|
|
and self.get_component_name(u) in ["relevant", "categorize"]:
|
|
continue
|
|
if u.lower().find("answer") >= 0:
|
|
for r, c in self._canvas.history[::-1]:
|
|
if r == "user":
|
|
upstream_outs.append(pd.DataFrame([{"content": c, "component_id": u}]))
|
|
break
|
|
break
|
|
if self.component_name.lower().find("answer") >= 0 and self.get_component_name(u) in ["relevant"]:
|
|
continue
|
|
o = self._canvas.get_component(u)["obj"].output(allow_partial=False)[1]
|
|
if o is not None:
|
|
o["component_id"] = u
|
|
upstream_outs.append(o)
|
|
break
|
|
|
|
assert upstream_outs, "Can't inference the where the component input is. Please identify whose output is this component's input."
|
|
|
|
df = pd.concat(upstream_outs, ignore_index=True)
|
|
if "content" in df:
|
|
df = df.drop_duplicates(subset=['content']).reset_index(drop=True)
|
|
|
|
self._param.inputs = []
|
|
for _, r in df.iterrows():
|
|
self._param.inputs.append({"component_id": r["component_id"], "content": r["content"]})
|
|
|
|
return df
|
|
|
|
def get_input_elements(self):
|
|
assert self._param.query, "Please identify input parameters firstly."
|
|
eles = []
|
|
for q in self._param.query:
|
|
if q.get("component_id"):
|
|
cpn_id = q["component_id"]
|
|
if cpn_id.split("@")[0].lower().find("begin") >= 0:
|
|
cpn_id, key = cpn_id.split("@")
|
|
eles.extend(self._canvas.get_component(cpn_id)["obj"]._param.query)
|
|
continue
|
|
|
|
eles.append({"name": self._canvas.get_component_name(cpn_id), "key": cpn_id})
|
|
else:
|
|
eles.append({"key": q["value"], "name": q["value"], "value": q["value"]})
|
|
return eles
|
|
|
|
def get_stream_input(self):
|
|
reversed_cpnts = []
|
|
if len(self._canvas.path) > 1:
|
|
reversed_cpnts.extend(self._canvas.path[-2])
|
|
reversed_cpnts.extend(self._canvas.path[-1])
|
|
|
|
for u in reversed_cpnts[::-1]:
|
|
if self.get_component_name(u) in ["switch", "answer"]:
|
|
continue
|
|
return self._canvas.get_component(u)["obj"].output()[1]
|
|
|
|
@staticmethod
|
|
def be_output(v):
|
|
return pd.DataFrame([{"content": v}])
|
|
|
|
def get_component_name(self, cpn_id):
|
|
return self._canvas.get_component(cpn_id)["obj"].component_name.lower()
|
|
|
|
def debug(self, **kwargs):
|
|
return self._run([], **kwargs)
|
|
|
|
def get_parent(self):
|
|
pid = self._canvas.get_component(self._id)["parent_id"]
|
|
return self._canvas.get_component(pid)["obj"]
|