mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-15 00:05:54 +08:00
fix(workflow): fix merge branch node id err
This commit is contained in:
parent
29b1ce781d
commit
0dabf799c0
@ -174,6 +174,7 @@ class Graph(BaseModel):
|
|||||||
node_parallel_mapping: dict[str, str] = {}
|
node_parallel_mapping: dict[str, str] = {}
|
||||||
cls._recursively_add_parallels(
|
cls._recursively_add_parallels(
|
||||||
edge_mapping=edge_mapping,
|
edge_mapping=edge_mapping,
|
||||||
|
reverse_edge_mapping=reverse_edge_mapping,
|
||||||
start_node_id=root_node_id,
|
start_node_id=root_node_id,
|
||||||
parallel_mapping=parallel_mapping,
|
parallel_mapping=parallel_mapping,
|
||||||
node_parallel_mapping=node_parallel_mapping
|
node_parallel_mapping=node_parallel_mapping
|
||||||
@ -310,6 +311,7 @@ class Graph(BaseModel):
|
|||||||
@classmethod
|
@classmethod
|
||||||
def _recursively_add_parallels(cls,
|
def _recursively_add_parallels(cls,
|
||||||
edge_mapping: dict[str, list[GraphEdge]],
|
edge_mapping: dict[str, list[GraphEdge]],
|
||||||
|
reverse_edge_mapping: dict[str, list[GraphEdge]],
|
||||||
start_node_id: str,
|
start_node_id: str,
|
||||||
parallel_mapping: dict[str, GraphParallel],
|
parallel_mapping: dict[str, GraphParallel],
|
||||||
node_parallel_mapping: dict[str, str]) -> None:
|
node_parallel_mapping: dict[str, str]) -> None:
|
||||||
@ -365,6 +367,7 @@ class Graph(BaseModel):
|
|||||||
|
|
||||||
in_branch_node_ids = cls._fetch_all_node_ids_in_parallels(
|
in_branch_node_ids = cls._fetch_all_node_ids_in_parallels(
|
||||||
edge_mapping=edge_mapping,
|
edge_mapping=edge_mapping,
|
||||||
|
reverse_edge_mapping=reverse_edge_mapping,
|
||||||
parallel_branch_node_ids=parallel_branch_node_ids
|
parallel_branch_node_ids=parallel_branch_node_ids
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -412,6 +415,7 @@ class Graph(BaseModel):
|
|||||||
for graph_edge in target_node_edges:
|
for graph_edge in target_node_edges:
|
||||||
cls._recursively_add_parallels(
|
cls._recursively_add_parallels(
|
||||||
edge_mapping=edge_mapping,
|
edge_mapping=edge_mapping,
|
||||||
|
reverse_edge_mapping=reverse_edge_mapping,
|
||||||
start_node_id=graph_edge.target_node_id,
|
start_node_id=graph_edge.target_node_id,
|
||||||
parallel_mapping=parallel_mapping,
|
parallel_mapping=parallel_mapping,
|
||||||
node_parallel_mapping=node_parallel_mapping
|
node_parallel_mapping=node_parallel_mapping
|
||||||
@ -472,6 +476,7 @@ class Graph(BaseModel):
|
|||||||
@classmethod
|
@classmethod
|
||||||
def _fetch_all_node_ids_in_parallels(cls,
|
def _fetch_all_node_ids_in_parallels(cls,
|
||||||
edge_mapping: dict[str, list[GraphEdge]],
|
edge_mapping: dict[str, list[GraphEdge]],
|
||||||
|
reverse_edge_mapping: dict[str, list[GraphEdge]],
|
||||||
parallel_branch_node_ids: list[str]) -> dict[str, list[str]]:
|
parallel_branch_node_ids: list[str]) -> dict[str, list[str]]:
|
||||||
"""
|
"""
|
||||||
Fetch all node ids in parallels
|
Fetch all node ids in parallels
|
||||||
@ -499,7 +504,11 @@ class Graph(BaseModel):
|
|||||||
leaf_node_ids[branch_node_id].append(node_id)
|
leaf_node_ids[branch_node_id].append(node_id)
|
||||||
|
|
||||||
for branch_node_id2, inner_route2 in routes_node_ids.items():
|
for branch_node_id2, inner_route2 in routes_node_ids.items():
|
||||||
if branch_node_id != branch_node_id2 and node_id in inner_route2:
|
if (
|
||||||
|
branch_node_id != branch_node_id2
|
||||||
|
and node_id in inner_route2
|
||||||
|
and len(reverse_edge_mapping.get(node_id, [])) > 1
|
||||||
|
):
|
||||||
if node_id not in merge_branch_node_ids:
|
if node_id not in merge_branch_node_ids:
|
||||||
merge_branch_node_ids[node_id] = []
|
merge_branch_node_ids[node_id] = []
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user