diff --git a/graph/component/base.py b/graph/component/base.py index f0f6367f8..b424d226f 100644 --- a/graph/component/base.py +++ b/graph/component/base.py @@ -445,6 +445,11 @@ class ComponentBase(ABC): if DEBUG: print(self.component_name, reversed_cpnts[::-1]) for u in reversed_cpnts[::-1]: if self.get_component_name(u) in ["switch"]: continue + if self.component_name.lower() == "generate" and self.get_component_name(u) == "retrieval": + o = self._canvas.get_component(u)["obj"].output(allow_partial=False)[1] + if o is not None: + upstream_outs.append(o) + continue if u not in self._canvas.get_component(self._id)["upstream"]: continue if self.component_name.lower().find("switch") < 0 \ and self.get_component_name(u) in ["relevant", "categorize"]: diff --git a/graph/component/generate.py b/graph/component/generate.py index 693bdba7a..83ade128d 100644 --- a/graph/component/generate.py +++ b/graph/component/generate.py @@ -72,7 +72,7 @@ class Generate(ComponentBase): prompt = self._param.prompt retrieval_res = self.get_input() - input = "\n- ".join(retrieval_res["content"]) + input = "\n- ".join(retrieval_res["content"]) if "content" in retrieval_res else "" for para in self._param.parameters: cpn = self._canvas.get_component(para["component_id"])["obj"] _, out = cpn.output(allow_partial=False)