fix(workflow): parallel not yield

This commit is contained in:
takatost 2024-08-28 16:13:38 +08:00
parent 8ba5673606
commit c2bb11405f

View File

@ -304,13 +304,17 @@ class GraphEngine:
if len(sub_edge_mappings) == 1: if len(sub_edge_mappings) == 1:
final_node_id = edge.target_node_id final_node_id = edge.target_node_id
else: else:
final_node_id, parallel_generator = self._run_parallel_branches( parallel_generator = self._run_parallel_branches(
edge_mappings=sub_edge_mappings, edge_mappings=sub_edge_mappings,
in_parallel_id=in_parallel_id, in_parallel_id=in_parallel_id,
parallel_start_node_id=parallel_start_node_id parallel_start_node_id=parallel_start_node_id
) )
yield from parallel_generator for item in parallel_generator:
if isinstance(item, str):
final_node_id = item
else:
yield item
break break
@ -319,13 +323,17 @@ class GraphEngine:
next_node_id = final_node_id next_node_id = final_node_id
else: else:
next_node_id, parallel_generator = self._run_parallel_branches( parallel_generator = self._run_parallel_branches(
edge_mappings=edge_mappings, edge_mappings=edge_mappings,
in_parallel_id=in_parallel_id, in_parallel_id=in_parallel_id,
parallel_start_node_id=parallel_start_node_id parallel_start_node_id=parallel_start_node_id
) )
yield from parallel_generator for item in parallel_generator:
if isinstance(item, str):
next_node_id = item
else:
yield item
if not next_node_id: if not next_node_id:
break break
@ -338,7 +346,7 @@ class GraphEngine:
edge_mappings: list[GraphEdge], edge_mappings: list[GraphEdge],
in_parallel_id: Optional[str] = None, in_parallel_id: Optional[str] = None,
parallel_start_node_id: Optional[str] = None, parallel_start_node_id: Optional[str] = None,
) -> tuple[Optional[str], Generator[GraphEngineEvent, None, None]]: ) -> Generator[GraphEngineEvent | str, None, None]:
# if nodes has no run conditions, parallel run all nodes # if nodes has no run conditions, parallel run all nodes
parallel_id = self.graph.node_parallel_mapping.get(edge_mappings[0].target_node_id) parallel_id = self.graph.node_parallel_mapping.get(edge_mappings[0].target_node_id)
if not parallel_id: if not parallel_id:
@ -368,41 +376,34 @@ class GraphEngine:
threads.append(thread) threads.append(thread)
thread.start() thread.start()
def parallel_generator() -> Generator[GraphEngineEvent, None, None]: succeeded_count = 0
succeeded_count = 0 while True:
while True: try:
try: event = q.get(timeout=1)
event = q.get(timeout=1) if event is None:
if event is None: break
break
yield event yield event
if event.parallel_id == parallel_id: if event.parallel_id == parallel_id:
if isinstance(event, ParallelBranchRunSucceededEvent): if isinstance(event, ParallelBranchRunSucceededEvent):
succeeded_count += 1 succeeded_count += 1
if succeeded_count == len(threads): if succeeded_count == len(threads):
q.put(None) q.put(None)
continue continue
elif isinstance(event, ParallelBranchRunFailedEvent): elif isinstance(event, ParallelBranchRunFailedEvent):
raise GraphRunFailedError(event.error) raise GraphRunFailedError(event.error)
except queue.Empty: except queue.Empty:
continue continue
generator = parallel_generator()
# Join all threads # Join all threads
for thread in threads: for thread in threads:
thread.join() thread.join()
# get final node id # get final node id
final_node_id = parallel.end_to_node_id final_node_id = parallel.end_to_node_id
if not final_node_id: if final_node_id:
return None, generator yield final_node_id
next_node_id = final_node_id
return final_node_id, generator
def _run_parallel_node( def _run_parallel_node(
self, self,