diff --git a/api/core/workflow/graph_engine/condition_handlers/base_handler.py b/api/core/workflow/graph_engine/condition_handlers/base_handler.py index 80f6b6a8a7..4099def4e2 100644 --- a/api/core/workflow/graph_engine/condition_handlers/base_handler.py +++ b/api/core/workflow/graph_engine/condition_handlers/base_handler.py @@ -19,14 +19,13 @@ class RunConditionHandler(ABC): @abstractmethod def check(self, graph_runtime_state: GraphRuntimeState, - previous_route_node_state: RouteNodeState, - target_node_id: str) -> bool: + previous_route_node_state: RouteNodeState + ) -> bool: """ Check if the condition can be executed :param graph_runtime_state: graph runtime state :param previous_route_node_state: previous route node state - :param target_node_id: target node id :return: bool """ raise NotImplementedError diff --git a/api/core/workflow/graph_engine/condition_handlers/branch_identify_handler.py b/api/core/workflow/graph_engine/condition_handlers/branch_identify_handler.py index 17212440f7..705eb908b1 100644 --- a/api/core/workflow/graph_engine/condition_handlers/branch_identify_handler.py +++ b/api/core/workflow/graph_engine/condition_handlers/branch_identify_handler.py @@ -7,14 +7,12 @@ class BranchIdentifyRunConditionHandler(RunConditionHandler): def check(self, graph_runtime_state: GraphRuntimeState, - previous_route_node_state: RouteNodeState, - target_node_id: str) -> bool: + previous_route_node_state: RouteNodeState) -> bool: """ Check if the condition can be executed :param graph_runtime_state: graph runtime state :param previous_route_node_state: previous route node state - :param target_node_id: target node id :return: bool """ if not self.condition.branch_identify: diff --git a/api/core/workflow/graph_engine/condition_handlers/condition_handler.py b/api/core/workflow/graph_engine/condition_handlers/condition_handler.py index fce5ef13f1..1edaf92da7 100644 --- a/api/core/workflow/graph_engine/condition_handlers/condition_handler.py +++ b/api/core/workflow/graph_engine/condition_handlers/condition_handler.py @@ -7,14 +7,13 @@ from core.workflow.utils.condition.processor import ConditionProcessor class ConditionRunConditionHandlerHandler(RunConditionHandler): def check(self, graph_runtime_state: GraphRuntimeState, - previous_route_node_state: RouteNodeState, - target_node_id: str) -> bool: + previous_route_node_state: RouteNodeState + ) -> bool: """ Check if the condition can be executed :param graph_runtime_state: graph runtime state :param previous_route_node_state: previous route node state - :param target_node_id: target node id :return: bool """ if not self.condition.conditions: diff --git a/api/core/workflow/graph_engine/entities/graph.py b/api/core/workflow/graph_engine/entities/graph.py index 853eee7126..32f2859659 100644 --- a/api/core/workflow/graph_engine/entities/graph.py +++ b/api/core/workflow/graph_engine/entities/graph.py @@ -314,8 +314,22 @@ class Graph(BaseModel): target_node_edges = edge_mapping.get(start_node_id, []) if len(target_node_edges) > 1: # fetch all node ids in current parallels - parallel_node_ids = [graph_edge.target_node_id - for graph_edge in target_node_edges if graph_edge.run_condition is None] + parallel_node_ids = [] + condition_edge_mappings = {} + for graph_edge in target_node_edges: + if graph_edge.run_condition is None: + parallel_node_ids.append(graph_edge.target_node_id) + else: + condition_hash = graph_edge.run_condition.hash + if not condition_hash in condition_edge_mappings: + condition_edge_mappings[condition_hash] = [] + + condition_edge_mappings[condition_hash].append(graph_edge) + + for _, graph_edges in condition_edge_mappings.items(): + if len(graph_edges) > 1: + for graph_edge in graph_edges: + parallel_node_ids.append(graph_edge.target_node_id) # any target node id in node_parallel_mapping if parallel_node_ids: diff --git a/api/core/workflow/graph_engine/entities/run_condition.py b/api/core/workflow/graph_engine/entities/run_condition.py index 46fca1b032..0362343568 100644 --- a/api/core/workflow/graph_engine/entities/run_condition.py +++ b/api/core/workflow/graph_engine/entities/run_condition.py @@ -1,3 +1,4 @@ +import hashlib from typing import Literal, Optional from pydantic import BaseModel @@ -14,3 +15,7 @@ class RunCondition(BaseModel): conditions: Optional[list[Condition]] = None """conditions to run the node, required when type is condition""" + + @property + def hash(self) -> str: + return hashlib.sha256(self.model_dump_json().encode()).hexdigest() \ No newline at end of file diff --git a/api/core/workflow/graph_engine/graph_engine.py b/api/core/workflow/graph_engine/graph_engine.py index 29689a8feb..14676b66cc 100644 --- a/api/core/workflow/graph_engine/graph_engine.py +++ b/api/core/workflow/graph_engine/graph_engine.py @@ -32,7 +32,7 @@ from core.workflow.graph_engine.entities.event import ( ParallelBranchRunStartedEvent, ParallelBranchRunSucceededEvent, ) -from core.workflow.graph_engine.entities.graph import Graph +from core.workflow.graph_engine.entities.graph import Graph, GraphEdge from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState @@ -262,8 +262,7 @@ class GraphEngine: run_condition=edge.run_condition, ).check( graph_runtime_state=self.graph_runtime_state, - previous_route_node_state=previous_route_node_state, - target_node_id=edge.target_node_id, + previous_route_node_state=previous_route_node_state ) if not result: @@ -274,90 +273,137 @@ class GraphEngine: if any(edge.run_condition for edge in edge_mappings): # if nodes has run conditions, get node id which branch to take based on the run condition results final_node_id = None + + condition_edge_mappings = {} for edge in edge_mappings: if edge.run_condition: - result = ConditionManager.get_condition_handler( - init_params=self.init_params, - graph=self.graph, - run_condition=edge.run_condition, - ).check( - graph_runtime_state=self.graph_runtime_state, - previous_route_node_state=previous_route_node_state, - target_node_id=edge.target_node_id, + run_condition_hash = edge.run_condition.hash + if run_condition_hash not in condition_edge_mappings: + condition_edge_mappings[run_condition_hash] = [] + + condition_edge_mappings[run_condition_hash].append(edge) + + for _, sub_edge_mappings in condition_edge_mappings.items(): + if len(sub_edge_mappings) == 0: + continue + + edge = sub_edge_mappings[0] + + result = ConditionManager.get_condition_handler( + init_params=self.init_params, + graph=self.graph, + run_condition=edge.run_condition, + ).check( + graph_runtime_state=self.graph_runtime_state, + previous_route_node_state=previous_route_node_state, + ) + + if not result: + continue + + if len(sub_edge_mappings) == 1: + final_node_id = edge.target_node_id + else: + final_node_id, parallel_generator = self._run_parallel_branches( + edge_mappings=sub_edge_mappings, + in_parallel_id=in_parallel_id, + parallel_start_node_id=parallel_start_node_id ) - if result: - final_node_id = edge.target_node_id - break + yield from parallel_generator + + break if not final_node_id: break next_node_id = final_node_id else: - # if nodes has no run conditions, parallel run all nodes - parallel_id = self.graph.node_parallel_mapping.get(edge_mappings[0].target_node_id) - if not parallel_id: - raise GraphRunFailedError(f'Node {edge_mappings[0].target_node_id} related parallel not found.') + next_node_id, parallel_generator = self._run_parallel_branches( + edge_mappings=edge_mappings, + in_parallel_id=in_parallel_id, + parallel_start_node_id=parallel_start_node_id + ) - parallel = self.graph.parallel_mapping.get(parallel_id) - if not parallel: - raise GraphRunFailedError(f'Parallel {parallel_id} not found.') + yield from parallel_generator - # run parallel nodes, run in new thread and use queue to get results - q: queue.Queue = queue.Queue() - - # Create a list to store the threads - threads = [] - - # new thread - for edge in edge_mappings: - thread = threading.Thread(target=self._run_parallel_node, kwargs={ - 'flask_app': current_app._get_current_object(), # type: ignore[attr-defined] - 'q': q, - 'parallel_id': parallel_id, - 'parallel_start_node_id': edge.target_node_id, - 'parent_parallel_id': in_parallel_id, - 'parent_parallel_start_node_id': parallel_start_node_id, - }) - - threads.append(thread) - thread.start() - - succeeded_count = 0 - while True: - try: - event = q.get(timeout=1) - if event is None: - break - - yield event - if event.parallel_id == parallel_id: - if isinstance(event, ParallelBranchRunSucceededEvent): - succeeded_count += 1 - if succeeded_count == len(threads): - q.put(None) - - continue - elif isinstance(event, ParallelBranchRunFailedEvent): - raise GraphRunFailedError(event.error) - except queue.Empty: - continue - - # Join all threads - for thread in threads: - thread.join() - - # get final node id - final_node_id = parallel.end_to_node_id - if not final_node_id: + if not next_node_id: break - next_node_id = final_node_id - if in_parallel_id and self.graph.node_parallel_mapping.get(next_node_id, '') != in_parallel_id: break + def _run_parallel_branches( + self, + edge_mappings: list[GraphEdge], + in_parallel_id: Optional[str] = None, + parallel_start_node_id: Optional[str] = None, + ) -> tuple[Optional[str], Generator[GraphEngineEvent, None, None]]: + # if nodes has no run conditions, parallel run all nodes + parallel_id = self.graph.node_parallel_mapping.get(edge_mappings[0].target_node_id) + if not parallel_id: + raise GraphRunFailedError(f'Node {edge_mappings[0].target_node_id} related parallel not found.') + + parallel = self.graph.parallel_mapping.get(parallel_id) + if not parallel: + raise GraphRunFailedError(f'Parallel {parallel_id} not found.') + + # run parallel nodes, run in new thread and use queue to get results + q: queue.Queue = queue.Queue() + + # Create a list to store the threads + threads = [] + + # new thread + for edge in edge_mappings: + thread = threading.Thread(target=self._run_parallel_node, kwargs={ + 'flask_app': current_app._get_current_object(), # type: ignore[attr-defined] + 'q': q, + 'parallel_id': parallel_id, + 'parallel_start_node_id': edge.target_node_id, + 'parent_parallel_id': in_parallel_id, + 'parent_parallel_start_node_id': parallel_start_node_id, + }) + + threads.append(thread) + thread.start() + + def parallel_generator() -> Generator[GraphEngineEvent, None, None]: + succeeded_count = 0 + while True: + try: + event = q.get(timeout=1) + if event is None: + break + + yield event + if event.parallel_id == parallel_id: + if isinstance(event, ParallelBranchRunSucceededEvent): + succeeded_count += 1 + if succeeded_count == len(threads): + q.put(None) + + continue + elif isinstance(event, ParallelBranchRunFailedEvent): + raise GraphRunFailedError(event.error) + except queue.Empty: + continue + + generator = parallel_generator() + + # Join all threads + for thread in threads: + thread.join() + + # get final node id + final_node_id = parallel.end_to_node_id + if not final_node_id: + return None, generator + + next_node_id = final_node_id + + return final_node_id, generator + def _run_parallel_node( self, flask_app: Flask, diff --git a/api/services/app_generate_service.py b/api/services/app_generate_service.py index 17de1142fb..44b2cd4086 100644 --- a/api/services/app_generate_service.py +++ b/api/services/app_generate_service.py @@ -99,9 +99,7 @@ class AppGenerateService: return max_active_requests @classmethod - def generate_single_iteration( - cls, app_model: App, user: Account, node_id: str, args: Any, streaming: bool = True - ): + def generate_single_iteration(cls, app_model: App, user: Account, node_id: str, args: Any, streaming: bool = True): if app_model.mode == AppMode.ADVANCED_CHAT.value: workflow = cls._get_workflow(app_model, InvokeFrom.DEBUGGER) return AdvancedChatAppGenerator().single_iteration_generate( diff --git a/api/services/workflow_service.py b/api/services/workflow_service.py index 9c36193328..357ffd41c1 100644 --- a/api/services/workflow_service.py +++ b/api/services/workflow_service.py @@ -234,7 +234,7 @@ class WorkflowService: break if not node_run_result: - raise ValueError('Node run failed with no run result') + raise ValueError("Node run failed with no run result") run_succeeded = True if node_run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED else False error = node_run_result.error if not run_succeeded else None @@ -262,9 +262,15 @@ class WorkflowService: if run_succeeded and node_run_result: # create workflow node execution workflow_node_execution.inputs = json.dumps(node_run_result.inputs) if node_run_result.inputs else None - workflow_node_execution.process_data = json.dumps(node_run_result.process_data) if node_run_result.process_data else None - workflow_node_execution.outputs = json.dumps(jsonable_encoder(node_run_result.outputs)) if node_run_result.outputs else None - workflow_node_execution.execution_metadata = json.dumps(jsonable_encoder(node_run_result.metadata)) if node_run_result.metadata else None + workflow_node_execution.process_data = ( + json.dumps(node_run_result.process_data) if node_run_result.process_data else None + ) + workflow_node_execution.outputs = ( + json.dumps(jsonable_encoder(node_run_result.outputs)) if node_run_result.outputs else None + ) + workflow_node_execution.execution_metadata = ( + json.dumps(jsonable_encoder(node_run_result.metadata)) if node_run_result.metadata else None + ) workflow_node_execution.status = WorkflowNodeExecutionStatus.SUCCEEDED.value else: # create workflow node execution