diff --git a/api/core/workflow/graph_engine/graph_engine.py b/api/core/workflow/graph_engine/graph_engine.py index 14676b66cc..8e5bb06c40 100644 --- a/api/core/workflow/graph_engine/graph_engine.py +++ b/api/core/workflow/graph_engine/graph_engine.py @@ -304,13 +304,17 @@ class GraphEngine: if len(sub_edge_mappings) == 1: final_node_id = edge.target_node_id else: - final_node_id, parallel_generator = self._run_parallel_branches( + parallel_generator = self._run_parallel_branches( edge_mappings=sub_edge_mappings, in_parallel_id=in_parallel_id, parallel_start_node_id=parallel_start_node_id ) - yield from parallel_generator + for item in parallel_generator: + if isinstance(item, str): + final_node_id = item + else: + yield item break @@ -319,13 +323,17 @@ class GraphEngine: next_node_id = final_node_id else: - next_node_id, parallel_generator = self._run_parallel_branches( + parallel_generator = self._run_parallel_branches( edge_mappings=edge_mappings, in_parallel_id=in_parallel_id, parallel_start_node_id=parallel_start_node_id ) - yield from parallel_generator + for item in parallel_generator: + if isinstance(item, str): + next_node_id = item + else: + yield item if not next_node_id: break @@ -338,7 +346,7 @@ class GraphEngine: edge_mappings: list[GraphEdge], in_parallel_id: Optional[str] = None, parallel_start_node_id: Optional[str] = None, - ) -> tuple[Optional[str], Generator[GraphEngineEvent, None, None]]: + ) -> Generator[GraphEngineEvent | str, None, None]: # if nodes has no run conditions, parallel run all nodes parallel_id = self.graph.node_parallel_mapping.get(edge_mappings[0].target_node_id) if not parallel_id: @@ -368,41 +376,34 @@ class GraphEngine: threads.append(thread) thread.start() - def parallel_generator() -> Generator[GraphEngineEvent, None, None]: - succeeded_count = 0 - while True: - try: - event = q.get(timeout=1) - if event is None: - break + succeeded_count = 0 + while True: + try: + event = q.get(timeout=1) + if event is None: + break - yield event - if event.parallel_id == parallel_id: - if isinstance(event, ParallelBranchRunSucceededEvent): - succeeded_count += 1 - if succeeded_count == len(threads): - q.put(None) + yield event + if event.parallel_id == parallel_id: + if isinstance(event, ParallelBranchRunSucceededEvent): + succeeded_count += 1 + if succeeded_count == len(threads): + q.put(None) - continue - elif isinstance(event, ParallelBranchRunFailedEvent): - raise GraphRunFailedError(event.error) - except queue.Empty: - continue + continue + elif isinstance(event, ParallelBranchRunFailedEvent): + raise GraphRunFailedError(event.error) + except queue.Empty: + continue - generator = parallel_generator() - # Join all threads for thread in threads: thread.join() # get final node id final_node_id = parallel.end_to_node_id - if not final_node_id: - return None, generator - - next_node_id = final_node_id - - return final_node_id, generator + if final_node_id: + yield final_node_id def _run_parallel_node( self,