mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-14 17:05:54 +08:00
fix(workflow): duplicate nodes in parallel
This commit is contained in:
parent
5d34e080eb
commit
790dd3b22f
@ -335,8 +335,10 @@ class Graph(BaseModel):
|
||||
if parallel_node_ids:
|
||||
# all parallel_node_ids in node_parallel_mapping
|
||||
parent_parallel_id = None
|
||||
if all(node_id in node_parallel_mapping for node_id in parallel_node_ids):
|
||||
parent_parallel_id = node_parallel_mapping[parallel_node_ids[0]]
|
||||
for node_id in parallel_node_ids:
|
||||
if node_id in node_parallel_mapping:
|
||||
parent_parallel_id = node_parallel_mapping[node_id]
|
||||
break
|
||||
|
||||
parent_parallel = None
|
||||
if parent_parallel_id:
|
||||
@ -392,7 +394,10 @@ class Graph(BaseModel):
|
||||
outside_parallel_target_node_ids.add(target_node_id)
|
||||
|
||||
if len(outside_parallel_target_node_ids) == 1:
|
||||
parallel.end_to_node_id = outside_parallel_target_node_ids.pop()
|
||||
if parent_parallel and parent_parallel.end_to_node_id and parallel.end_to_node_id == parent_parallel.end_to_node_id:
|
||||
parallel.end_to_node_id = None
|
||||
else:
|
||||
parallel.end_to_node_id = outside_parallel_target_node_ids.pop()
|
||||
|
||||
for graph_edge in target_node_edges:
|
||||
cls._recursively_add_parallels(
|
||||
@ -436,7 +441,7 @@ class Graph(BaseModel):
|
||||
"""
|
||||
routes_node_ids: dict[str, list[str]] = {}
|
||||
for parallel_node_id in parallel_node_ids:
|
||||
routes_node_ids[parallel_node_id] = []
|
||||
routes_node_ids[parallel_node_id] = [parallel_node_id]
|
||||
|
||||
# fetch routes node ids
|
||||
cls._recursively_fetch_routes(
|
||||
@ -479,12 +484,16 @@ class Graph(BaseModel):
|
||||
|
||||
in_branch_node_ids: dict[str, list[str]] = {}
|
||||
for branch_node_id, node_ids in routes_node_ids.items():
|
||||
in_branch_node_ids[branch_node_id] = [branch_node_id]
|
||||
in_branch_node_ids[branch_node_id] = []
|
||||
if branch_node_id not in branches_merge_node_ids:
|
||||
# all node ids in current branch is in this thread
|
||||
in_branch_node_ids[branch_node_id].append(branch_node_id)
|
||||
in_branch_node_ids[branch_node_id].extend(node_ids)
|
||||
else:
|
||||
merge_node_id = branches_merge_node_ids[branch_node_id]
|
||||
if merge_node_id != branch_node_id:
|
||||
in_branch_node_ids[branch_node_id].append(branch_node_id)
|
||||
|
||||
# fetch all node ids from branch_node_id and merge_node_id
|
||||
cls._recursively_add_parallel_node_ids(
|
||||
branch_node_ids=in_branch_node_ids[branch_node_id],
|
||||
|
@ -366,6 +366,12 @@ class GraphEngine:
|
||||
|
||||
# new thread
|
||||
for edge in edge_mappings:
|
||||
if (
|
||||
edge.target_node_id not in self.graph.node_parallel_mapping
|
||||
or self.graph.node_parallel_mapping.get(edge.target_node_id, '') != parallel_id
|
||||
):
|
||||
continue
|
||||
|
||||
thread = threading.Thread(target=self._run_parallel_node, kwargs={
|
||||
'flask_app': current_app._get_current_object(), # type: ignore[attr-defined]
|
||||
'q': q,
|
||||
|
Loading…
x
Reference in New Issue
Block a user