fix(workflow): parallel not yield

This commit is contained in:
takatost 2024-08-28 16:13:38 +08:00
parent 8ba5673606
commit c2bb11405f

View File

@ -304,13 +304,17 @@ class GraphEngine:
if len(sub_edge_mappings) == 1:
final_node_id = edge.target_node_id
else:
final_node_id, parallel_generator = self._run_parallel_branches(
parallel_generator = self._run_parallel_branches(
edge_mappings=sub_edge_mappings,
in_parallel_id=in_parallel_id,
parallel_start_node_id=parallel_start_node_id
)
yield from parallel_generator
for item in parallel_generator:
if isinstance(item, str):
final_node_id = item
else:
yield item
break
@ -319,13 +323,17 @@ class GraphEngine:
next_node_id = final_node_id
else:
next_node_id, parallel_generator = self._run_parallel_branches(
parallel_generator = self._run_parallel_branches(
edge_mappings=edge_mappings,
in_parallel_id=in_parallel_id,
parallel_start_node_id=parallel_start_node_id
)
yield from parallel_generator
for item in parallel_generator:
if isinstance(item, str):
next_node_id = item
else:
yield item
if not next_node_id:
break
@ -338,7 +346,7 @@ class GraphEngine:
edge_mappings: list[GraphEdge],
in_parallel_id: Optional[str] = None,
parallel_start_node_id: Optional[str] = None,
) -> tuple[Optional[str], Generator[GraphEngineEvent, None, None]]:
) -> Generator[GraphEngineEvent | str, None, None]:
# if nodes has no run conditions, parallel run all nodes
parallel_id = self.graph.node_parallel_mapping.get(edge_mappings[0].target_node_id)
if not parallel_id:
@ -368,7 +376,6 @@ class GraphEngine:
threads.append(thread)
thread.start()
def parallel_generator() -> Generator[GraphEngineEvent, None, None]:
succeeded_count = 0
while True:
try:
@ -389,20 +396,14 @@ class GraphEngine:
except queue.Empty:
continue
generator = parallel_generator()
# Join all threads
for thread in threads:
thread.join()
# get final node id
final_node_id = parallel.end_to_node_id
if not final_node_id:
return None, generator
next_node_id = final_node_id
return final_node_id, generator
if final_node_id:
yield final_node_id
def _run_parallel_node(
self,