mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-18 04:25:54 +08:00
fix(workflow): parallel not yield
This commit is contained in:
parent
8ba5673606
commit
c2bb11405f
@ -304,13 +304,17 @@ class GraphEngine:
|
|||||||
if len(sub_edge_mappings) == 1:
|
if len(sub_edge_mappings) == 1:
|
||||||
final_node_id = edge.target_node_id
|
final_node_id = edge.target_node_id
|
||||||
else:
|
else:
|
||||||
final_node_id, parallel_generator = self._run_parallel_branches(
|
parallel_generator = self._run_parallel_branches(
|
||||||
edge_mappings=sub_edge_mappings,
|
edge_mappings=sub_edge_mappings,
|
||||||
in_parallel_id=in_parallel_id,
|
in_parallel_id=in_parallel_id,
|
||||||
parallel_start_node_id=parallel_start_node_id
|
parallel_start_node_id=parallel_start_node_id
|
||||||
)
|
)
|
||||||
|
|
||||||
yield from parallel_generator
|
for item in parallel_generator:
|
||||||
|
if isinstance(item, str):
|
||||||
|
final_node_id = item
|
||||||
|
else:
|
||||||
|
yield item
|
||||||
|
|
||||||
break
|
break
|
||||||
|
|
||||||
@ -319,13 +323,17 @@ class GraphEngine:
|
|||||||
|
|
||||||
next_node_id = final_node_id
|
next_node_id = final_node_id
|
||||||
else:
|
else:
|
||||||
next_node_id, parallel_generator = self._run_parallel_branches(
|
parallel_generator = self._run_parallel_branches(
|
||||||
edge_mappings=edge_mappings,
|
edge_mappings=edge_mappings,
|
||||||
in_parallel_id=in_parallel_id,
|
in_parallel_id=in_parallel_id,
|
||||||
parallel_start_node_id=parallel_start_node_id
|
parallel_start_node_id=parallel_start_node_id
|
||||||
)
|
)
|
||||||
|
|
||||||
yield from parallel_generator
|
for item in parallel_generator:
|
||||||
|
if isinstance(item, str):
|
||||||
|
next_node_id = item
|
||||||
|
else:
|
||||||
|
yield item
|
||||||
|
|
||||||
if not next_node_id:
|
if not next_node_id:
|
||||||
break
|
break
|
||||||
@ -338,7 +346,7 @@ class GraphEngine:
|
|||||||
edge_mappings: list[GraphEdge],
|
edge_mappings: list[GraphEdge],
|
||||||
in_parallel_id: Optional[str] = None,
|
in_parallel_id: Optional[str] = None,
|
||||||
parallel_start_node_id: Optional[str] = None,
|
parallel_start_node_id: Optional[str] = None,
|
||||||
) -> tuple[Optional[str], Generator[GraphEngineEvent, None, None]]:
|
) -> Generator[GraphEngineEvent | str, None, None]:
|
||||||
# if nodes has no run conditions, parallel run all nodes
|
# if nodes has no run conditions, parallel run all nodes
|
||||||
parallel_id = self.graph.node_parallel_mapping.get(edge_mappings[0].target_node_id)
|
parallel_id = self.graph.node_parallel_mapping.get(edge_mappings[0].target_node_id)
|
||||||
if not parallel_id:
|
if not parallel_id:
|
||||||
@ -368,41 +376,34 @@ class GraphEngine:
|
|||||||
threads.append(thread)
|
threads.append(thread)
|
||||||
thread.start()
|
thread.start()
|
||||||
|
|
||||||
def parallel_generator() -> Generator[GraphEngineEvent, None, None]:
|
succeeded_count = 0
|
||||||
succeeded_count = 0
|
while True:
|
||||||
while True:
|
try:
|
||||||
try:
|
event = q.get(timeout=1)
|
||||||
event = q.get(timeout=1)
|
if event is None:
|
||||||
if event is None:
|
break
|
||||||
break
|
|
||||||
|
|
||||||
yield event
|
yield event
|
||||||
if event.parallel_id == parallel_id:
|
if event.parallel_id == parallel_id:
|
||||||
if isinstance(event, ParallelBranchRunSucceededEvent):
|
if isinstance(event, ParallelBranchRunSucceededEvent):
|
||||||
succeeded_count += 1
|
succeeded_count += 1
|
||||||
if succeeded_count == len(threads):
|
if succeeded_count == len(threads):
|
||||||
q.put(None)
|
q.put(None)
|
||||||
|
|
||||||
continue
|
continue
|
||||||
elif isinstance(event, ParallelBranchRunFailedEvent):
|
elif isinstance(event, ParallelBranchRunFailedEvent):
|
||||||
raise GraphRunFailedError(event.error)
|
raise GraphRunFailedError(event.error)
|
||||||
except queue.Empty:
|
except queue.Empty:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
generator = parallel_generator()
|
|
||||||
|
|
||||||
# Join all threads
|
# Join all threads
|
||||||
for thread in threads:
|
for thread in threads:
|
||||||
thread.join()
|
thread.join()
|
||||||
|
|
||||||
# get final node id
|
# get final node id
|
||||||
final_node_id = parallel.end_to_node_id
|
final_node_id = parallel.end_to_node_id
|
||||||
if not final_node_id:
|
if final_node_id:
|
||||||
return None, generator
|
yield final_node_id
|
||||||
|
|
||||||
next_node_id = final_node_id
|
|
||||||
|
|
||||||
return final_node_id, generator
|
|
||||||
|
|
||||||
def _run_parallel_node(
|
def _run_parallel_node(
|
||||||
self,
|
self,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user