fix(workflow): fix merge branch node id err

This commit is contained in:
takatost 2024-09-02 13:56:07 +08:00
parent 0dabf799c0
commit 52b4623131

View File

@ -508,11 +508,17 @@ class Graph(BaseModel):
branch_node_id != branch_node_id2
and node_id in inner_route2
and len(reverse_edge_mapping.get(node_id, [])) > 1
and cls._is_node_in_routes(
reverse_edge_mapping=reverse_edge_mapping,
start_node_id=node_id,
routes_node_ids=routes_node_ids
)
):
if node_id not in merge_branch_node_ids:
merge_branch_node_ids[node_id] = []
merge_branch_node_ids[node_id].append(branch_node_id2)
if branch_node_id2 not in merge_branch_node_ids[node_id]:
merge_branch_node_ids[node_id].append(branch_node_id2)
# sorted merge_branch_node_ids by branch_node_ids length desc
merge_branch_node_ids = dict(sorted(merge_branch_node_ids.items(), key=lambda x: len(x[1]), reverse=True))
@ -596,6 +602,28 @@ class Graph(BaseModel):
routes_node_ids=routes_node_ids
)
@classmethod
def _is_node_in_routes(cls,
reverse_edge_mapping: dict[str, list[GraphEdge]],
start_node_id: str,
routes_node_ids: dict[str, list[str]]) -> bool:
"""
Recursively check if the node is in the routes
"""
if start_node_id not in reverse_edge_mapping:
return False
all_routes_node_ids = []
for _, node_ids in routes_node_ids.items():
for node_id in node_ids:
all_routes_node_ids.append(node_id)
for graph_edge in reverse_edge_mapping[start_node_id]:
if graph_edge.source_node_id not in all_routes_node_ids:
return False
return True
@classmethod
def _is_node2_after_node1(
cls,