diff --git a/api/core/workflow/entities/workflow_runtime_state_entities.py b/api/core/workflow/entities/workflow_runtime_state_entities.py index 6b55d843e3..ffe4e8f109 100644 --- a/api/core/workflow/entities/workflow_runtime_state_entities.py +++ b/api/core/workflow/entities/workflow_runtime_state_entities.py @@ -1,3 +1,4 @@ +from datetime import datetime, timezone from enum import Enum from typing import Optional @@ -8,11 +9,12 @@ from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.variable_pool import VariablePool from core.workflow.graph import Graph, GraphNode from core.workflow.nodes.base_node import BaseNode, UserFrom -from models.workflow import WorkflowType +from models.workflow import WorkflowNodeExecutionStatus, WorkflowType class RuntimeNode(BaseModel): class Status(Enum): + PENDING = "pending" RUNNING = "running" SUCCESS = "success" FAILED = "failed" @@ -30,16 +32,16 @@ class RuntimeNode(BaseModel): node_run_result: Optional[NodeRunResult] = None """node run result""" - status: Status = Status.RUNNING + status: Status = Status.PENDING """node status""" - start_at: float + start_at: Optional[datetime] = None """start time""" - paused_at: Optional[float] = None + paused_at: Optional[datetime] = None """paused time""" - finished_at: Optional[float] = None + finished_at: Optional[datetime] = None """finished time""" failed_reason: Optional[str] = None @@ -48,6 +50,39 @@ class RuntimeNode(BaseModel): paused_by: Optional[str] = None """paused by""" + predecessor_runtime_node_id: Optional[str] = None + """predecessor runtime node id""" + + +class RuntimeGraph(BaseModel): + runtime_nodes: dict[str, RuntimeNode] = {} + """runtime nodes""" + + def add_runtime_node(self, runtime_node: RuntimeNode) -> None: + self.runtime_nodes[runtime_node.id] = runtime_node + + def add_link(self, source_runtime_node_id: str, target_runtime_node_id: str) -> None: + if source_runtime_node_id in self.runtime_nodes and target_runtime_node_id in self.runtime_nodes: + target_runtime_node = self.runtime_nodes[target_runtime_node_id] + target_runtime_node.predecessor_runtime_node_id = source_runtime_node_id + + def runtime_node_finished(self, runtime_node_id: str, node_run_result: NodeRunResult) -> None: + if runtime_node_id in self.runtime_nodes: + runtime_node = self.runtime_nodes[runtime_node_id] + runtime_node.status = RuntimeNode.Status.SUCCESS \ + if node_run_result.status == WorkflowNodeExecutionStatus.RUNNING \ + else RuntimeNode.Status.FAILED + runtime_node.node_run_result = node_run_result + runtime_node.finished_at = datetime.now(timezone.utc).replace(tzinfo=None) + runtime_node.failed_reason = node_run_result.error + + def runtime_node_paused(self, runtime_node_id: str, paused_by: Optional[str] = None) -> None: + if runtime_node_id in self.runtime_nodes: + runtime_node = self.runtime_nodes[runtime_node_id] + runtime_node.status = RuntimeNode.Status.PAUSED + runtime_node.paused_at = datetime.now(timezone.utc).replace(tzinfo=None) + runtime_node.paused_by = paused_by + class WorkflowRuntimeState(BaseModel): tenant_id: str @@ -64,3 +99,5 @@ class WorkflowRuntimeState(BaseModel): total_tokens: int = 0 node_run_steps: int = 0 + + runtime_graph: RuntimeGraph diff --git a/api/core/workflow/graph.py b/api/core/workflow/graph.py index 06e739b31a..25482e10c7 100644 --- a/api/core/workflow/graph.py +++ b/api/core/workflow/graph.py @@ -17,7 +17,7 @@ class GraphNode(BaseModel): source_handle: Optional[str] = None """current node source handle from the previous node result""" - is_continue_callback: Optional[Callable] = None + run_condition_callback: Optional[Callable] = None """condition function check if the node can be executed""" node_config: dict @@ -34,8 +34,8 @@ class GraphNode(BaseModel): class Graph(BaseModel): - graph: dict - """graph from workflow""" + graph_config: dict + """graph config from workflow""" graph_nodes: dict[str, GraphNode] = {} """graph nodes""" @@ -46,14 +46,14 @@ class Graph(BaseModel): def add_edge(self, edge_config: dict, source_node_config: dict, target_node_config: dict, - is_continue_callback: Optional[Callable] = None) -> None: + run_condition_callback: Optional[Callable] = None) -> None: """ Add edge to the graph :param edge_config: edge config :param source_node_config: source node config :param target_node_config: target node config - :param is_continue_callback: condition callback + :param run_condition_callback: condition callback """ source_node_id = source_node_config.get('id') if not source_node_id: @@ -77,17 +77,12 @@ class Graph(BaseModel): source_node.add_child(target_node_id) source_node.target_edge_config = edge_config - source_handle = None - if edge_config.get('sourceHandle'): - source_handle = edge_config.get('sourceHandle') - if target_node_id not in self.graph_nodes: target_graph_node = GraphNode( id=target_node_id, predecessor_node_id=source_node_id, node_config=target_node_config, - source_handle=source_handle, - is_continue_callback=is_continue_callback, + run_condition_callback=run_condition_callback, source_edge_config=edge_config, ) @@ -95,8 +90,7 @@ class Graph(BaseModel): else: target_node = self.graph_nodes[target_node_id] target_node.predecessor_node_id = source_node_id - target_node.source_handle = source_handle - target_node.is_continue_callback = is_continue_callback + target_node.run_condition_callback = run_condition_callback target_node.source_edge_config = edge_config def add_graph_node(self, graph_node: GraphNode) -> None: @@ -135,7 +129,7 @@ class Graph(BaseModel): if not graph_node.children_node_ids: return None - descendants_graph = Graph() + descendants_graph = Graph(graph_config=self.graph_config) descendants_graph.add_graph_node(graph_node) for child_node_id in graph_node.children_node_ids: