diff --git a/api/core/workflow/graph_engine/entities/graph.py b/api/core/workflow/graph_engine/entities/graph.py index 7897749687..aa32002668 100644 --- a/api/core/workflow/graph_engine/entities/graph.py +++ b/api/core/workflow/graph_engine/entities/graph.py @@ -1,6 +1,6 @@ from typing import Optional -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, model_validator from core.workflow.graph_engine.condition_handlers.base_handler import RunConditionHandler from core.workflow.graph_engine.condition_handlers.condition_manager import ConditionManager @@ -33,7 +33,8 @@ class GraphNode(BaseModel): """sub graph of the node, e.g. iteration/loop sub graph""" def add_child(self, node_id: str) -> None: - self.descendant_node_ids.append(node_id) + if node_id not in self.descendant_node_ids: + self.descendant_node_ids.append(node_id) def get_run_condition_handler(self) -> Optional[RunConditionHandler]: """ @@ -56,6 +57,12 @@ class Graph(BaseModel): root_node: GraphNode """root node of the graph""" + @model_validator(mode='after') + def add_root_node(cls, values): + root_node = values.root_node + values.graph_nodes[root_node.id] = root_node + return values + @classmethod def init(cls, root_node_config: dict, run_condition: Optional[RunCondition] = None) -> "Graph": """ @@ -76,10 +83,7 @@ class Graph(BaseModel): run_condition=run_condition ) - graph = cls(root_node=root_node) - - graph._add_graph_node(graph.root_node) - return graph + return cls(root_node=root_node) def add_edge(self, edge_config: dict, source_node_config: dict, @@ -106,7 +110,10 @@ class Graph(BaseModel): if not target_node_id: return - source_node = self.graph_nodes[source_node_id] + source_node = self.graph_nodes.get(source_node_id) + if not source_node: + return + source_node.add_child(target_node_id) if target_node_id not in self.graph_nodes: @@ -120,45 +127,66 @@ class Graph(BaseModel): sub_graph=target_node_sub_graph ) - self._add_graph_node(target_graph_node) + self.add_graph_node(target_graph_node) else: - target_node = self.graph_nodes[target_node_id] + target_node = self.graph_nodes.get(target_node_id) + if not target_node: + return + target_node.predecessor_node_id = source_node_id target_node.run_condition = run_condition target_node.source_edge_config = edge_config target_node.sub_graph = target_node_sub_graph - def get_root_node(self) -> Optional[GraphNode]: + def get_leaf_nodes(self) -> list[GraphNode]: """ - Get root node of the graph + Get leaf nodes of the graph - :return: root node + :return: leaf nodes """ - return self.root_node + leaf_nodes = [] + for node_id, graph_node in self.graph_nodes.items(): + if ( + not graph_node.descendant_node_ids # has no child + or # or has only one child and the child is the root node + ( + graph_node.descendant_node_ids + and graph_node.descendant_node_ids[0] == self.root_node.id + ) + ): + leaf_nodes.append(graph_node) - def get_descendants_graph(self, node_id: str) -> Optional["Graph"]: + return leaf_nodes + + def get_descendant_graphs(self, node_id: str) -> list["Graph"]: """ - Get descendants graph of the specific node + Get descendant graphs of the specific node :param node_id: node id - :return: descendants graph + :return: descendant graphs """ if node_id not in self.graph_nodes: - return None + return [] - graph_node = self.graph_nodes[node_id] - if not graph_node.descendant_node_ids: - return None + graph_node = self.graph_nodes.get(node_id) + if not graph_node or not graph_node.descendant_node_ids: + return [] - descendants_graph = Graph(root_node=graph_node) - descendants_graph._add_graph_node(graph_node) + descendant_graphs: list[Graph] = [] + for descendant_node_id in graph_node.descendant_node_ids: + descendant_graph_node = self.graph_nodes.get(descendant_node_id) + if not descendant_graph_node: + continue - for child_node_id in graph_node.descendant_node_ids: - self._add_descendants_graph_nodes(descendants_graph, child_node_id) + descendants_graph = Graph(root_node=descendant_graph_node) + for sub_descendant_node_id in descendant_graph_node.descendant_node_ids: + descendants_graph.add_descendants_graph_nodes(self, sub_descendant_node_id) - return descendants_graph + descendant_graphs.append(descendants_graph) - def _add_graph_node(self, graph_node: GraphNode) -> None: + return descendant_graphs + + def add_graph_node(self, graph_node: GraphNode) -> None: """ Add graph node to the graph @@ -167,23 +195,24 @@ class Graph(BaseModel): if graph_node.id in self.graph_nodes: return - if len(self.graph_nodes) == 0: - self.root_node = graph_node - self.graph_nodes[graph_node.id] = graph_node - def _add_descendants_graph_nodes(self, descendants_graph: "Graph", node_id: str) -> None: + def add_descendants_graph_nodes(self, predecessor_graph: "Graph", node_id: str) -> None: """ Add descendants graph nodes - :param descendants_graph: descendants graph + :param predecessor_graph: predecessor graph :param node_id: node id """ - if node_id not in self.graph_nodes: + if node_id not in predecessor_graph.graph_nodes: return - graph_node = self.graph_nodes[node_id] - descendants_graph._add_graph_node(graph_node) + graph_node = predecessor_graph.graph_nodes.get(node_id) + if not graph_node: + return - for child_node_id in graph_node.descendant_node_ids: - self._add_descendants_graph_nodes(descendants_graph, child_node_id) + if graph_node.id not in self.graph_nodes: + self.add_graph_node(graph_node) + + for child_node_id in graph_node.descendant_node_ids: + self.add_descendants_graph_nodes(predecessor_graph, child_node_id) diff --git a/api/core/workflow/workflow_entry.py b/api/core/workflow/workflow_entry.py index 03f20421f1..de59302c18 100644 --- a/api/core/workflow/workflow_entry.py +++ b/api/core/workflow/workflow_entry.py @@ -298,6 +298,11 @@ class WorkflowEntry: root_node_configs=root_node_configs ) + # add edge from end node to first node of sub graph + sub_graph_root_node_id = sub_graph.root_node.id + for leaf_node in sub_graph.get_leaf_nodes(): + leaf_node.add_child(sub_graph_root_node_id) + # parse run condition run_condition = None if edge_config.get('sourceHandle'): diff --git a/api/tests/unit_tests/core/workflow/test_workflow_entry.py b/api/tests/unit_tests/core/workflow/test_workflow_entry.py index 5b9d0af3cd..47e62aff8c 100644 --- a/api/tests/unit_tests/core/workflow/test_workflow_entry.py +++ b/api/tests/unit_tests/core/workflow/test_workflow_entry.py @@ -1,3 +1,6 @@ +from typing import Optional + +from core.workflow.graph_engine.entities.graph import Graph from core.workflow.workflow_entry import WorkflowEntry @@ -230,3 +233,461 @@ def test__init_graph(): assert graph.graph_nodes.get("llm").run_condition is not None assert graph.graph_nodes.get("1719481315734").run_condition is not None + + +def test__init_graph_with_iteration(): + graph_config = { + "edges": [ + { + "data": { + "sourceType": "llm", + "targetType": "answer" + }, + "id": "llm-answer", + "source": "llm", + "sourceHandle": "source", + "target": "answer", + "targetHandle": "target", + "type": "custom" + }, + { + "data": { + "isInIteration": False, + "sourceType": "iteration", + "targetType": "llm" + }, + "id": "1720001776597-source-llm-target", + "selected": False, + "source": "1720001776597", + "sourceHandle": "source", + "target": "llm", + "targetHandle": "target", + "type": "custom", + "zIndex": 0 + }, + { + "data": { + "isInIteration": True, + "iteration_id": "1720001776597", + "sourceType": "template-transform", + "targetType": "llm" + }, + "id": "1720001783092-source-1720001859851-target", + "source": "1720001783092", + "sourceHandle": "source", + "target": "1720001859851", + "targetHandle": "target", + "type": "custom", + "zIndex": 1002 + }, + { + "data": { + "isInIteration": True, + "iteration_id": "1720001776597", + "sourceType": "llm", + "targetType": "answer" + }, + "id": "1720001859851-source-1720001879621-target", + "source": "1720001859851", + "sourceHandle": "source", + "target": "1720001879621", + "targetHandle": "target", + "type": "custom", + "zIndex": 1002 + }, + { + "data": { + "isInIteration": False, + "sourceType": "start", + "targetType": "code" + }, + "id": "1720001771022-source-1720001956578-target", + "source": "1720001771022", + "sourceHandle": "source", + "target": "1720001956578", + "targetHandle": "target", + "type": "custom", + "zIndex": 0 + }, + { + "data": { + "isInIteration": False, + "sourceType": "code", + "targetType": "iteration" + }, + "id": "1720001956578-source-1720001776597-target", + "source": "1720001956578", + "sourceHandle": "source", + "target": "1720001776597", + "targetHandle": "target", + "type": "custom", + "zIndex": 0 + } + ], + "nodes": [ + { + "data": { + "desc": "", + "selected": False, + "title": "Start", + "type": "start", + "variables": [] + }, + "height": 53, + "id": "1720001771022", + "position": { + "x": 80, + "y": 282 + }, + "positionAbsolute": { + "x": 80, + "y": 282 + }, + "selected": False, + "sourcePosition": "right", + "targetPosition": "left", + "type": "custom", + "width": 244 + }, + { + "data": { + "context": { + "enabled": False, + "variable_selector": [] + }, + "desc": "", + "memory": { + "role_prefix": { + "assistant": "", + "user": "" + }, + "window": { + "enabled": False, + "size": 10 + } + }, + "model": { + "completion_params": { + "temperature": 0.7 + }, + "mode": "chat", + "name": "gpt-3.5-turbo", + "provider": "openai" + }, + "prompt_template": [ + { + "id": "b7d1350e-cf0d-4ff3-8ad0-52b6f1218781", + "role": "system", + "text": "" + } + ], + "selected": False, + "title": "LLM", + "type": "llm", + "variables": [], + "vision": { + "enabled": False + } + }, + "height": 97, + "id": "llm", + "position": { + "x": 1730.595805935594, + "y": 282 + }, + "positionAbsolute": { + "x": 1730.595805935594, + "y": 282 + }, + "selected": True, + "sourcePosition": "right", + "targetPosition": "left", + "type": "custom", + "width": 244 + }, + { + "data": { + "answer": "{{#llm.text#}}", + "desc": "", + "selected": False, + "title": "Answer", + "type": "answer", + "variables": [] + }, + "height": 105, + "id": "answer", + "position": { + "x": 2042.803154918583, + "y": 282 + }, + "positionAbsolute": { + "x": 2042.803154918583, + "y": 282 + }, + "selected": False, + "sourcePosition": "right", + "targetPosition": "left", + "type": "custom", + "width": 244 + }, + { + "data": { + "desc": "", + "height": 202, + "iterator_selector": [ + "1720001956578", + "result" + ], + "output_selector": [ + "1720001859851", + "text" + ], + "output_type": "array[string]", + "selected": False, + "startNodeType": "template-transform", + "start_node_id": "1720001783092", + "title": "Iteration", + "type": "iteration", + "width": 985 + }, + "height": 202, + "id": "1720001776597", + "position": { + "x": 678.6748900850307, + "y": 282 + }, + "positionAbsolute": { + "x": 678.6748900850307, + "y": 282 + }, + "selected": False, + "sourcePosition": "right", + "targetPosition": "left", + "type": "custom", + "width": 985, + "zIndex": 1 + }, + { + "data": { + "desc": "", + "isInIteration": True, + "isIterationStart": True, + "iteration_id": "1720001776597", + "selected": False, + "template": "{{ arg1 }}", + "title": "Template", + "type": "template-transform", + "variables": [ + { + "value_selector": [ + "1720001776597", + "item" + ], + "variable": "arg1" + } + ] + }, + "extent": "parent", + "height": 53, + "id": "1720001783092", + "parentId": "1720001776597", + "position": { + "x": 117, + "y": 85 + }, + "positionAbsolute": { + "x": 795.6748900850307, + "y": 367 + }, + "selected": False, + "sourcePosition": "right", + "targetPosition": "left", + "type": "custom", + "width": 244, + "zIndex": 1001 + }, + { + "data": { + "context": { + "enabled": False, + "variable_selector": [] + }, + "desc": "", + "isInIteration": True, + "iteration_id": "1720001776597", + "model": { + "completion_params": { + "temperature": 0.7 + }, + "mode": "chat", + "name": "gpt-3.5-turbo", + "provider": "openai" + }, + "prompt_template": [ + { + "id": "9575b8f2-33c4-4611-b6d0-17d8d436a250", + "role": "system", + "text": "{{#1720001783092.output#}}" + } + ], + "selected": False, + "title": "LLM 2", + "type": "llm", + "variables": [], + "vision": { + "enabled": False + } + }, + "extent": "parent", + "height": 97, + "id": "1720001859851", + "parentId": "1720001776597", + "position": { + "x": 421, + "y": 85 + }, + "positionAbsolute": { + "x": 1099.6748900850307, + "y": 367 + }, + "selected": False, + "sourcePosition": "right", + "targetPosition": "left", + "type": "custom", + "width": 244, + "zIndex": 1002 + }, + { + "data": { + "answer": "{{#1720001859851.text#}}", + "desc": "", + "isInIteration": True, + "iteration_id": "1720001776597", + "selected": False, + "title": "Answer 2", + "type": "answer", + "variables": [] + }, + "extent": "parent", + "height": 105, + "id": "1720001879621", + "parentId": "1720001776597", + "position": { + "x": 725, + "y": 85 + }, + "positionAbsolute": { + "x": 1403.6748900850307, + "y": 367 + }, + "selected": False, + "sourcePosition": "right", + "targetPosition": "left", + "type": "custom", + "width": 244, + "zIndex": 1002 + }, + { + "data": { + "code": "\ndef main() -> dict:\n return {\n \"result\": [\n \"a\",\n \"b\"\n ]\n }\n", + "code_language": "python3", + "desc": "", + "outputs": { + "result": { + "children": None, + "type": "array[string]" + } + }, + "selected": False, + "title": "Code", + "type": "code", + "variables": [] + }, + "height": 53, + "id": "1720001956578", + "position": { + "x": 380, + "y": 282 + }, + "positionAbsolute": { + "x": 380, + "y": 282 + }, + "sourcePosition": "right", + "targetPosition": "left", + "type": "custom", + "width": 244 + } + ] + } + + workflow_entry = WorkflowEntry() + graph = workflow_entry._init_graph( + graph_config=graph_config + ) + + # start 1720001771022 -> code 1720001956578 -> iteration 1720001776597 -> llm llm -> answer answer + # iteration 1720001776597: + # [template 1720001783092 -> llm 1720001859851 -> answer 1720001879621] + + main_graph_orders = [ + "1720001771022", + "1720001956578", + "1720001776597", + "llm", + "answer" + ] + + iteration_sub_graph_orders = [ + "1720001783092", + "1720001859851", + "1720001879621" + ] + + assert graph.root_node.id == "1720001771022" + + print("") + + current_graph = graph + for i, node_id in enumerate(main_graph_orders): + current_root_node = current_graph.root_node + assert current_root_node is not None + assert current_root_node.id == node_id + + if current_root_node.node_config.get("data", {}).get("type") == "iteration": + assert current_root_node.sub_graph is not None + + sub_graph = current_root_node.sub_graph + assert sub_graph.root_node.id == "1720001783092" + + current_sub_graph = sub_graph + for j, sub_node_id in enumerate(iteration_sub_graph_orders): + sub_descendant_graphs = current_sub_graph.get_descendant_graphs(node_id=current_sub_graph.root_node.id) + print(f"Iteration [{current_sub_graph.root_node.id}] -> {len(sub_descendant_graphs)}" + f" {[sub_descendant_graph.root_node.id for sub_descendant_graph in sub_descendant_graphs]}") + + if j == len(iteration_sub_graph_orders) - 1: + break + + assert len(sub_descendant_graphs) == 1 + + first_sub_descendant_graph = sub_descendant_graphs[0] + assert first_sub_descendant_graph.root_node.id == iteration_sub_graph_orders[j + 1] + assert first_sub_descendant_graph.root_node.predecessor_node_id == sub_node_id + + current_sub_graph = first_sub_descendant_graph + + descendant_graphs = current_graph.get_descendant_graphs(node_id=current_graph.root_node.id) + print(f"[{current_graph.root_node.id}] -> {len(descendant_graphs)}" + f" {[descendant_graph.root_node.id for descendant_graph in descendant_graphs]}") + if i == len(main_graph_orders) - 1: + assert len(descendant_graphs) == 0 + break + + assert len(descendant_graphs) == 1 + + first_descendant_graph = descendant_graphs[0] + assert first_descendant_graph.root_node.id == main_graph_orders[i + 1] + assert first_descendant_graph.root_node.predecessor_node_id == node_id + + current_graph = first_descendant_graph