add runtime state of graph

This commit is contained in:
takatost 2024-06-25 17:43:13 +08:00
parent fe27c97fd9
commit 216910a4a1
2 changed files with 50 additions and 19 deletions

View File

@ -1,3 +1,4 @@
from datetime import datetime, timezone
from enum import Enum
from typing import Optional
@ -8,11 +9,12 @@ from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.graph import Graph, GraphNode
from core.workflow.nodes.base_node import BaseNode, UserFrom
from models.workflow import WorkflowType
from models.workflow import WorkflowNodeExecutionStatus, WorkflowType
class RuntimeNode(BaseModel):
class Status(Enum):
PENDING = "pending"
RUNNING = "running"
SUCCESS = "success"
FAILED = "failed"
@ -30,16 +32,16 @@ class RuntimeNode(BaseModel):
node_run_result: Optional[NodeRunResult] = None
"""node run result"""
status: Status = Status.RUNNING
status: Status = Status.PENDING
"""node status"""
start_at: float
start_at: Optional[datetime] = None
"""start time"""
paused_at: Optional[float] = None
paused_at: Optional[datetime] = None
"""paused time"""
finished_at: Optional[float] = None
finished_at: Optional[datetime] = None
"""finished time"""
failed_reason: Optional[str] = None
@ -48,6 +50,39 @@ class RuntimeNode(BaseModel):
paused_by: Optional[str] = None
"""paused by"""
predecessor_runtime_node_id: Optional[str] = None
"""predecessor runtime node id"""
class RuntimeGraph(BaseModel):
runtime_nodes: dict[str, RuntimeNode] = {}
"""runtime nodes"""
def add_runtime_node(self, runtime_node: RuntimeNode) -> None:
self.runtime_nodes[runtime_node.id] = runtime_node
def add_link(self, source_runtime_node_id: str, target_runtime_node_id: str) -> None:
if source_runtime_node_id in self.runtime_nodes and target_runtime_node_id in self.runtime_nodes:
target_runtime_node = self.runtime_nodes[target_runtime_node_id]
target_runtime_node.predecessor_runtime_node_id = source_runtime_node_id
def runtime_node_finished(self, runtime_node_id: str, node_run_result: NodeRunResult) -> None:
if runtime_node_id in self.runtime_nodes:
runtime_node = self.runtime_nodes[runtime_node_id]
runtime_node.status = RuntimeNode.Status.SUCCESS \
if node_run_result.status == WorkflowNodeExecutionStatus.RUNNING \
else RuntimeNode.Status.FAILED
runtime_node.node_run_result = node_run_result
runtime_node.finished_at = datetime.now(timezone.utc).replace(tzinfo=None)
runtime_node.failed_reason = node_run_result.error
def runtime_node_paused(self, runtime_node_id: str, paused_by: Optional[str] = None) -> None:
if runtime_node_id in self.runtime_nodes:
runtime_node = self.runtime_nodes[runtime_node_id]
runtime_node.status = RuntimeNode.Status.PAUSED
runtime_node.paused_at = datetime.now(timezone.utc).replace(tzinfo=None)
runtime_node.paused_by = paused_by
class WorkflowRuntimeState(BaseModel):
tenant_id: str
@ -64,3 +99,5 @@ class WorkflowRuntimeState(BaseModel):
total_tokens: int = 0
node_run_steps: int = 0
runtime_graph: RuntimeGraph

View File

@ -17,7 +17,7 @@ class GraphNode(BaseModel):
source_handle: Optional[str] = None
"""current node source handle from the previous node result"""
is_continue_callback: Optional[Callable] = None
run_condition_callback: Optional[Callable] = None
"""condition function check if the node can be executed"""
node_config: dict
@ -34,8 +34,8 @@ class GraphNode(BaseModel):
class Graph(BaseModel):
graph: dict
"""graph from workflow"""
graph_config: dict
"""graph config from workflow"""
graph_nodes: dict[str, GraphNode] = {}
"""graph nodes"""
@ -46,14 +46,14 @@ class Graph(BaseModel):
def add_edge(self, edge_config: dict,
source_node_config: dict,
target_node_config: dict,
is_continue_callback: Optional[Callable] = None) -> None:
run_condition_callback: Optional[Callable] = None) -> None:
"""
Add edge to the graph
:param edge_config: edge config
:param source_node_config: source node config
:param target_node_config: target node config
:param is_continue_callback: condition callback
:param run_condition_callback: condition callback
"""
source_node_id = source_node_config.get('id')
if not source_node_id:
@ -77,17 +77,12 @@ class Graph(BaseModel):
source_node.add_child(target_node_id)
source_node.target_edge_config = edge_config
source_handle = None
if edge_config.get('sourceHandle'):
source_handle = edge_config.get('sourceHandle')
if target_node_id not in self.graph_nodes:
target_graph_node = GraphNode(
id=target_node_id,
predecessor_node_id=source_node_id,
node_config=target_node_config,
source_handle=source_handle,
is_continue_callback=is_continue_callback,
run_condition_callback=run_condition_callback,
source_edge_config=edge_config,
)
@ -95,8 +90,7 @@ class Graph(BaseModel):
else:
target_node = self.graph_nodes[target_node_id]
target_node.predecessor_node_id = source_node_id
target_node.source_handle = source_handle
target_node.is_continue_callback = is_continue_callback
target_node.run_condition_callback = run_condition_callback
target_node.source_edge_config = edge_config
def add_graph_node(self, graph_node: GraphNode) -> None:
@ -135,7 +129,7 @@ class Graph(BaseModel):
if not graph_node.children_node_ids:
return None
descendants_graph = Graph()
descendants_graph = Graph(graph_config=self.graph_config)
descendants_graph.add_graph_node(graph_node)
for child_node_id in graph_node.children_node_ids: