fix(workflow): parallel not yield

This commit is contained in:
takatost 2024-08-28 16:13:38 +08:00
parent 8ba5673606
commit c2bb11405f

View File

@ -304,13 +304,17 @@ class GraphEngine:
if len(sub_edge_mappings) == 1: if len(sub_edge_mappings) == 1:
final_node_id = edge.target_node_id final_node_id = edge.target_node_id
else: else:
final_node_id, parallel_generator = self._run_parallel_branches( parallel_generator = self._run_parallel_branches(
edge_mappings=sub_edge_mappings, edge_mappings=sub_edge_mappings,
in_parallel_id=in_parallel_id, in_parallel_id=in_parallel_id,
parallel_start_node_id=parallel_start_node_id parallel_start_node_id=parallel_start_node_id
) )
yield from parallel_generator for item in parallel_generator:
if isinstance(item, str):
final_node_id = item
else:
yield item
break break
@ -319,13 +323,17 @@ class GraphEngine:
next_node_id = final_node_id next_node_id = final_node_id
else: else:
next_node_id, parallel_generator = self._run_parallel_branches( parallel_generator = self._run_parallel_branches(
edge_mappings=edge_mappings, edge_mappings=edge_mappings,
in_parallel_id=in_parallel_id, in_parallel_id=in_parallel_id,
parallel_start_node_id=parallel_start_node_id parallel_start_node_id=parallel_start_node_id
) )
yield from parallel_generator for item in parallel_generator:
if isinstance(item, str):
next_node_id = item
else:
yield item
if not next_node_id: if not next_node_id:
break break
@ -338,7 +346,7 @@ class GraphEngine:
edge_mappings: list[GraphEdge], edge_mappings: list[GraphEdge],
in_parallel_id: Optional[str] = None, in_parallel_id: Optional[str] = None,
parallel_start_node_id: Optional[str] = None, parallel_start_node_id: Optional[str] = None,
) -> tuple[Optional[str], Generator[GraphEngineEvent, None, None]]: ) -> Generator[GraphEngineEvent | str, None, None]:
# if nodes has no run conditions, parallel run all nodes # if nodes has no run conditions, parallel run all nodes
parallel_id = self.graph.node_parallel_mapping.get(edge_mappings[0].target_node_id) parallel_id = self.graph.node_parallel_mapping.get(edge_mappings[0].target_node_id)
if not parallel_id: if not parallel_id:
@ -368,7 +376,6 @@ class GraphEngine:
threads.append(thread) threads.append(thread)
thread.start() thread.start()
def parallel_generator() -> Generator[GraphEngineEvent, None, None]:
succeeded_count = 0 succeeded_count = 0
while True: while True:
try: try:
@ -389,20 +396,14 @@ class GraphEngine:
except queue.Empty: except queue.Empty:
continue continue
generator = parallel_generator()
# Join all threads # Join all threads
for thread in threads: for thread in threads:
thread.join() thread.join()
# get final node id # get final node id
final_node_id = parallel.end_to_node_id final_node_id = parallel.end_to_node_id
if not final_node_id: if final_node_id:
return None, generator yield final_node_id
next_node_id = final_node_id
return final_node_id, generator
def _run_parallel_node( def _run_parallel_node(
self, self,