From 8d9238db14bbca937966e98ee0d10314ec41e3ca Mon Sep 17 00:00:00 2001 From: Kevin Hu Date: Mon, 4 Nov 2024 09:53:41 +0800 Subject: [PATCH] fix es search parameter error (#3169) ### What problem does this PR solve? #3151 ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue) --- agent/canvas.py | 7 +++---- agent/component/base.py | 4 ++++ graphrag/search.py | 6 +++--- 3 files changed, 10 insertions(+), 7 deletions(-) diff --git a/agent/canvas.py b/agent/canvas.py index 45a231ed7..bd1042850 100644 --- a/agent/canvas.py +++ b/agent/canvas.py @@ -194,10 +194,9 @@ class Canvas(ABC): self.answer.append(c) else: if DEBUG: print("RUN: ", c) - if cpn.component_name == "Generate": - cpids = cpn.get_dependent_components() - if any([c not in self.path[-1] for c in cpids]): - continue + cpids = cpn.get_dependent_components() + if any([c not in self.path[-1] for c in cpids]): + continue ans = cpn.run(self.history, **kwargs) self.path[-1].append(c) ran += 1 diff --git a/agent/component/base.py b/agent/component/base.py index 1f1fd33de..bca2fc0be 100644 --- a/agent/component/base.py +++ b/agent/component/base.py @@ -397,6 +397,10 @@ class ComponentBase(ABC): self._param = param self._param.check() + def get_dependent_components(self): + cpnts = [para["component_id"] for para in self._param.query] + return cpnts + def run(self, history, **kwargs): flow_logger.info("{}, history: {}, kwargs: {}".format(self, json.dumps(history, ensure_ascii=False), json.dumps(kwargs, ensure_ascii=False))) diff --git a/graphrag/search.py b/graphrag/search.py index 85ba0698a..a5574466a 100644 --- a/graphrag/search.py +++ b/graphrag/search.py @@ -68,7 +68,7 @@ class KGSearch(Dealer): s["knn"]["filter"] = bqry.to_dict() q_vec = s["knn"]["query_vector"] - ent_res = self.es.search(deepcopy(s), idxnm=idxnm, timeout="600s", src=src) + ent_res = self.es.search(deepcopy(s), idxnms=idxnm, timeout="600s", src=src) entities = [d["name_kwd"] for d in self.es.getSource(ent_res)] ent_ids = self.es.getDocIds(ent_res) if merge_into_first(ent_res, "-Entities-"): @@ -81,7 +81,7 @@ class KGSearch(Dealer): s = Search() s = s.query(bqry)[0: 32] s = s.to_dict() - comm_res = self.es.search(deepcopy(s), idxnm=idxnm, timeout="600s", src=src) + comm_res = self.es.search(deepcopy(s), idxnms=idxnm, timeout="600s", src=src) comm_ids = self.es.getDocIds(comm_res) if merge_into_first(comm_res, "-Community Report-"): comm_ids = comm_ids[0:1] @@ -92,7 +92,7 @@ class KGSearch(Dealer): s = Search() s = s.query(bqry)[0: 6] s = s.to_dict() - txt_res = self.es.search(deepcopy(s), idxnm=idxnm, timeout="600s", src=src) + txt_res = self.es.search(deepcopy(s), idxnms=idxnm, timeout="600s", src=src) txt_ids = self.es.getDocIds(txt_res) if merge_into_first(txt_res, "-Original Content-"): txt_ids = txt_ids[0:1]