completed graph init test

This commit is contained in:
takatost 2024-07-04 15:40:20 +08:00
parent 0f19b2a986
commit 1b6cd975f3
3 changed files with 531 additions and 36 deletions

View File

@ -1,6 +1,6 @@
from typing import Optional from typing import Optional
from pydantic import BaseModel, Field from pydantic import BaseModel, Field, model_validator
from core.workflow.graph_engine.condition_handlers.base_handler import RunConditionHandler from core.workflow.graph_engine.condition_handlers.base_handler import RunConditionHandler
from core.workflow.graph_engine.condition_handlers.condition_manager import ConditionManager from core.workflow.graph_engine.condition_handlers.condition_manager import ConditionManager
@ -33,7 +33,8 @@ class GraphNode(BaseModel):
"""sub graph of the node, e.g. iteration/loop sub graph""" """sub graph of the node, e.g. iteration/loop sub graph"""
def add_child(self, node_id: str) -> None: def add_child(self, node_id: str) -> None:
self.descendant_node_ids.append(node_id) if node_id not in self.descendant_node_ids:
self.descendant_node_ids.append(node_id)
def get_run_condition_handler(self) -> Optional[RunConditionHandler]: def get_run_condition_handler(self) -> Optional[RunConditionHandler]:
""" """
@ -56,6 +57,12 @@ class Graph(BaseModel):
root_node: GraphNode root_node: GraphNode
"""root node of the graph""" """root node of the graph"""
@model_validator(mode='after')
def add_root_node(cls, values):
root_node = values.root_node
values.graph_nodes[root_node.id] = root_node
return values
@classmethod @classmethod
def init(cls, root_node_config: dict, run_condition: Optional[RunCondition] = None) -> "Graph": def init(cls, root_node_config: dict, run_condition: Optional[RunCondition] = None) -> "Graph":
""" """
@ -76,10 +83,7 @@ class Graph(BaseModel):
run_condition=run_condition run_condition=run_condition
) )
graph = cls(root_node=root_node) return cls(root_node=root_node)
graph._add_graph_node(graph.root_node)
return graph
def add_edge(self, edge_config: dict, def add_edge(self, edge_config: dict,
source_node_config: dict, source_node_config: dict,
@ -106,7 +110,10 @@ class Graph(BaseModel):
if not target_node_id: if not target_node_id:
return return
source_node = self.graph_nodes[source_node_id] source_node = self.graph_nodes.get(source_node_id)
if not source_node:
return
source_node.add_child(target_node_id) source_node.add_child(target_node_id)
if target_node_id not in self.graph_nodes: if target_node_id not in self.graph_nodes:
@ -120,45 +127,66 @@ class Graph(BaseModel):
sub_graph=target_node_sub_graph sub_graph=target_node_sub_graph
) )
self._add_graph_node(target_graph_node) self.add_graph_node(target_graph_node)
else: else:
target_node = self.graph_nodes[target_node_id] target_node = self.graph_nodes.get(target_node_id)
if not target_node:
return
target_node.predecessor_node_id = source_node_id target_node.predecessor_node_id = source_node_id
target_node.run_condition = run_condition target_node.run_condition = run_condition
target_node.source_edge_config = edge_config target_node.source_edge_config = edge_config
target_node.sub_graph = target_node_sub_graph target_node.sub_graph = target_node_sub_graph
def get_root_node(self) -> Optional[GraphNode]: def get_leaf_nodes(self) -> list[GraphNode]:
""" """
Get root node of the graph Get leaf nodes of the graph
:return: root node :return: leaf nodes
""" """
return self.root_node leaf_nodes = []
for node_id, graph_node in self.graph_nodes.items():
if (
not graph_node.descendant_node_ids # has no child
or # or has only one child and the child is the root node
(
graph_node.descendant_node_ids
and graph_node.descendant_node_ids[0] == self.root_node.id
)
):
leaf_nodes.append(graph_node)
def get_descendants_graph(self, node_id: str) -> Optional["Graph"]: return leaf_nodes
def get_descendant_graphs(self, node_id: str) -> list["Graph"]:
""" """
Get descendants graph of the specific node Get descendant graphs of the specific node
:param node_id: node id :param node_id: node id
:return: descendants graph :return: descendant graphs
""" """
if node_id not in self.graph_nodes: if node_id not in self.graph_nodes:
return None return []
graph_node = self.graph_nodes[node_id] graph_node = self.graph_nodes.get(node_id)
if not graph_node.descendant_node_ids: if not graph_node or not graph_node.descendant_node_ids:
return None return []
descendants_graph = Graph(root_node=graph_node) descendant_graphs: list[Graph] = []
descendants_graph._add_graph_node(graph_node) for descendant_node_id in graph_node.descendant_node_ids:
descendant_graph_node = self.graph_nodes.get(descendant_node_id)
if not descendant_graph_node:
continue
for child_node_id in graph_node.descendant_node_ids: descendants_graph = Graph(root_node=descendant_graph_node)
self._add_descendants_graph_nodes(descendants_graph, child_node_id) for sub_descendant_node_id in descendant_graph_node.descendant_node_ids:
descendants_graph.add_descendants_graph_nodes(self, sub_descendant_node_id)
return descendants_graph descendant_graphs.append(descendants_graph)
def _add_graph_node(self, graph_node: GraphNode) -> None: return descendant_graphs
def add_graph_node(self, graph_node: GraphNode) -> None:
""" """
Add graph node to the graph Add graph node to the graph
@ -167,23 +195,24 @@ class Graph(BaseModel):
if graph_node.id in self.graph_nodes: if graph_node.id in self.graph_nodes:
return return
if len(self.graph_nodes) == 0:
self.root_node = graph_node
self.graph_nodes[graph_node.id] = graph_node self.graph_nodes[graph_node.id] = graph_node
def _add_descendants_graph_nodes(self, descendants_graph: "Graph", node_id: str) -> None: def add_descendants_graph_nodes(self, predecessor_graph: "Graph", node_id: str) -> None:
""" """
Add descendants graph nodes Add descendants graph nodes
:param descendants_graph: descendants graph :param predecessor_graph: predecessor graph
:param node_id: node id :param node_id: node id
""" """
if node_id not in self.graph_nodes: if node_id not in predecessor_graph.graph_nodes:
return return
graph_node = self.graph_nodes[node_id] graph_node = predecessor_graph.graph_nodes.get(node_id)
descendants_graph._add_graph_node(graph_node) if not graph_node:
return
for child_node_id in graph_node.descendant_node_ids: if graph_node.id not in self.graph_nodes:
self._add_descendants_graph_nodes(descendants_graph, child_node_id) self.add_graph_node(graph_node)
for child_node_id in graph_node.descendant_node_ids:
self.add_descendants_graph_nodes(predecessor_graph, child_node_id)

View File

@ -298,6 +298,11 @@ class WorkflowEntry:
root_node_configs=root_node_configs root_node_configs=root_node_configs
) )
# add edge from end node to first node of sub graph
sub_graph_root_node_id = sub_graph.root_node.id
for leaf_node in sub_graph.get_leaf_nodes():
leaf_node.add_child(sub_graph_root_node_id)
# parse run condition # parse run condition
run_condition = None run_condition = None
if edge_config.get('sourceHandle'): if edge_config.get('sourceHandle'):

View File

@ -1,3 +1,6 @@
from typing import Optional
from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.workflow_entry import WorkflowEntry from core.workflow.workflow_entry import WorkflowEntry
@ -230,3 +233,461 @@ def test__init_graph():
assert graph.graph_nodes.get("llm").run_condition is not None assert graph.graph_nodes.get("llm").run_condition is not None
assert graph.graph_nodes.get("1719481315734").run_condition is not None assert graph.graph_nodes.get("1719481315734").run_condition is not None
def test__init_graph_with_iteration():
graph_config = {
"edges": [
{
"data": {
"sourceType": "llm",
"targetType": "answer"
},
"id": "llm-answer",
"source": "llm",
"sourceHandle": "source",
"target": "answer",
"targetHandle": "target",
"type": "custom"
},
{
"data": {
"isInIteration": False,
"sourceType": "iteration",
"targetType": "llm"
},
"id": "1720001776597-source-llm-target",
"selected": False,
"source": "1720001776597",
"sourceHandle": "source",
"target": "llm",
"targetHandle": "target",
"type": "custom",
"zIndex": 0
},
{
"data": {
"isInIteration": True,
"iteration_id": "1720001776597",
"sourceType": "template-transform",
"targetType": "llm"
},
"id": "1720001783092-source-1720001859851-target",
"source": "1720001783092",
"sourceHandle": "source",
"target": "1720001859851",
"targetHandle": "target",
"type": "custom",
"zIndex": 1002
},
{
"data": {
"isInIteration": True,
"iteration_id": "1720001776597",
"sourceType": "llm",
"targetType": "answer"
},
"id": "1720001859851-source-1720001879621-target",
"source": "1720001859851",
"sourceHandle": "source",
"target": "1720001879621",
"targetHandle": "target",
"type": "custom",
"zIndex": 1002
},
{
"data": {
"isInIteration": False,
"sourceType": "start",
"targetType": "code"
},
"id": "1720001771022-source-1720001956578-target",
"source": "1720001771022",
"sourceHandle": "source",
"target": "1720001956578",
"targetHandle": "target",
"type": "custom",
"zIndex": 0
},
{
"data": {
"isInIteration": False,
"sourceType": "code",
"targetType": "iteration"
},
"id": "1720001956578-source-1720001776597-target",
"source": "1720001956578",
"sourceHandle": "source",
"target": "1720001776597",
"targetHandle": "target",
"type": "custom",
"zIndex": 0
}
],
"nodes": [
{
"data": {
"desc": "",
"selected": False,
"title": "Start",
"type": "start",
"variables": []
},
"height": 53,
"id": "1720001771022",
"position": {
"x": 80,
"y": 282
},
"positionAbsolute": {
"x": 80,
"y": 282
},
"selected": False,
"sourcePosition": "right",
"targetPosition": "left",
"type": "custom",
"width": 244
},
{
"data": {
"context": {
"enabled": False,
"variable_selector": []
},
"desc": "",
"memory": {
"role_prefix": {
"assistant": "",
"user": ""
},
"window": {
"enabled": False,
"size": 10
}
},
"model": {
"completion_params": {
"temperature": 0.7
},
"mode": "chat",
"name": "gpt-3.5-turbo",
"provider": "openai"
},
"prompt_template": [
{
"id": "b7d1350e-cf0d-4ff3-8ad0-52b6f1218781",
"role": "system",
"text": ""
}
],
"selected": False,
"title": "LLM",
"type": "llm",
"variables": [],
"vision": {
"enabled": False
}
},
"height": 97,
"id": "llm",
"position": {
"x": 1730.595805935594,
"y": 282
},
"positionAbsolute": {
"x": 1730.595805935594,
"y": 282
},
"selected": True,
"sourcePosition": "right",
"targetPosition": "left",
"type": "custom",
"width": 244
},
{
"data": {
"answer": "{{#llm.text#}}",
"desc": "",
"selected": False,
"title": "Answer",
"type": "answer",
"variables": []
},
"height": 105,
"id": "answer",
"position": {
"x": 2042.803154918583,
"y": 282
},
"positionAbsolute": {
"x": 2042.803154918583,
"y": 282
},
"selected": False,
"sourcePosition": "right",
"targetPosition": "left",
"type": "custom",
"width": 244
},
{
"data": {
"desc": "",
"height": 202,
"iterator_selector": [
"1720001956578",
"result"
],
"output_selector": [
"1720001859851",
"text"
],
"output_type": "array[string]",
"selected": False,
"startNodeType": "template-transform",
"start_node_id": "1720001783092",
"title": "Iteration",
"type": "iteration",
"width": 985
},
"height": 202,
"id": "1720001776597",
"position": {
"x": 678.6748900850307,
"y": 282
},
"positionAbsolute": {
"x": 678.6748900850307,
"y": 282
},
"selected": False,
"sourcePosition": "right",
"targetPosition": "left",
"type": "custom",
"width": 985,
"zIndex": 1
},
{
"data": {
"desc": "",
"isInIteration": True,
"isIterationStart": True,
"iteration_id": "1720001776597",
"selected": False,
"template": "{{ arg1 }}",
"title": "Template",
"type": "template-transform",
"variables": [
{
"value_selector": [
"1720001776597",
"item"
],
"variable": "arg1"
}
]
},
"extent": "parent",
"height": 53,
"id": "1720001783092",
"parentId": "1720001776597",
"position": {
"x": 117,
"y": 85
},
"positionAbsolute": {
"x": 795.6748900850307,
"y": 367
},
"selected": False,
"sourcePosition": "right",
"targetPosition": "left",
"type": "custom",
"width": 244,
"zIndex": 1001
},
{
"data": {
"context": {
"enabled": False,
"variable_selector": []
},
"desc": "",
"isInIteration": True,
"iteration_id": "1720001776597",
"model": {
"completion_params": {
"temperature": 0.7
},
"mode": "chat",
"name": "gpt-3.5-turbo",
"provider": "openai"
},
"prompt_template": [
{
"id": "9575b8f2-33c4-4611-b6d0-17d8d436a250",
"role": "system",
"text": "{{#1720001783092.output#}}"
}
],
"selected": False,
"title": "LLM 2",
"type": "llm",
"variables": [],
"vision": {
"enabled": False
}
},
"extent": "parent",
"height": 97,
"id": "1720001859851",
"parentId": "1720001776597",
"position": {
"x": 421,
"y": 85
},
"positionAbsolute": {
"x": 1099.6748900850307,
"y": 367
},
"selected": False,
"sourcePosition": "right",
"targetPosition": "left",
"type": "custom",
"width": 244,
"zIndex": 1002
},
{
"data": {
"answer": "{{#1720001859851.text#}}",
"desc": "",
"isInIteration": True,
"iteration_id": "1720001776597",
"selected": False,
"title": "Answer 2",
"type": "answer",
"variables": []
},
"extent": "parent",
"height": 105,
"id": "1720001879621",
"parentId": "1720001776597",
"position": {
"x": 725,
"y": 85
},
"positionAbsolute": {
"x": 1403.6748900850307,
"y": 367
},
"selected": False,
"sourcePosition": "right",
"targetPosition": "left",
"type": "custom",
"width": 244,
"zIndex": 1002
},
{
"data": {
"code": "\ndef main() -> dict:\n return {\n \"result\": [\n \"a\",\n \"b\"\n ]\n }\n",
"code_language": "python3",
"desc": "",
"outputs": {
"result": {
"children": None,
"type": "array[string]"
}
},
"selected": False,
"title": "Code",
"type": "code",
"variables": []
},
"height": 53,
"id": "1720001956578",
"position": {
"x": 380,
"y": 282
},
"positionAbsolute": {
"x": 380,
"y": 282
},
"sourcePosition": "right",
"targetPosition": "left",
"type": "custom",
"width": 244
}
]
}
workflow_entry = WorkflowEntry()
graph = workflow_entry._init_graph(
graph_config=graph_config
)
# start 1720001771022 -> code 1720001956578 -> iteration 1720001776597 -> llm llm -> answer answer
# iteration 1720001776597:
# [template 1720001783092 -> llm 1720001859851 -> answer 1720001879621]
main_graph_orders = [
"1720001771022",
"1720001956578",
"1720001776597",
"llm",
"answer"
]
iteration_sub_graph_orders = [
"1720001783092",
"1720001859851",
"1720001879621"
]
assert graph.root_node.id == "1720001771022"
print("")
current_graph = graph
for i, node_id in enumerate(main_graph_orders):
current_root_node = current_graph.root_node
assert current_root_node is not None
assert current_root_node.id == node_id
if current_root_node.node_config.get("data", {}).get("type") == "iteration":
assert current_root_node.sub_graph is not None
sub_graph = current_root_node.sub_graph
assert sub_graph.root_node.id == "1720001783092"
current_sub_graph = sub_graph
for j, sub_node_id in enumerate(iteration_sub_graph_orders):
sub_descendant_graphs = current_sub_graph.get_descendant_graphs(node_id=current_sub_graph.root_node.id)
print(f"Iteration [{current_sub_graph.root_node.id}] -> {len(sub_descendant_graphs)}"
f" {[sub_descendant_graph.root_node.id for sub_descendant_graph in sub_descendant_graphs]}")
if j == len(iteration_sub_graph_orders) - 1:
break
assert len(sub_descendant_graphs) == 1
first_sub_descendant_graph = sub_descendant_graphs[0]
assert first_sub_descendant_graph.root_node.id == iteration_sub_graph_orders[j + 1]
assert first_sub_descendant_graph.root_node.predecessor_node_id == sub_node_id
current_sub_graph = first_sub_descendant_graph
descendant_graphs = current_graph.get_descendant_graphs(node_id=current_graph.root_node.id)
print(f"[{current_graph.root_node.id}] -> {len(descendant_graphs)}"
f" {[descendant_graph.root_node.id for descendant_graph in descendant_graphs]}")
if i == len(main_graph_orders) - 1:
assert len(descendant_graphs) == 0
break
assert len(descendant_graphs) == 1
first_descendant_graph = descendant_graphs[0]
assert first_descendant_graph.root_node.id == main_graph_orders[i + 1]
assert first_descendant_graph.root_node.predecessor_node_id == node_id
current_graph = first_descendant_graph