add runtime state of graph

This commit is contained in:
takatost 2024-06-25 17:43:13 +08:00
parent fe27c97fd9
commit 216910a4a1
2 changed files with 50 additions and 19 deletions

View File

@ -1,3 +1,4 @@
from datetime import datetime, timezone
from enum import Enum from enum import Enum
from typing import Optional from typing import Optional
@ -8,11 +9,12 @@ from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.variable_pool import VariablePool from core.workflow.entities.variable_pool import VariablePool
from core.workflow.graph import Graph, GraphNode from core.workflow.graph import Graph, GraphNode
from core.workflow.nodes.base_node import BaseNode, UserFrom from core.workflow.nodes.base_node import BaseNode, UserFrom
from models.workflow import WorkflowType from models.workflow import WorkflowNodeExecutionStatus, WorkflowType
class RuntimeNode(BaseModel): class RuntimeNode(BaseModel):
class Status(Enum): class Status(Enum):
PENDING = "pending"
RUNNING = "running" RUNNING = "running"
SUCCESS = "success" SUCCESS = "success"
FAILED = "failed" FAILED = "failed"
@ -30,16 +32,16 @@ class RuntimeNode(BaseModel):
node_run_result: Optional[NodeRunResult] = None node_run_result: Optional[NodeRunResult] = None
"""node run result""" """node run result"""
status: Status = Status.RUNNING status: Status = Status.PENDING
"""node status""" """node status"""
start_at: float start_at: Optional[datetime] = None
"""start time""" """start time"""
paused_at: Optional[float] = None paused_at: Optional[datetime] = None
"""paused time""" """paused time"""
finished_at: Optional[float] = None finished_at: Optional[datetime] = None
"""finished time""" """finished time"""
failed_reason: Optional[str] = None failed_reason: Optional[str] = None
@ -48,6 +50,39 @@ class RuntimeNode(BaseModel):
paused_by: Optional[str] = None paused_by: Optional[str] = None
"""paused by""" """paused by"""
predecessor_runtime_node_id: Optional[str] = None
"""predecessor runtime node id"""
class RuntimeGraph(BaseModel):
runtime_nodes: dict[str, RuntimeNode] = {}
"""runtime nodes"""
def add_runtime_node(self, runtime_node: RuntimeNode) -> None:
self.runtime_nodes[runtime_node.id] = runtime_node
def add_link(self, source_runtime_node_id: str, target_runtime_node_id: str) -> None:
if source_runtime_node_id in self.runtime_nodes and target_runtime_node_id in self.runtime_nodes:
target_runtime_node = self.runtime_nodes[target_runtime_node_id]
target_runtime_node.predecessor_runtime_node_id = source_runtime_node_id
def runtime_node_finished(self, runtime_node_id: str, node_run_result: NodeRunResult) -> None:
if runtime_node_id in self.runtime_nodes:
runtime_node = self.runtime_nodes[runtime_node_id]
runtime_node.status = RuntimeNode.Status.SUCCESS \
if node_run_result.status == WorkflowNodeExecutionStatus.RUNNING \
else RuntimeNode.Status.FAILED
runtime_node.node_run_result = node_run_result
runtime_node.finished_at = datetime.now(timezone.utc).replace(tzinfo=None)
runtime_node.failed_reason = node_run_result.error
def runtime_node_paused(self, runtime_node_id: str, paused_by: Optional[str] = None) -> None:
if runtime_node_id in self.runtime_nodes:
runtime_node = self.runtime_nodes[runtime_node_id]
runtime_node.status = RuntimeNode.Status.PAUSED
runtime_node.paused_at = datetime.now(timezone.utc).replace(tzinfo=None)
runtime_node.paused_by = paused_by
class WorkflowRuntimeState(BaseModel): class WorkflowRuntimeState(BaseModel):
tenant_id: str tenant_id: str
@ -64,3 +99,5 @@ class WorkflowRuntimeState(BaseModel):
total_tokens: int = 0 total_tokens: int = 0
node_run_steps: int = 0 node_run_steps: int = 0
runtime_graph: RuntimeGraph

View File

@ -17,7 +17,7 @@ class GraphNode(BaseModel):
source_handle: Optional[str] = None source_handle: Optional[str] = None
"""current node source handle from the previous node result""" """current node source handle from the previous node result"""
is_continue_callback: Optional[Callable] = None run_condition_callback: Optional[Callable] = None
"""condition function check if the node can be executed""" """condition function check if the node can be executed"""
node_config: dict node_config: dict
@ -34,8 +34,8 @@ class GraphNode(BaseModel):
class Graph(BaseModel): class Graph(BaseModel):
graph: dict graph_config: dict
"""graph from workflow""" """graph config from workflow"""
graph_nodes: dict[str, GraphNode] = {} graph_nodes: dict[str, GraphNode] = {}
"""graph nodes""" """graph nodes"""
@ -46,14 +46,14 @@ class Graph(BaseModel):
def add_edge(self, edge_config: dict, def add_edge(self, edge_config: dict,
source_node_config: dict, source_node_config: dict,
target_node_config: dict, target_node_config: dict,
is_continue_callback: Optional[Callable] = None) -> None: run_condition_callback: Optional[Callable] = None) -> None:
""" """
Add edge to the graph Add edge to the graph
:param edge_config: edge config :param edge_config: edge config
:param source_node_config: source node config :param source_node_config: source node config
:param target_node_config: target node config :param target_node_config: target node config
:param is_continue_callback: condition callback :param run_condition_callback: condition callback
""" """
source_node_id = source_node_config.get('id') source_node_id = source_node_config.get('id')
if not source_node_id: if not source_node_id:
@ -77,17 +77,12 @@ class Graph(BaseModel):
source_node.add_child(target_node_id) source_node.add_child(target_node_id)
source_node.target_edge_config = edge_config source_node.target_edge_config = edge_config
source_handle = None
if edge_config.get('sourceHandle'):
source_handle = edge_config.get('sourceHandle')
if target_node_id not in self.graph_nodes: if target_node_id not in self.graph_nodes:
target_graph_node = GraphNode( target_graph_node = GraphNode(
id=target_node_id, id=target_node_id,
predecessor_node_id=source_node_id, predecessor_node_id=source_node_id,
node_config=target_node_config, node_config=target_node_config,
source_handle=source_handle, run_condition_callback=run_condition_callback,
is_continue_callback=is_continue_callback,
source_edge_config=edge_config, source_edge_config=edge_config,
) )
@ -95,8 +90,7 @@ class Graph(BaseModel):
else: else:
target_node = self.graph_nodes[target_node_id] target_node = self.graph_nodes[target_node_id]
target_node.predecessor_node_id = source_node_id target_node.predecessor_node_id = source_node_id
target_node.source_handle = source_handle target_node.run_condition_callback = run_condition_callback
target_node.is_continue_callback = is_continue_callback
target_node.source_edge_config = edge_config target_node.source_edge_config = edge_config
def add_graph_node(self, graph_node: GraphNode) -> None: def add_graph_node(self, graph_node: GraphNode) -> None:
@ -135,7 +129,7 @@ class Graph(BaseModel):
if not graph_node.children_node_ids: if not graph_node.children_node_ids:
return None return None
descendants_graph = Graph() descendants_graph = Graph(graph_config=self.graph_config)
descendants_graph.add_graph_node(graph_node) descendants_graph.add_graph_node(graph_node)
for child_node_id in graph_node.children_node_ids: for child_node_id in graph_node.children_node_ids: