mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-18 02:35:56 +08:00
completed graph init test
This commit is contained in:
parent
0f19b2a986
commit
1b6cd975f3
@ -1,6 +1,6 @@
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
|
||||
from core.workflow.graph_engine.condition_handlers.base_handler import RunConditionHandler
|
||||
from core.workflow.graph_engine.condition_handlers.condition_manager import ConditionManager
|
||||
@ -33,7 +33,8 @@ class GraphNode(BaseModel):
|
||||
"""sub graph of the node, e.g. iteration/loop sub graph"""
|
||||
|
||||
def add_child(self, node_id: str) -> None:
|
||||
self.descendant_node_ids.append(node_id)
|
||||
if node_id not in self.descendant_node_ids:
|
||||
self.descendant_node_ids.append(node_id)
|
||||
|
||||
def get_run_condition_handler(self) -> Optional[RunConditionHandler]:
|
||||
"""
|
||||
@ -56,6 +57,12 @@ class Graph(BaseModel):
|
||||
root_node: GraphNode
|
||||
"""root node of the graph"""
|
||||
|
||||
@model_validator(mode='after')
|
||||
def add_root_node(cls, values):
|
||||
root_node = values.root_node
|
||||
values.graph_nodes[root_node.id] = root_node
|
||||
return values
|
||||
|
||||
@classmethod
|
||||
def init(cls, root_node_config: dict, run_condition: Optional[RunCondition] = None) -> "Graph":
|
||||
"""
|
||||
@ -76,10 +83,7 @@ class Graph(BaseModel):
|
||||
run_condition=run_condition
|
||||
)
|
||||
|
||||
graph = cls(root_node=root_node)
|
||||
|
||||
graph._add_graph_node(graph.root_node)
|
||||
return graph
|
||||
return cls(root_node=root_node)
|
||||
|
||||
def add_edge(self, edge_config: dict,
|
||||
source_node_config: dict,
|
||||
@ -106,7 +110,10 @@ class Graph(BaseModel):
|
||||
if not target_node_id:
|
||||
return
|
||||
|
||||
source_node = self.graph_nodes[source_node_id]
|
||||
source_node = self.graph_nodes.get(source_node_id)
|
||||
if not source_node:
|
||||
return
|
||||
|
||||
source_node.add_child(target_node_id)
|
||||
|
||||
if target_node_id not in self.graph_nodes:
|
||||
@ -120,45 +127,66 @@ class Graph(BaseModel):
|
||||
sub_graph=target_node_sub_graph
|
||||
)
|
||||
|
||||
self._add_graph_node(target_graph_node)
|
||||
self.add_graph_node(target_graph_node)
|
||||
else:
|
||||
target_node = self.graph_nodes[target_node_id]
|
||||
target_node = self.graph_nodes.get(target_node_id)
|
||||
if not target_node:
|
||||
return
|
||||
|
||||
target_node.predecessor_node_id = source_node_id
|
||||
target_node.run_condition = run_condition
|
||||
target_node.source_edge_config = edge_config
|
||||
target_node.sub_graph = target_node_sub_graph
|
||||
|
||||
def get_root_node(self) -> Optional[GraphNode]:
|
||||
def get_leaf_nodes(self) -> list[GraphNode]:
|
||||
"""
|
||||
Get root node of the graph
|
||||
Get leaf nodes of the graph
|
||||
|
||||
:return: root node
|
||||
:return: leaf nodes
|
||||
"""
|
||||
return self.root_node
|
||||
leaf_nodes = []
|
||||
for node_id, graph_node in self.graph_nodes.items():
|
||||
if (
|
||||
not graph_node.descendant_node_ids # has no child
|
||||
or # or has only one child and the child is the root node
|
||||
(
|
||||
graph_node.descendant_node_ids
|
||||
and graph_node.descendant_node_ids[0] == self.root_node.id
|
||||
)
|
||||
):
|
||||
leaf_nodes.append(graph_node)
|
||||
|
||||
def get_descendants_graph(self, node_id: str) -> Optional["Graph"]:
|
||||
return leaf_nodes
|
||||
|
||||
def get_descendant_graphs(self, node_id: str) -> list["Graph"]:
|
||||
"""
|
||||
Get descendants graph of the specific node
|
||||
Get descendant graphs of the specific node
|
||||
|
||||
:param node_id: node id
|
||||
:return: descendants graph
|
||||
:return: descendant graphs
|
||||
"""
|
||||
if node_id not in self.graph_nodes:
|
||||
return None
|
||||
return []
|
||||
|
||||
graph_node = self.graph_nodes[node_id]
|
||||
if not graph_node.descendant_node_ids:
|
||||
return None
|
||||
graph_node = self.graph_nodes.get(node_id)
|
||||
if not graph_node or not graph_node.descendant_node_ids:
|
||||
return []
|
||||
|
||||
descendants_graph = Graph(root_node=graph_node)
|
||||
descendants_graph._add_graph_node(graph_node)
|
||||
descendant_graphs: list[Graph] = []
|
||||
for descendant_node_id in graph_node.descendant_node_ids:
|
||||
descendant_graph_node = self.graph_nodes.get(descendant_node_id)
|
||||
if not descendant_graph_node:
|
||||
continue
|
||||
|
||||
for child_node_id in graph_node.descendant_node_ids:
|
||||
self._add_descendants_graph_nodes(descendants_graph, child_node_id)
|
||||
descendants_graph = Graph(root_node=descendant_graph_node)
|
||||
for sub_descendant_node_id in descendant_graph_node.descendant_node_ids:
|
||||
descendants_graph.add_descendants_graph_nodes(self, sub_descendant_node_id)
|
||||
|
||||
return descendants_graph
|
||||
descendant_graphs.append(descendants_graph)
|
||||
|
||||
def _add_graph_node(self, graph_node: GraphNode) -> None:
|
||||
return descendant_graphs
|
||||
|
||||
def add_graph_node(self, graph_node: GraphNode) -> None:
|
||||
"""
|
||||
Add graph node to the graph
|
||||
|
||||
@ -167,23 +195,24 @@ class Graph(BaseModel):
|
||||
if graph_node.id in self.graph_nodes:
|
||||
return
|
||||
|
||||
if len(self.graph_nodes) == 0:
|
||||
self.root_node = graph_node
|
||||
|
||||
self.graph_nodes[graph_node.id] = graph_node
|
||||
|
||||
def _add_descendants_graph_nodes(self, descendants_graph: "Graph", node_id: str) -> None:
|
||||
def add_descendants_graph_nodes(self, predecessor_graph: "Graph", node_id: str) -> None:
|
||||
"""
|
||||
Add descendants graph nodes
|
||||
|
||||
:param descendants_graph: descendants graph
|
||||
:param predecessor_graph: predecessor graph
|
||||
:param node_id: node id
|
||||
"""
|
||||
if node_id not in self.graph_nodes:
|
||||
if node_id not in predecessor_graph.graph_nodes:
|
||||
return
|
||||
|
||||
graph_node = self.graph_nodes[node_id]
|
||||
descendants_graph._add_graph_node(graph_node)
|
||||
graph_node = predecessor_graph.graph_nodes.get(node_id)
|
||||
if not graph_node:
|
||||
return
|
||||
|
||||
for child_node_id in graph_node.descendant_node_ids:
|
||||
self._add_descendants_graph_nodes(descendants_graph, child_node_id)
|
||||
if graph_node.id not in self.graph_nodes:
|
||||
self.add_graph_node(graph_node)
|
||||
|
||||
for child_node_id in graph_node.descendant_node_ids:
|
||||
self.add_descendants_graph_nodes(predecessor_graph, child_node_id)
|
||||
|
@ -298,6 +298,11 @@ class WorkflowEntry:
|
||||
root_node_configs=root_node_configs
|
||||
)
|
||||
|
||||
# add edge from end node to first node of sub graph
|
||||
sub_graph_root_node_id = sub_graph.root_node.id
|
||||
for leaf_node in sub_graph.get_leaf_nodes():
|
||||
leaf_node.add_child(sub_graph_root_node_id)
|
||||
|
||||
# parse run condition
|
||||
run_condition = None
|
||||
if edge_config.get('sourceHandle'):
|
||||
|
@ -1,3 +1,6 @@
|
||||
from typing import Optional
|
||||
|
||||
from core.workflow.graph_engine.entities.graph import Graph
|
||||
from core.workflow.workflow_entry import WorkflowEntry
|
||||
|
||||
|
||||
@ -230,3 +233,461 @@ def test__init_graph():
|
||||
|
||||
assert graph.graph_nodes.get("llm").run_condition is not None
|
||||
assert graph.graph_nodes.get("1719481315734").run_condition is not None
|
||||
|
||||
|
||||
def test__init_graph_with_iteration():
|
||||
graph_config = {
|
||||
"edges": [
|
||||
{
|
||||
"data": {
|
||||
"sourceType": "llm",
|
||||
"targetType": "answer"
|
||||
},
|
||||
"id": "llm-answer",
|
||||
"source": "llm",
|
||||
"sourceHandle": "source",
|
||||
"target": "answer",
|
||||
"targetHandle": "target",
|
||||
"type": "custom"
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"isInIteration": False,
|
||||
"sourceType": "iteration",
|
||||
"targetType": "llm"
|
||||
},
|
||||
"id": "1720001776597-source-llm-target",
|
||||
"selected": False,
|
||||
"source": "1720001776597",
|
||||
"sourceHandle": "source",
|
||||
"target": "llm",
|
||||
"targetHandle": "target",
|
||||
"type": "custom",
|
||||
"zIndex": 0
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"isInIteration": True,
|
||||
"iteration_id": "1720001776597",
|
||||
"sourceType": "template-transform",
|
||||
"targetType": "llm"
|
||||
},
|
||||
"id": "1720001783092-source-1720001859851-target",
|
||||
"source": "1720001783092",
|
||||
"sourceHandle": "source",
|
||||
"target": "1720001859851",
|
||||
"targetHandle": "target",
|
||||
"type": "custom",
|
||||
"zIndex": 1002
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"isInIteration": True,
|
||||
"iteration_id": "1720001776597",
|
||||
"sourceType": "llm",
|
||||
"targetType": "answer"
|
||||
},
|
||||
"id": "1720001859851-source-1720001879621-target",
|
||||
"source": "1720001859851",
|
||||
"sourceHandle": "source",
|
||||
"target": "1720001879621",
|
||||
"targetHandle": "target",
|
||||
"type": "custom",
|
||||
"zIndex": 1002
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"isInIteration": False,
|
||||
"sourceType": "start",
|
||||
"targetType": "code"
|
||||
},
|
||||
"id": "1720001771022-source-1720001956578-target",
|
||||
"source": "1720001771022",
|
||||
"sourceHandle": "source",
|
||||
"target": "1720001956578",
|
||||
"targetHandle": "target",
|
||||
"type": "custom",
|
||||
"zIndex": 0
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"isInIteration": False,
|
||||
"sourceType": "code",
|
||||
"targetType": "iteration"
|
||||
},
|
||||
"id": "1720001956578-source-1720001776597-target",
|
||||
"source": "1720001956578",
|
||||
"sourceHandle": "source",
|
||||
"target": "1720001776597",
|
||||
"targetHandle": "target",
|
||||
"type": "custom",
|
||||
"zIndex": 0
|
||||
}
|
||||
],
|
||||
"nodes": [
|
||||
{
|
||||
"data": {
|
||||
"desc": "",
|
||||
"selected": False,
|
||||
"title": "Start",
|
||||
"type": "start",
|
||||
"variables": []
|
||||
},
|
||||
"height": 53,
|
||||
"id": "1720001771022",
|
||||
"position": {
|
||||
"x": 80,
|
||||
"y": 282
|
||||
},
|
||||
"positionAbsolute": {
|
||||
"x": 80,
|
||||
"y": 282
|
||||
},
|
||||
"selected": False,
|
||||
"sourcePosition": "right",
|
||||
"targetPosition": "left",
|
||||
"type": "custom",
|
||||
"width": 244
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"context": {
|
||||
"enabled": False,
|
||||
"variable_selector": []
|
||||
},
|
||||
"desc": "",
|
||||
"memory": {
|
||||
"role_prefix": {
|
||||
"assistant": "",
|
||||
"user": ""
|
||||
},
|
||||
"window": {
|
||||
"enabled": False,
|
||||
"size": 10
|
||||
}
|
||||
},
|
||||
"model": {
|
||||
"completion_params": {
|
||||
"temperature": 0.7
|
||||
},
|
||||
"mode": "chat",
|
||||
"name": "gpt-3.5-turbo",
|
||||
"provider": "openai"
|
||||
},
|
||||
"prompt_template": [
|
||||
{
|
||||
"id": "b7d1350e-cf0d-4ff3-8ad0-52b6f1218781",
|
||||
"role": "system",
|
||||
"text": ""
|
||||
}
|
||||
],
|
||||
"selected": False,
|
||||
"title": "LLM",
|
||||
"type": "llm",
|
||||
"variables": [],
|
||||
"vision": {
|
||||
"enabled": False
|
||||
}
|
||||
},
|
||||
"height": 97,
|
||||
"id": "llm",
|
||||
"position": {
|
||||
"x": 1730.595805935594,
|
||||
"y": 282
|
||||
},
|
||||
"positionAbsolute": {
|
||||
"x": 1730.595805935594,
|
||||
"y": 282
|
||||
},
|
||||
"selected": True,
|
||||
"sourcePosition": "right",
|
||||
"targetPosition": "left",
|
||||
"type": "custom",
|
||||
"width": 244
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"answer": "{{#llm.text#}}",
|
||||
"desc": "",
|
||||
"selected": False,
|
||||
"title": "Answer",
|
||||
"type": "answer",
|
||||
"variables": []
|
||||
},
|
||||
"height": 105,
|
||||
"id": "answer",
|
||||
"position": {
|
||||
"x": 2042.803154918583,
|
||||
"y": 282
|
||||
},
|
||||
"positionAbsolute": {
|
||||
"x": 2042.803154918583,
|
||||
"y": 282
|
||||
},
|
||||
"selected": False,
|
||||
"sourcePosition": "right",
|
||||
"targetPosition": "left",
|
||||
"type": "custom",
|
||||
"width": 244
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"desc": "",
|
||||
"height": 202,
|
||||
"iterator_selector": [
|
||||
"1720001956578",
|
||||
"result"
|
||||
],
|
||||
"output_selector": [
|
||||
"1720001859851",
|
||||
"text"
|
||||
],
|
||||
"output_type": "array[string]",
|
||||
"selected": False,
|
||||
"startNodeType": "template-transform",
|
||||
"start_node_id": "1720001783092",
|
||||
"title": "Iteration",
|
||||
"type": "iteration",
|
||||
"width": 985
|
||||
},
|
||||
"height": 202,
|
||||
"id": "1720001776597",
|
||||
"position": {
|
||||
"x": 678.6748900850307,
|
||||
"y": 282
|
||||
},
|
||||
"positionAbsolute": {
|
||||
"x": 678.6748900850307,
|
||||
"y": 282
|
||||
},
|
||||
"selected": False,
|
||||
"sourcePosition": "right",
|
||||
"targetPosition": "left",
|
||||
"type": "custom",
|
||||
"width": 985,
|
||||
"zIndex": 1
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"desc": "",
|
||||
"isInIteration": True,
|
||||
"isIterationStart": True,
|
||||
"iteration_id": "1720001776597",
|
||||
"selected": False,
|
||||
"template": "{{ arg1 }}",
|
||||
"title": "Template",
|
||||
"type": "template-transform",
|
||||
"variables": [
|
||||
{
|
||||
"value_selector": [
|
||||
"1720001776597",
|
||||
"item"
|
||||
],
|
||||
"variable": "arg1"
|
||||
}
|
||||
]
|
||||
},
|
||||
"extent": "parent",
|
||||
"height": 53,
|
||||
"id": "1720001783092",
|
||||
"parentId": "1720001776597",
|
||||
"position": {
|
||||
"x": 117,
|
||||
"y": 85
|
||||
},
|
||||
"positionAbsolute": {
|
||||
"x": 795.6748900850307,
|
||||
"y": 367
|
||||
},
|
||||
"selected": False,
|
||||
"sourcePosition": "right",
|
||||
"targetPosition": "left",
|
||||
"type": "custom",
|
||||
"width": 244,
|
||||
"zIndex": 1001
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"context": {
|
||||
"enabled": False,
|
||||
"variable_selector": []
|
||||
},
|
||||
"desc": "",
|
||||
"isInIteration": True,
|
||||
"iteration_id": "1720001776597",
|
||||
"model": {
|
||||
"completion_params": {
|
||||
"temperature": 0.7
|
||||
},
|
||||
"mode": "chat",
|
||||
"name": "gpt-3.5-turbo",
|
||||
"provider": "openai"
|
||||
},
|
||||
"prompt_template": [
|
||||
{
|
||||
"id": "9575b8f2-33c4-4611-b6d0-17d8d436a250",
|
||||
"role": "system",
|
||||
"text": "{{#1720001783092.output#}}"
|
||||
}
|
||||
],
|
||||
"selected": False,
|
||||
"title": "LLM 2",
|
||||
"type": "llm",
|
||||
"variables": [],
|
||||
"vision": {
|
||||
"enabled": False
|
||||
}
|
||||
},
|
||||
"extent": "parent",
|
||||
"height": 97,
|
||||
"id": "1720001859851",
|
||||
"parentId": "1720001776597",
|
||||
"position": {
|
||||
"x": 421,
|
||||
"y": 85
|
||||
},
|
||||
"positionAbsolute": {
|
||||
"x": 1099.6748900850307,
|
||||
"y": 367
|
||||
},
|
||||
"selected": False,
|
||||
"sourcePosition": "right",
|
||||
"targetPosition": "left",
|
||||
"type": "custom",
|
||||
"width": 244,
|
||||
"zIndex": 1002
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"answer": "{{#1720001859851.text#}}",
|
||||
"desc": "",
|
||||
"isInIteration": True,
|
||||
"iteration_id": "1720001776597",
|
||||
"selected": False,
|
||||
"title": "Answer 2",
|
||||
"type": "answer",
|
||||
"variables": []
|
||||
},
|
||||
"extent": "parent",
|
||||
"height": 105,
|
||||
"id": "1720001879621",
|
||||
"parentId": "1720001776597",
|
||||
"position": {
|
||||
"x": 725,
|
||||
"y": 85
|
||||
},
|
||||
"positionAbsolute": {
|
||||
"x": 1403.6748900850307,
|
||||
"y": 367
|
||||
},
|
||||
"selected": False,
|
||||
"sourcePosition": "right",
|
||||
"targetPosition": "left",
|
||||
"type": "custom",
|
||||
"width": 244,
|
||||
"zIndex": 1002
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"code": "\ndef main() -> dict:\n return {\n \"result\": [\n \"a\",\n \"b\"\n ]\n }\n",
|
||||
"code_language": "python3",
|
||||
"desc": "",
|
||||
"outputs": {
|
||||
"result": {
|
||||
"children": None,
|
||||
"type": "array[string]"
|
||||
}
|
||||
},
|
||||
"selected": False,
|
||||
"title": "Code",
|
||||
"type": "code",
|
||||
"variables": []
|
||||
},
|
||||
"height": 53,
|
||||
"id": "1720001956578",
|
||||
"position": {
|
||||
"x": 380,
|
||||
"y": 282
|
||||
},
|
||||
"positionAbsolute": {
|
||||
"x": 380,
|
||||
"y": 282
|
||||
},
|
||||
"sourcePosition": "right",
|
||||
"targetPosition": "left",
|
||||
"type": "custom",
|
||||
"width": 244
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
workflow_entry = WorkflowEntry()
|
||||
graph = workflow_entry._init_graph(
|
||||
graph_config=graph_config
|
||||
)
|
||||
|
||||
# start 1720001771022 -> code 1720001956578 -> iteration 1720001776597 -> llm llm -> answer answer
|
||||
# iteration 1720001776597:
|
||||
# [template 1720001783092 -> llm 1720001859851 -> answer 1720001879621]
|
||||
|
||||
main_graph_orders = [
|
||||
"1720001771022",
|
||||
"1720001956578",
|
||||
"1720001776597",
|
||||
"llm",
|
||||
"answer"
|
||||
]
|
||||
|
||||
iteration_sub_graph_orders = [
|
||||
"1720001783092",
|
||||
"1720001859851",
|
||||
"1720001879621"
|
||||
]
|
||||
|
||||
assert graph.root_node.id == "1720001771022"
|
||||
|
||||
print("")
|
||||
|
||||
current_graph = graph
|
||||
for i, node_id in enumerate(main_graph_orders):
|
||||
current_root_node = current_graph.root_node
|
||||
assert current_root_node is not None
|
||||
assert current_root_node.id == node_id
|
||||
|
||||
if current_root_node.node_config.get("data", {}).get("type") == "iteration":
|
||||
assert current_root_node.sub_graph is not None
|
||||
|
||||
sub_graph = current_root_node.sub_graph
|
||||
assert sub_graph.root_node.id == "1720001783092"
|
||||
|
||||
current_sub_graph = sub_graph
|
||||
for j, sub_node_id in enumerate(iteration_sub_graph_orders):
|
||||
sub_descendant_graphs = current_sub_graph.get_descendant_graphs(node_id=current_sub_graph.root_node.id)
|
||||
print(f"Iteration [{current_sub_graph.root_node.id}] -> {len(sub_descendant_graphs)}"
|
||||
f" {[sub_descendant_graph.root_node.id for sub_descendant_graph in sub_descendant_graphs]}")
|
||||
|
||||
if j == len(iteration_sub_graph_orders) - 1:
|
||||
break
|
||||
|
||||
assert len(sub_descendant_graphs) == 1
|
||||
|
||||
first_sub_descendant_graph = sub_descendant_graphs[0]
|
||||
assert first_sub_descendant_graph.root_node.id == iteration_sub_graph_orders[j + 1]
|
||||
assert first_sub_descendant_graph.root_node.predecessor_node_id == sub_node_id
|
||||
|
||||
current_sub_graph = first_sub_descendant_graph
|
||||
|
||||
descendant_graphs = current_graph.get_descendant_graphs(node_id=current_graph.root_node.id)
|
||||
print(f"[{current_graph.root_node.id}] -> {len(descendant_graphs)}"
|
||||
f" {[descendant_graph.root_node.id for descendant_graph in descendant_graphs]}")
|
||||
if i == len(main_graph_orders) - 1:
|
||||
assert len(descendant_graphs) == 0
|
||||
break
|
||||
|
||||
assert len(descendant_graphs) == 1
|
||||
|
||||
first_descendant_graph = descendant_graphs[0]
|
||||
assert first_descendant_graph.root_node.id == main_graph_orders[i + 1]
|
||||
assert first_descendant_graph.root_node.predecessor_node_id == node_id
|
||||
|
||||
current_graph = first_descendant_graph
|
||||
|
Loading…
x
Reference in New Issue
Block a user