diff --git a/agent/component/base.py b/agent/component/base.py index e50b297de..8bdddb38e 100644 --- a/agent/component/base.py +++ b/agent/component/base.py @@ -13,11 +13,11 @@ # See the License for the specific language governing permissions and # limitations under the License. # -from abc import ABC import builtins import json -import os import logging +import os +from abc import ABC from functools import partial from typing import Any, Tuple, Union @@ -46,9 +46,6 @@ class ComponentParamBase(ABC): def check(self): raise NotImplementedError("Parameter Object should be checked.") - - def output(self): - return None @classmethod def _get_or_init_deprecated_params_set(cls): @@ -113,15 +110,11 @@ class ComponentParamBase(ABC): update_from_raw_conf = conf.get(_IS_RAW_CONF, True) if update_from_raw_conf: deprecated_params_set = self._get_or_init_deprecated_params_set() - feeded_deprecated_params_set = ( - self._get_or_init_feeded_deprecated_params_set() - ) + feeded_deprecated_params_set = self._get_or_init_feeded_deprecated_params_set() user_feeded_params_set = self._get_or_init_user_feeded_params_set() setattr(self, _IS_RAW_CONF, False) else: - feeded_deprecated_params_set = ( - self._get_or_init_feeded_deprecated_params_set(conf) - ) + feeded_deprecated_params_set = self._get_or_init_feeded_deprecated_params_set(conf) user_feeded_params_set = self._get_or_init_user_feeded_params_set(conf) def _recursive_update_param(param, config, depth, prefix): @@ -157,15 +150,11 @@ class ComponentParamBase(ABC): else: # recursive set obj attr - sub_params = _recursive_update_param( - attr, config_value, depth + 1, prefix=f"{prefix}{config_key}." - ) + sub_params = _recursive_update_param(attr, config_value, depth + 1, prefix=f"{prefix}{config_key}.") setattr(param, config_key, sub_params) if not allow_redundant and redundant_attrs: - raise ValueError( - f"cpn `{getattr(self, '_name', type(self))}` has redundant parameters: `{[redundant_attrs]}`" - ) + raise ValueError(f"cpn `{getattr(self, '_name', type(self))}` has redundant parameters: `{[redundant_attrs]}`") return param @@ -196,9 +185,7 @@ class ComponentParamBase(ABC): param_validation_path_prefix = home_dir + "/param_validation/" param_name = type(self).__name__ - param_validation_path = "/".join( - [param_validation_path_prefix, param_name + ".json"] - ) + param_validation_path = "/".join([param_validation_path_prefix, param_name + ".json"]) validation_json = None @@ -231,11 +218,7 @@ class ComponentParamBase(ABC): break if not value_legal: - raise ValueError( - "Plase check runtime conf, {} = {} does not match user-parameter restriction".format( - variable, value - ) - ) + raise ValueError("Plase check runtime conf, {} = {} does not match user-parameter restriction".format(variable, value)) elif variable in validation_json: self._validate_param(attr, validation_json) @@ -243,94 +226,63 @@ class ComponentParamBase(ABC): @staticmethod def check_string(param, descr): if type(param).__name__ not in ["str"]: - raise ValueError( - descr + " {} not supported, should be string type".format(param) - ) + raise ValueError(descr + " {} not supported, should be string type".format(param)) @staticmethod def check_empty(param, descr): if not param: - raise ValueError( - descr + " does not support empty value." - ) + raise ValueError(descr + " does not support empty value.") @staticmethod def check_positive_integer(param, descr): if type(param).__name__ not in ["int", "long"] or param <= 0: - raise ValueError( - descr + " {} not supported, should be positive integer".format(param) - ) + raise ValueError(descr + " {} not supported, should be positive integer".format(param)) @staticmethod def check_positive_number(param, descr): if type(param).__name__ not in ["float", "int", "long"] or param <= 0: - raise ValueError( - descr + " {} not supported, should be positive numeric".format(param) - ) + raise ValueError(descr + " {} not supported, should be positive numeric".format(param)) @staticmethod def check_nonnegative_number(param, descr): if type(param).__name__ not in ["float", "int", "long"] or param < 0: - raise ValueError( - descr - + " {} not supported, should be non-negative numeric".format(param) - ) + raise ValueError(descr + " {} not supported, should be non-negative numeric".format(param)) @staticmethod def check_decimal_float(param, descr): if type(param).__name__ not in ["float", "int"] or param < 0 or param > 1: - raise ValueError( - descr - + " {} not supported, should be a float number in range [0, 1]".format( - param - ) - ) + raise ValueError(descr + " {} not supported, should be a float number in range [0, 1]".format(param)) @staticmethod def check_boolean(param, descr): if type(param).__name__ != "bool": - raise ValueError( - descr + " {} not supported, should be bool type".format(param) - ) + raise ValueError(descr + " {} not supported, should be bool type".format(param)) @staticmethod def check_open_unit_interval(param, descr): if type(param).__name__ not in ["float"] or param <= 0 or param >= 1: - raise ValueError( - descr + " should be a numeric number between 0 and 1 exclusively" - ) + raise ValueError(descr + " should be a numeric number between 0 and 1 exclusively") @staticmethod def check_valid_value(param, descr, valid_values): if param not in valid_values: - raise ValueError( - descr - + " {} is not supported, it should be in {}".format(param, valid_values) - ) + raise ValueError(descr + " {} is not supported, it should be in {}".format(param, valid_values)) @staticmethod def check_defined_type(param, descr, types): if type(param).__name__ not in types: - raise ValueError( - descr + " {} not supported, should be one of {}".format(param, types) - ) + raise ValueError(descr + " {} not supported, should be one of {}".format(param, types)) @staticmethod def check_and_change_lower(param, valid_list, descr=""): if type(param).__name__ != "str": - raise ValueError( - descr - + " {} not supported, should be one of {}".format(param, valid_list) - ) + raise ValueError(descr + " {} not supported, should be one of {}".format(param, valid_list)) lower_param = param.lower() if lower_param in valid_list: return lower_param else: - raise ValueError( - descr - + " {} not supported, should be one of {}".format(param, valid_list) - ) + raise ValueError(descr + " {} not supported, should be one of {}".format(param, valid_list)) @staticmethod def _greater_equal_than(value, limit): @@ -344,11 +296,7 @@ class ComponentParamBase(ABC): def _range(value, ranges): in_range = False for left_limit, right_limit in ranges: - if ( - left_limit - settings.FLOAT_ZERO - <= value - <= right_limit + settings.FLOAT_ZERO - ): + if left_limit - settings.FLOAT_ZERO <= value <= right_limit + settings.FLOAT_ZERO: in_range = True break @@ -364,16 +312,11 @@ class ComponentParamBase(ABC): def _warn_deprecated_param(self, param_name, descr): if self._deprecated_params_set.get(param_name): - logging.warning( - f"{descr} {param_name} is deprecated and ignored in this version." - ) + logging.warning(f"{descr} {param_name} is deprecated and ignored in this version.") def _warn_to_deprecate_param(self, param_name, descr, new_param): if self._deprecated_params_set.get(param_name): - logging.warning( - f"{descr} {param_name} will be deprecated in future release; " - f"please use {new_param} instead." - ) + logging.warning(f"{descr} {param_name} will be deprecated in future release; please use {new_param} instead.") return True return False @@ -398,14 +341,16 @@ class ComponentBase(ABC): "params": {}, "output": {}, "inputs": {} - }}""".format(self.component_name, - self._param, - json.dumps(json.loads(str(self._param)).get("output", {}), ensure_ascii=False), - json.dumps(json.loads(str(self._param)).get("inputs", []), ensure_ascii=False) + }}""".format( + self.component_name, + self._param, + json.dumps(json.loads(str(self._param)).get("output", {}), ensure_ascii=False), + json.dumps(json.loads(str(self._param)).get("inputs", []), ensure_ascii=False), ) def __init__(self, canvas, id, param: ComponentParamBase): from agent.canvas import Canvas # Local import to avoid cyclic dependency + assert isinstance(canvas, Canvas), "canvas must be an instance of Canvas" self._canvas = canvas self._id = id @@ -413,15 +358,17 @@ class ComponentBase(ABC): self._param.check() def get_dependent_components(self): - cpnts = set([para["component_id"].split("@")[0] for para in self._param.query \ - if para.get("component_id") \ - and para["component_id"].lower().find("answer") < 0 \ - and para["component_id"].lower().find("begin") < 0]) + cpnts = set( + [ + para["component_id"].split("@")[0] + for para in self._param.query + if para.get("component_id") and para["component_id"].lower().find("answer") < 0 and para["component_id"].lower().find("begin") < 0 + ] + ) return list(cpnts) def run(self, history, **kwargs): - logging.debug("{}, history: {}, kwargs: {}".format(self, json.dumps(history, ensure_ascii=False), - json.dumps(kwargs, ensure_ascii=False))) + logging.debug("{}, history: {}, kwargs: {}".format(self, json.dumps(history, ensure_ascii=False), json.dumps(kwargs, ensure_ascii=False))) self._param.debug_inputs = [] try: res = self._run(history, **kwargs) @@ -468,7 +415,7 @@ class ComponentBase(ABC): def set_infor(self, v): setattr(self._param, self._param.infor_var_name, v) - + def _fetch_outputs_from(self, sources: list[dict[str, Any]]) -> list[pd.DataFrame]: outs = [] for q in sources: @@ -485,7 +432,7 @@ class ComponentBase(ABC): if q["component_id"].lower().find("answer") == 0: txt = [] - for r, c in self._canvas.history[::-1][:self._param.message_history_window_size][::-1]: + for r, c in self._canvas.history[::-1][: self._param.message_history_window_size][::-1]: txt.append(f"{r.upper()}:{c}") txt = "\n".join(txt) outs.append(pd.DataFrame([{"content": txt}])) @@ -495,6 +442,7 @@ class ComponentBase(ABC): elif q.get("value"): outs.append(pd.DataFrame([{"content": q["value"]}])) return outs + def get_input(self): if self._param.debug_inputs: return pd.DataFrame([{"content": v["value"]} for v in self._param.debug_inputs if v.get("value")]) @@ -515,21 +463,16 @@ class ComponentBase(ABC): content: str if len(records) > 1: - content = "\n".join( - [str(d["content"]) for d in records] - ) + content = "\n".join([str(d["content"]) for d in records]) else: content = records[0]["content"] - self._param.inputs.append({ - "component_id": records[0].get("component_id"), - "content": content - }) + self._param.inputs.append({"component_id": records[0].get("component_id"), "content": content}) if outs: df = pd.concat(outs, ignore_index=True) if "content" in df: - df = df.drop_duplicates(subset=['content']).reset_index(drop=True) + df = df.drop_duplicates(subset=["content"]).reset_index(drop=True) return df upstream_outs = [] @@ -543,9 +486,8 @@ class ComponentBase(ABC): o["component_id"] = u upstream_outs.append(o) continue - #if self.component_name.lower()!="answer" and u not in self._canvas.get_component(self._id)["upstream"]: continue - if self.component_name.lower().find("switch") < 0 \ - and self.get_component_name(u) in ["relevant", "categorize"]: + # if self.component_name.lower()!="answer" and u not in self._canvas.get_component(self._id)["upstream"]: continue + if self.component_name.lower().find("switch") < 0 and self.get_component_name(u) in ["relevant", "categorize"]: continue if u.lower().find("answer") >= 0: for r, c in self._canvas.history[::-1]: @@ -565,7 +507,7 @@ class ComponentBase(ABC): df = pd.concat(upstream_outs, ignore_index=True) if "content" in df: - df = df.drop_duplicates(subset=['content']).reset_index(drop=True) + df = df.drop_duplicates(subset=["content"]).reset_index(drop=True) self._param.inputs = [] for _, r in df.iterrows(): @@ -617,5 +559,5 @@ class ComponentBase(ABC): return self._canvas.get_component(pid)["obj"] def get_upstream(self): - cpn_nms = self._canvas.get_component(self._id)['upstream'] + cpn_nms = self._canvas.get_component(self._id)["upstream"] return cpn_nms diff --git a/agent/component/code.py b/agent/component/code.py index 0abf0b472..baa38324c 100644 --- a/agent/component/code.py +++ b/agent/component/code.py @@ -89,13 +89,23 @@ class Code(ComponentBase, ABC): if "value" in param: arguments[input["name"]] = param["value"] else: - cpn = self._canvas.get_component(input["component_id"])["obj"] - if cpn.component_name.lower() == "answer": + refered_component = self._canvas.get_component(input["component_id"])["obj"] + refered_component_name = refered_component.component_name + refered_component_id = refered_component._id + if refered_component_name.lower() == "answer": arguments[input["name"]] = self._canvas.get_history(1)[0]["content"] continue - _, out = cpn.output(allow_partial=False) - if not out.empty: - arguments[input["name"]] = "\n".join(out["content"]) + + debug_inputs = self._param.debug_inputs + if debug_inputs: + for param in debug_inputs: + if param["key"] == refered_component_id: + if "value" in param and param["name"] == input["name"]: + arguments[input["name"]] = param["value"] + else: + _, out = refered_component.output(allow_partial=False) + if not out.empty: + arguments[input["name"]] = "\n".join(out["content"]) return self._execute_code( language=self._param.lang,