mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-11 19:58:59 +08:00
fix(workflow): bugs
This commit is contained in:
parent
43240fcd41
commit
5bda3a384a
@ -309,12 +309,15 @@ class Graph(BaseModel):
|
|||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _recursively_add_parallels(cls,
|
def _recursively_add_parallels(
|
||||||
edge_mapping: dict[str, list[GraphEdge]],
|
cls,
|
||||||
reverse_edge_mapping: dict[str, list[GraphEdge]],
|
edge_mapping: dict[str, list[GraphEdge]],
|
||||||
start_node_id: str,
|
reverse_edge_mapping: dict[str, list[GraphEdge]],
|
||||||
parallel_mapping: dict[str, GraphParallel],
|
start_node_id: str,
|
||||||
node_parallel_mapping: dict[str, str]) -> None:
|
parallel_mapping: dict[str, GraphParallel],
|
||||||
|
node_parallel_mapping: dict[str, str],
|
||||||
|
parent_parallel: Optional[GraphParallel] = None
|
||||||
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Recursively add parallel ids
|
Recursively add parallel ids
|
||||||
|
|
||||||
@ -322,8 +325,10 @@ class Graph(BaseModel):
|
|||||||
:param start_node_id: start from node id
|
:param start_node_id: start from node id
|
||||||
:param parallel_mapping: parallel mapping
|
:param parallel_mapping: parallel mapping
|
||||||
:param node_parallel_mapping: node parallel mapping
|
:param node_parallel_mapping: node parallel mapping
|
||||||
|
:param parent_parallel: parent parallel
|
||||||
"""
|
"""
|
||||||
target_node_edges = edge_mapping.get(start_node_id, [])
|
target_node_edges = edge_mapping.get(start_node_id, [])
|
||||||
|
parallel = None
|
||||||
if len(target_node_edges) > 1:
|
if len(target_node_edges) > 1:
|
||||||
# fetch all node ids in current parallels
|
# fetch all node ids in current parallels
|
||||||
parallel_branch_node_ids = []
|
parallel_branch_node_ids = []
|
||||||
@ -345,18 +350,7 @@ class Graph(BaseModel):
|
|||||||
|
|
||||||
# any target node id in node_parallel_mapping
|
# any target node id in node_parallel_mapping
|
||||||
if parallel_branch_node_ids:
|
if parallel_branch_node_ids:
|
||||||
# all parallel_branch_node_ids in node_parallel_mapping
|
parent_parallel_id = parent_parallel.id if parent_parallel else None
|
||||||
parent_parallel_id = None
|
|
||||||
for node_id in parallel_branch_node_ids:
|
|
||||||
if node_id in node_parallel_mapping:
|
|
||||||
parent_parallel_id = node_parallel_mapping[node_id]
|
|
||||||
break
|
|
||||||
|
|
||||||
parent_parallel = None
|
|
||||||
if parent_parallel_id:
|
|
||||||
parent_parallel = parallel_mapping.get(parent_parallel_id)
|
|
||||||
if not parent_parallel:
|
|
||||||
raise Exception(f"Parent parallel {parent_parallel_id} not found")
|
|
||||||
|
|
||||||
parallel = GraphParallel(
|
parallel = GraphParallel(
|
||||||
start_from_node_id=start_node_id,
|
start_from_node_id=start_node_id,
|
||||||
@ -375,8 +369,17 @@ class Graph(BaseModel):
|
|||||||
parallel_node_ids = []
|
parallel_node_ids = []
|
||||||
for _, node_ids in in_branch_node_ids.items():
|
for _, node_ids in in_branch_node_ids.items():
|
||||||
for node_id in node_ids:
|
for node_id in node_ids:
|
||||||
parallel_node_ids.append(node_id)
|
in_parent_parallel = True
|
||||||
node_parallel_mapping[node_id] = parallel.id
|
if parent_parallel_id:
|
||||||
|
in_parent_parallel = False
|
||||||
|
for parallel_node_id, parallel_id in node_parallel_mapping.items():
|
||||||
|
if parallel_id == parent_parallel_id and parallel_node_id == node_id:
|
||||||
|
in_parent_parallel = True
|
||||||
|
break
|
||||||
|
|
||||||
|
if in_parent_parallel:
|
||||||
|
parallel_node_ids.append(node_id)
|
||||||
|
node_parallel_mapping[node_id] = parallel.id
|
||||||
|
|
||||||
outside_parallel_target_node_ids = set()
|
outside_parallel_target_node_ids = set()
|
||||||
for node_id in parallel_node_ids:
|
for node_id in parallel_node_ids:
|
||||||
@ -418,7 +421,8 @@ class Graph(BaseModel):
|
|||||||
reverse_edge_mapping=reverse_edge_mapping,
|
reverse_edge_mapping=reverse_edge_mapping,
|
||||||
start_node_id=graph_edge.target_node_id,
|
start_node_id=graph_edge.target_node_id,
|
||||||
parallel_mapping=parallel_mapping,
|
parallel_mapping=parallel_mapping,
|
||||||
node_parallel_mapping=node_parallel_mapping
|
node_parallel_mapping=node_parallel_mapping,
|
||||||
|
parent_parallel=parallel if parallel else parent_parallel
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -538,14 +542,14 @@ class Graph(BaseModel):
|
|||||||
edge_mapping=edge_mapping
|
edge_mapping=edge_mapping
|
||||||
):
|
):
|
||||||
if node_id in merge_branch_node_ids:
|
if node_id in merge_branch_node_ids:
|
||||||
del merge_branch_node_ids[node_id]
|
del merge_branch_node_ids[node_id2]
|
||||||
elif cls._is_node2_after_node1(
|
elif cls._is_node2_after_node1(
|
||||||
node1_id=node_id2,
|
node1_id=node_id2,
|
||||||
node2_id=node_id,
|
node2_id=node_id,
|
||||||
edge_mapping=edge_mapping
|
edge_mapping=edge_mapping
|
||||||
):
|
):
|
||||||
if node_id2 in merge_branch_node_ids:
|
if node_id2 in merge_branch_node_ids:
|
||||||
del merge_branch_node_ids[node_id2]
|
del merge_branch_node_ids[node_id]
|
||||||
|
|
||||||
branches_merge_node_ids: dict[str, str] = {}
|
branches_merge_node_ids: dict[str, str] = {}
|
||||||
for node_id, branch_node_ids in merge_branch_node_ids.items():
|
for node_id, branch_node_ids in merge_branch_node_ids.items():
|
||||||
@ -613,13 +617,30 @@ class Graph(BaseModel):
|
|||||||
if start_node_id not in reverse_edge_mapping:
|
if start_node_id not in reverse_edge_mapping:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
all_routes_node_ids = []
|
all_routes_node_ids = set()
|
||||||
for _, node_ids in routes_node_ids.items():
|
parallel_start_node_ids: dict[str, list[str]] = {}
|
||||||
|
for branch_node_id, node_ids in routes_node_ids.items():
|
||||||
for node_id in node_ids:
|
for node_id in node_ids:
|
||||||
all_routes_node_ids.append(node_id)
|
all_routes_node_ids.add(node_id)
|
||||||
|
|
||||||
|
if branch_node_id in reverse_edge_mapping:
|
||||||
|
for graph_edge in reverse_edge_mapping[branch_node_id]:
|
||||||
|
if graph_edge.source_node_id not in parallel_start_node_ids:
|
||||||
|
parallel_start_node_ids[graph_edge.source_node_id] = []
|
||||||
|
|
||||||
|
parallel_start_node_ids[graph_edge.source_node_id].append(branch_node_id)
|
||||||
|
|
||||||
|
parallel_start_node_id = None
|
||||||
|
for p_start_node_id, branch_node_ids in parallel_start_node_ids.items():
|
||||||
|
if set(branch_node_ids) == set(routes_node_ids.keys()):
|
||||||
|
parallel_start_node_id = p_start_node_id
|
||||||
|
return True
|
||||||
|
|
||||||
|
if not parallel_start_node_id:
|
||||||
|
raise Exception("Parallel start node id not found")
|
||||||
|
|
||||||
for graph_edge in reverse_edge_mapping[start_node_id]:
|
for graph_edge in reverse_edge_mapping[start_node_id]:
|
||||||
if graph_edge.source_node_id not in all_routes_node_ids:
|
if graph_edge.source_node_id not in all_routes_node_ids or graph_edge.source_node_id != parallel_start_node_id:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
@ -352,7 +352,13 @@ class GraphEngine:
|
|||||||
# if nodes has no run conditions, parallel run all nodes
|
# if nodes has no run conditions, parallel run all nodes
|
||||||
parallel_id = self.graph.node_parallel_mapping.get(edge_mappings[0].target_node_id)
|
parallel_id = self.graph.node_parallel_mapping.get(edge_mappings[0].target_node_id)
|
||||||
if not parallel_id:
|
if not parallel_id:
|
||||||
raise GraphRunFailedError(f'Node {edge_mappings[0].target_node_id} related parallel not found.')
|
node_id = edge_mappings[0].target_node_id
|
||||||
|
node_config = self.graph.node_id_config_mapping.get(node_id)
|
||||||
|
if not node_config:
|
||||||
|
raise GraphRunFailedError(f'Node {node_id} related parallel not found.')
|
||||||
|
|
||||||
|
node_title = node_config.get('data', {}).get('title')
|
||||||
|
raise GraphRunFailedError(f'Node {node_title} related parallel not found.')
|
||||||
|
|
||||||
parallel = self.graph.parallel_mapping.get(parallel_id)
|
parallel = self.graph.parallel_mapping.get(parallel_id)
|
||||||
if not parallel:
|
if not parallel:
|
||||||
|
@ -700,6 +700,11 @@ def test_parallels_graph6():
|
|||||||
"source": "code3",
|
"source": "code3",
|
||||||
"target": "answer",
|
"target": "answer",
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"id": "llm3-source-answer-target",
|
||||||
|
"source": "llm3",
|
||||||
|
"target": "answer",
|
||||||
|
},
|
||||||
],
|
],
|
||||||
"nodes": [
|
"nodes": [
|
||||||
{"data": {"type": "start"}, "id": "start"},
|
{"data": {"type": "start"}, "id": "start"},
|
||||||
|
Loading…
x
Reference in New Issue
Block a user