This commit is contained in:
takatost 2024-06-29 15:44:52 +08:00
parent 1d8ecac093
commit 8375517ccd
10 changed files with 995 additions and 470 deletions

View File

@ -1,6 +1,8 @@
from enum import Enum from enum import Enum
from typing import Any, Optional, Union from typing import Any, Optional, Union
from pydantic import BaseModel, Field
from core.file.file_obj import FileVar from core.file.file_obj import FileVar
from core.workflow.entities.node_entities import SystemVariable from core.workflow.entities.node_entities import SystemVariable
@ -21,20 +23,23 @@ class ValueType(Enum):
FILE = "file" FILE = "file"
class VariablePool: class VariablePool(BaseModel):
def __init__(self, system_variables: dict[SystemVariable, Any], variables_mapping: dict[str, dict[int, VariableValue]] = Field(
user_inputs: dict) -> None: description='Variables mapping',
# system variables default={},
# for example: )
# {
# 'query': 'abc', user_inputs: dict = Field(
# 'files': [] description='User inputs',
# } )
self.variables_mapping = {}
self.user_inputs = user_inputs system_variables: dict[SystemVariable, Any] = Field(
self.system_variables = system_variables description='System variables',
for system_variable, value in system_variables.items(): )
def __post_init__(self):
for system_variable, value in self.system_variables.items():
self.append_variable('sys', [system_variable.value], value) self.append_variable('sys', [system_variable.value], value)
def append_variable(self, node_id: str, variable_key_list: list[str], value: VariableValue) -> None: def append_variable(self, node_id: str, variable_key_list: list[str], value: VariableValue) -> None:

View File

@ -1,21 +1,40 @@
from collections.abc import Callable from collections.abc import Callable
from typing import Optional from typing import Literal, Optional
from pydantic import BaseModel from pydantic import BaseModel, Field
from core.workflow.utils.condition.entities import Condition
class RunCondition(BaseModel):
type: Literal["branch_identify", "condition"]
"""condition type"""
branch_identify: Optional[str] = None
"""branch identify, required when type is branch_identify"""
conditions: Optional[list[Condition]] = None
"""conditions to run the node, required when type is condition"""
class GraphNode(BaseModel): class GraphNode(BaseModel):
id: str id: str
"""node id""" """node id"""
parent_id: Optional[str] = None
"""parent node id, e.g. iteration/loop"""
predecessor_node_id: Optional[str] = None predecessor_node_id: Optional[str] = None
"""predecessor node id""" """predecessor node id"""
children_node_ids: list[str] = [] descendant_node_ids: list[str] = []
"""children node ids""" """descendant node ids"""
run_condition_callback: Optional[Callable] = None run_condition: Optional[RunCondition] = None
"""condition function check if the node can be executed""" """condition to run the node"""
run_condition_callback: Optional[Callable] = Field(None, exclude=True)
"""condition function check if the node can be executed, translated from run_conditions, not serialized"""
node_config: dict node_config: dict
"""original node config""" """original node config"""
@ -23,72 +42,96 @@ class GraphNode(BaseModel):
source_edge_config: Optional[dict] = None source_edge_config: Optional[dict] = None
"""original source edge config""" """original source edge config"""
target_edge_config: Optional[dict] = None sub_graph: Optional["Graph"] = None
"""original target edge config""" """sub graph of the node, e.g. iteration/loop sub graph"""
def add_child(self, node_id: str) -> None: def add_child(self, node_id: str) -> None:
self.children_node_ids.append(node_id) self.descendant_node_ids.append(node_id)
class Graph(BaseModel): class Graph(BaseModel):
graph_config: dict
"""graph config from workflow"""
graph_nodes: dict[str, GraphNode] = {} graph_nodes: dict[str, GraphNode] = {}
"""graph nodes""" """graph nodes"""
root_node: Optional[GraphNode] = None root_node: GraphNode
"""root node of the graph""" """root node of the graph"""
@classmethod
def init(cls, root_node_config: dict, run_condition: Optional[RunCondition] = None) -> "Graph":
"""
Init graph
:param root_node_config: root node config
:param run_condition: run condition when root node parent is iteration/loop
:return: graph
"""
root_node = GraphNode(
id=root_node_config.get('id'),
parent_id=root_node_config.get('parentId'),
node_config=root_node_config,
run_condition=run_condition
)
graph = cls(root_node=root_node)
# TODO parse run_condition to run_condition_callback
graph.add_graph_node(graph.root_node)
return graph
def add_edge(self, edge_config: dict, def add_edge(self, edge_config: dict,
source_node_config: dict, source_node_config: dict,
target_node_config: dict, target_node_config: dict,
run_condition_callback: Optional[Callable] = None) -> None: target_node_sub_graph: Optional["Graph"] = None) -> None:
""" """
Add edge to the graph Add edge to the graph
:param edge_config: edge config :param edge_config: edge config
:param source_node_config: source node config :param source_node_config: source node config
:param target_node_config: target node config :param target_node_config: target node config
:param run_condition_callback: condition callback :param target_node_sub_graph: sub graph
""" """
source_node_id = source_node_config.get('id') source_node_id = source_node_config.get('id')
if not source_node_id: if not source_node_id:
return return
if source_node_id not in self.graph_nodes:
return
target_node_id = target_node_config.get('id') target_node_id = target_node_config.get('id')
if not target_node_id: if not target_node_id:
return return
if source_node_id not in self.graph_nodes:
source_graph_node = GraphNode(
id=source_node_id,
node_config=source_node_config,
children_node_ids=[target_node_id],
target_edge_config=edge_config,
)
self.add_graph_node(source_graph_node)
else:
source_node = self.graph_nodes[source_node_id] source_node = self.graph_nodes[source_node_id]
source_node.add_child(target_node_id) source_node.add_child(target_node_id)
source_node.target_edge_config = edge_config
# if run_conditions:
# run_condition_callback = lambda: all()
if target_node_id not in self.graph_nodes: if target_node_id not in self.graph_nodes:
run_condition = None # todo
run_condition_callback = None # todo
target_graph_node = GraphNode( target_graph_node = GraphNode(
id=target_node_id, id=target_node_id,
parent_id=source_node_config.get('parentId'),
predecessor_node_id=source_node_id, predecessor_node_id=source_node_id,
node_config=target_node_config, node_config=target_node_config,
run_condition=run_condition,
run_condition_callback=run_condition_callback, run_condition_callback=run_condition_callback,
source_edge_config=edge_config, source_edge_config=edge_config,
sub_graph=target_node_sub_graph
) )
self.add_graph_node(target_graph_node) self.add_graph_node(target_graph_node)
else: else:
target_node = self.graph_nodes[target_node_id] target_node = self.graph_nodes[target_node_id]
target_node.predecessor_node_id = source_node_id target_node.predecessor_node_id = source_node_id
target_node.run_conditions = run_conditions
target_node.run_condition_callback = run_condition_callback target_node.run_condition_callback = run_condition_callback
target_node.source_edge_config = edge_config target_node.source_edge_config = edge_config
target_node.sub_graph = target_node_sub_graph
def add_graph_node(self, graph_node: GraphNode) -> None: def add_graph_node(self, graph_node: GraphNode) -> None:
""" """
@ -123,13 +166,13 @@ class Graph(BaseModel):
return None return None
graph_node = self.graph_nodes[node_id] graph_node = self.graph_nodes[node_id]
if not graph_node.children_node_ids: if not graph_node.descendant_node_ids:
return None return None
descendants_graph = Graph(graph_config=self.graph_config) descendants_graph = Graph()
descendants_graph.add_graph_node(graph_node) descendants_graph.add_graph_node(graph_node)
for child_node_id in graph_node.children_node_ids: for child_node_id in graph_node.descendant_node_ids:
self._add_descendants_graph_nodes(descendants_graph, child_node_id) self._add_descendants_graph_nodes(descendants_graph, child_node_id)
return descendants_graph return descendants_graph
@ -147,5 +190,5 @@ class Graph(BaseModel):
graph_node = self.graph_nodes[node_id] graph_node = self.graph_nodes[node_id]
descendants_graph.add_graph_node(graph_node) descendants_graph.add_graph_node(graph_node)
for child_node_id in graph_node.children_node_ids: for child_node_id in graph_node.descendant_node_ids:
self._add_descendants_graph_nodes(descendants_graph, child_node_id) self._add_descendants_graph_nodes(descendants_graph, child_node_id)

View File

@ -1,26 +1,12 @@
from typing import Literal, Optional from typing import Literal
from pydantic import BaseModel
from core.workflow.entities.base_node_data_entities import BaseNodeData from core.workflow.entities.base_node_data_entities import BaseNodeData
from core.workflow.utils.condition.entities import Condition
class IfElseNodeData(BaseNodeData): class IfElseNodeData(BaseNodeData):
""" """
Answer Node Data. Answer Node Data.
""" """
class Condition(BaseModel):
"""
Condition entity
"""
variable_selector: list[str]
comparison_operator: Literal[
# for string or array
"contains", "not contains", "start with", "end with", "is", "is not", "empty", "not empty",
# for number
"=", "", ">", "<", "", "", "null", "not null"
]
value: Optional[str] = None
logical_operator: Literal["and", "or"] = "and" logical_operator: Literal["and", "or"] = "and"
conditions: list[Condition] conditions: list[Condition]

View File

@ -1,10 +1,11 @@
from typing import Optional, cast from typing import cast
from core.workflow.entities.base_node_data_entities import BaseNodeData from core.workflow.entities.base_node_data_entities import BaseNodeData
from core.workflow.entities.node_entities import NodeRunResult, NodeType from core.workflow.entities.node_entities import NodeRunResult, NodeType
from core.workflow.entities.variable_pool import VariablePool from core.workflow.entities.variable_pool import VariablePool
from core.workflow.nodes.base_node import BaseNode from core.workflow.nodes.base_node import BaseNode
from core.workflow.nodes.if_else.entities import IfElseNodeData from core.workflow.nodes.if_else.entities import IfElseNodeData
from core.workflow.utils.condition.processor import ConditionAssertionError, ConditionProcessor
from models.workflow import WorkflowNodeExecutionStatus from models.workflow import WorkflowNodeExecutionStatus
@ -19,90 +20,42 @@ class IfElseNode(BaseNode):
:return: :return:
""" """
node_data = self.node_data node_data = self.node_data
node_data = cast(self._node_data_cls, node_data) node_data = cast(IfElseNodeData, node_data)
node_inputs = { node_inputs: dict[str, list] = {
"conditions": [] "conditions": []
} }
process_datas = { process_datas: dict[str, list] = {
"condition_results": [] "condition_results": []
} }
try: try:
logical_operator = node_data.logical_operator processor = ConditionProcessor()
input_conditions = [] compare_result, sub_condition_compare_results = processor.process(
for condition in node_data.conditions: variable_pool=variable_pool,
actual_value = variable_pool.get_variable_value( logical_operator=node_data.logical_operator,
variable_selector=condition.variable_selector conditions=node_data.conditions,
) )
expected_value = condition.value node_inputs["conditions"] = [{
"actual_value": result['actual_value'],
"expected_value": result['expected_value'],
"comparison_operator": result['comparison_operator'],
} for result in sub_condition_compare_results]
input_conditions.append({ process_datas["condition_results"] = sub_condition_compare_results
"actual_value": actual_value, except ConditionAssertionError as e:
"expected_value": expected_value, node_inputs["conditions"] = e.conditions
"comparison_operator": condition.comparison_operator process_datas["condition_results"] = e.sub_condition_compare_results
})
node_inputs["conditions"] = input_conditions
for input_condition in input_conditions:
actual_value = input_condition["actual_value"]
expected_value = input_condition["expected_value"]
comparison_operator = input_condition["comparison_operator"]
if comparison_operator == "contains":
compare_result = self._assert_contains(actual_value, expected_value)
elif comparison_operator == "not contains":
compare_result = self._assert_not_contains(actual_value, expected_value)
elif comparison_operator == "start with":
compare_result = self._assert_start_with(actual_value, expected_value)
elif comparison_operator == "end with":
compare_result = self._assert_end_with(actual_value, expected_value)
elif comparison_operator == "is":
compare_result = self._assert_is(actual_value, expected_value)
elif comparison_operator == "is not":
compare_result = self._assert_is_not(actual_value, expected_value)
elif comparison_operator == "empty":
compare_result = self._assert_empty(actual_value)
elif comparison_operator == "not empty":
compare_result = self._assert_not_empty(actual_value)
elif comparison_operator == "=":
compare_result = self._assert_equal(actual_value, expected_value)
elif comparison_operator == "":
compare_result = self._assert_not_equal(actual_value, expected_value)
elif comparison_operator == ">":
compare_result = self._assert_greater_than(actual_value, expected_value)
elif comparison_operator == "<":
compare_result = self._assert_less_than(actual_value, expected_value)
elif comparison_operator == "":
compare_result = self._assert_greater_than_or_equal(actual_value, expected_value)
elif comparison_operator == "":
compare_result = self._assert_less_than_or_equal(actual_value, expected_value)
elif comparison_operator == "null":
compare_result = self._assert_null(actual_value)
elif comparison_operator == "not null":
compare_result = self._assert_not_null(actual_value)
else:
continue
process_datas["condition_results"].append({
**input_condition,
"result": compare_result
})
except Exception as e:
return NodeRunResult( return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED, status=WorkflowNodeExecutionStatus.FAILED,
inputs=node_inputs, inputs=node_inputs,
process_data=process_datas, process_data=process_datas,
error=str(e) error=str(e)
) )
except Exception as e:
if logical_operator == "and": raise e
compare_result = False not in [condition["result"] for condition in process_datas["condition_results"]]
else:
compare_result = True in [condition["result"] for condition in process_datas["condition_results"]]
return NodeRunResult( return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED, status=WorkflowNodeExecutionStatus.SUCCEEDED,
@ -114,280 +67,6 @@ class IfElseNode(BaseNode):
} }
) )
def _assert_contains(self, actual_value: Optional[str | list], expected_value: str) -> bool:
"""
Assert contains
:param actual_value: actual value
:param expected_value: expected value
:return:
"""
if not actual_value:
return False
if not isinstance(actual_value, str | list):
raise ValueError('Invalid actual value type: string or array')
if expected_value not in actual_value:
return False
return True
def _assert_not_contains(self, actual_value: Optional[str | list], expected_value: str) -> bool:
"""
Assert not contains
:param actual_value: actual value
:param expected_value: expected value
:return:
"""
if not actual_value:
return True
if not isinstance(actual_value, str | list):
raise ValueError('Invalid actual value type: string or array')
if expected_value in actual_value:
return False
return True
def _assert_start_with(self, actual_value: Optional[str], expected_value: str) -> bool:
"""
Assert start with
:param actual_value: actual value
:param expected_value: expected value
:return:
"""
if not actual_value:
return False
if not isinstance(actual_value, str):
raise ValueError('Invalid actual value type: string')
if not actual_value.startswith(expected_value):
return False
return True
def _assert_end_with(self, actual_value: Optional[str], expected_value: str) -> bool:
"""
Assert end with
:param actual_value: actual value
:param expected_value: expected value
:return:
"""
if not actual_value:
return False
if not isinstance(actual_value, str):
raise ValueError('Invalid actual value type: string')
if not actual_value.endswith(expected_value):
return False
return True
def _assert_is(self, actual_value: Optional[str], expected_value: str) -> bool:
"""
Assert is
:param actual_value: actual value
:param expected_value: expected value
:return:
"""
if actual_value is None:
return False
if not isinstance(actual_value, str):
raise ValueError('Invalid actual value type: string')
if actual_value != expected_value:
return False
return True
def _assert_is_not(self, actual_value: Optional[str], expected_value: str) -> bool:
"""
Assert is not
:param actual_value: actual value
:param expected_value: expected value
:return:
"""
if actual_value is None:
return False
if not isinstance(actual_value, str):
raise ValueError('Invalid actual value type: string')
if actual_value == expected_value:
return False
return True
def _assert_empty(self, actual_value: Optional[str]) -> bool:
"""
Assert empty
:param actual_value: actual value
:return:
"""
if not actual_value:
return True
return False
def _assert_not_empty(self, actual_value: Optional[str]) -> bool:
"""
Assert not empty
:param actual_value: actual value
:return:
"""
if actual_value:
return True
return False
def _assert_equal(self, actual_value: Optional[int | float], expected_value: str) -> bool:
"""
Assert equal
:param actual_value: actual value
:param expected_value: expected value
:return:
"""
if actual_value is None:
return False
if not isinstance(actual_value, int | float):
raise ValueError('Invalid actual value type: number')
if isinstance(actual_value, int):
expected_value = int(expected_value)
else:
expected_value = float(expected_value)
if actual_value != expected_value:
return False
return True
def _assert_not_equal(self, actual_value: Optional[int | float], expected_value: str) -> bool:
"""
Assert not equal
:param actual_value: actual value
:param expected_value: expected value
:return:
"""
if actual_value is None:
return False
if not isinstance(actual_value, int | float):
raise ValueError('Invalid actual value type: number')
if isinstance(actual_value, int):
expected_value = int(expected_value)
else:
expected_value = float(expected_value)
if actual_value == expected_value:
return False
return True
def _assert_greater_than(self, actual_value: Optional[int | float], expected_value: str) -> bool:
"""
Assert greater than
:param actual_value: actual value
:param expected_value: expected value
:return:
"""
if actual_value is None:
return False
if not isinstance(actual_value, int | float):
raise ValueError('Invalid actual value type: number')
if isinstance(actual_value, int):
expected_value = int(expected_value)
else:
expected_value = float(expected_value)
if actual_value <= expected_value:
return False
return True
def _assert_less_than(self, actual_value: Optional[int | float], expected_value: str) -> bool:
"""
Assert less than
:param actual_value: actual value
:param expected_value: expected value
:return:
"""
if actual_value is None:
return False
if not isinstance(actual_value, int | float):
raise ValueError('Invalid actual value type: number')
if isinstance(actual_value, int):
expected_value = int(expected_value)
else:
expected_value = float(expected_value)
if actual_value >= expected_value:
return False
return True
def _assert_greater_than_or_equal(self, actual_value: Optional[int | float], expected_value: str) -> bool:
"""
Assert greater than or equal
:param actual_value: actual value
:param expected_value: expected value
:return:
"""
if actual_value is None:
return False
if not isinstance(actual_value, int | float):
raise ValueError('Invalid actual value type: number')
if isinstance(actual_value, int):
expected_value = int(expected_value)
else:
expected_value = float(expected_value)
if actual_value < expected_value:
return False
return True
def _assert_less_than_or_equal(self, actual_value: Optional[int | float], expected_value: str) -> bool:
"""
Assert less than or equal
:param actual_value: actual value
:param expected_value: expected value
:return:
"""
if actual_value is None:
return False
if not isinstance(actual_value, int | float):
raise ValueError('Invalid actual value type: number')
if isinstance(actual_value, int):
expected_value = int(expected_value)
else:
expected_value = float(expected_value)
if actual_value > expected_value:
return False
return True
def _assert_null(self, actual_value: Optional[int | float]) -> bool:
"""
Assert null
:param actual_value: actual value
:return:
"""
if actual_value is None:
return True
return False
def _assert_not_null(self, actual_value: Optional[int | float]) -> bool:
"""
Assert not null
:param actual_value: actual value
:return:
"""
if actual_value is not None:
return True
return False
@classmethod @classmethod
def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[str, list[str]]: def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[str, list[str]]:
""" """

View File

@ -0,0 +1,17 @@
from typing import Literal, Optional
from pydantic import BaseModel
class Condition(BaseModel):
"""
Condition entity
"""
variable_selector: list[str]
comparison_operator: Literal[
# for string or array
"contains", "not contains", "start with", "end with", "is", "is not", "empty", "not empty",
# for number
"=", "", ">", "<", "", "", "null", "not null"
]
value: Optional[str] = None

View File

@ -0,0 +1,22 @@
from typing import Optional
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.workflow_runtime_state_entities import WorkflowRuntimeState
from core.workflow.graph import GraphNode
def source_handle_condition_func(workflow_runtime_state: WorkflowRuntimeState,
graph_node: GraphNode,
# TODO cycle_state optional
predecessor_node_run_result: Optional[NodeRunResult] = None) -> bool:
if not graph_node.source_edge_config:
return False
if not graph_node.source_edge_config.get('sourceHandle'):
return True
source_handle = predecessor_node_run_result.edge_source_handle \
if predecessor_node_run_result else None
return (source_handle is not None
and graph_node.source_edge_config.get('sourceHandle') == source_handle)

View File

@ -0,0 +1,369 @@
from typing import Literal, Optional
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.utils.condition.entities import Condition
class ConditionProcessor:
def process(self, variable_pool: VariablePool,
logical_operator: Literal["and", "or"],
conditions: list[Condition]) -> tuple[bool, list[dict]]:
"""
Process conditions
:param variable_pool: variable pool
:param logical_operator: logical operator
:param conditions: conditions
"""
input_conditions = []
sub_condition_compare_results = []
try:
for condition in conditions:
actual_value = variable_pool.get_variable_value(
variable_selector=condition.variable_selector
)
expected_value = condition.value
input_conditions.append({
"actual_value": actual_value,
"expected_value": expected_value,
"comparison_operator": condition.comparison_operator
})
for input_condition in input_conditions:
actual_value = input_condition["actual_value"]
expected_value = input_condition["expected_value"]
comparison_operator = input_condition["comparison_operator"]
if comparison_operator == "contains":
compare_result = self._assert_contains(actual_value, expected_value)
elif comparison_operator == "not contains":
compare_result = self._assert_not_contains(actual_value, expected_value)
elif comparison_operator == "start with":
compare_result = self._assert_start_with(actual_value, expected_value)
elif comparison_operator == "end with":
compare_result = self._assert_end_with(actual_value, expected_value)
elif comparison_operator == "is":
compare_result = self._assert_is(actual_value, expected_value)
elif comparison_operator == "is not":
compare_result = self._assert_is_not(actual_value, expected_value)
elif comparison_operator == "empty":
compare_result = self._assert_empty(actual_value)
elif comparison_operator == "not empty":
compare_result = self._assert_not_empty(actual_value)
elif comparison_operator == "=":
compare_result = self._assert_equal(actual_value, expected_value)
elif comparison_operator == "":
compare_result = self._assert_not_equal(actual_value, expected_value)
elif comparison_operator == ">":
compare_result = self._assert_greater_than(actual_value, expected_value)
elif comparison_operator == "<":
compare_result = self._assert_less_than(actual_value, expected_value)
elif comparison_operator == "":
compare_result = self._assert_greater_than_or_equal(actual_value, expected_value)
elif comparison_operator == "":
compare_result = self._assert_less_than_or_equal(actual_value, expected_value)
elif comparison_operator == "null":
compare_result = self._assert_null(actual_value)
elif comparison_operator == "not null":
compare_result = self._assert_not_null(actual_value)
else:
continue
sub_condition_compare_results.append({
**input_condition,
"result": compare_result
})
except Exception as e:
raise ConditionAssertionError(str(e), input_conditions, sub_condition_compare_results)
if logical_operator == "and":
compare_result = False not in [condition["result"] for condition in sub_condition_compare_results]
else:
compare_result = True in [condition["result"] for condition in sub_condition_compare_results]
return compare_result, sub_condition_compare_results
def _assert_contains(self, actual_value: Optional[str | list], expected_value: str) -> bool:
"""
Assert contains
:param actual_value: actual value
:param expected_value: expected value
:return:
"""
if not actual_value:
return False
if not isinstance(actual_value, str | list):
raise ValueError('Invalid actual value type: string or array')
if expected_value not in actual_value:
return False
return True
def _assert_not_contains(self, actual_value: Optional[str | list], expected_value: str) -> bool:
"""
Assert not contains
:param actual_value: actual value
:param expected_value: expected value
:return:
"""
if not actual_value:
return True
if not isinstance(actual_value, str | list):
raise ValueError('Invalid actual value type: string or array')
if expected_value in actual_value:
return False
return True
def _assert_start_with(self, actual_value: Optional[str], expected_value: str) -> bool:
"""
Assert start with
:param actual_value: actual value
:param expected_value: expected value
:return:
"""
if not actual_value:
return False
if not isinstance(actual_value, str):
raise ValueError('Invalid actual value type: string')
if not actual_value.startswith(expected_value):
return False
return True
def _assert_end_with(self, actual_value: Optional[str], expected_value: str) -> bool:
"""
Assert end with
:param actual_value: actual value
:param expected_value: expected value
:return:
"""
if not actual_value:
return False
if not isinstance(actual_value, str):
raise ValueError('Invalid actual value type: string')
if not actual_value.endswith(expected_value):
return False
return True
def _assert_is(self, actual_value: Optional[str], expected_value: str) -> bool:
"""
Assert is
:param actual_value: actual value
:param expected_value: expected value
:return:
"""
if actual_value is None:
return False
if not isinstance(actual_value, str):
raise ValueError('Invalid actual value type: string')
if actual_value != expected_value:
return False
return True
def _assert_is_not(self, actual_value: Optional[str], expected_value: str) -> bool:
"""
Assert is not
:param actual_value: actual value
:param expected_value: expected value
:return:
"""
if actual_value is None:
return False
if not isinstance(actual_value, str):
raise ValueError('Invalid actual value type: string')
if actual_value == expected_value:
return False
return True
def _assert_empty(self, actual_value: Optional[str]) -> bool:
"""
Assert empty
:param actual_value: actual value
:return:
"""
if not actual_value:
return True
return False
def _assert_not_empty(self, actual_value: Optional[str]) -> bool:
"""
Assert not empty
:param actual_value: actual value
:return:
"""
if actual_value:
return True
return False
def _assert_equal(self, actual_value: Optional[int | float], expected_value: str) -> bool:
"""
Assert equal
:param actual_value: actual value
:param expected_value: expected value
:return:
"""
if actual_value is None:
return False
if not isinstance(actual_value, int | float):
raise ValueError('Invalid actual value type: number')
if isinstance(actual_value, int):
expected_value = int(expected_value)
else:
expected_value = float(expected_value)
if actual_value != expected_value:
return False
return True
def _assert_not_equal(self, actual_value: Optional[int | float], expected_value: str) -> bool:
"""
Assert not equal
:param actual_value: actual value
:param expected_value: expected value
:return:
"""
if actual_value is None:
return False
if not isinstance(actual_value, int | float):
raise ValueError('Invalid actual value type: number')
if isinstance(actual_value, int):
expected_value = int(expected_value)
else:
expected_value = float(expected_value)
if actual_value == expected_value:
return False
return True
def _assert_greater_than(self, actual_value: Optional[int | float], expected_value: str) -> bool:
"""
Assert greater than
:param actual_value: actual value
:param expected_value: expected value
:return:
"""
if actual_value is None:
return False
if not isinstance(actual_value, int | float):
raise ValueError('Invalid actual value type: number')
if isinstance(actual_value, int):
expected_value = int(expected_value)
else:
expected_value = float(expected_value)
if actual_value <= expected_value:
return False
return True
def _assert_less_than(self, actual_value: Optional[int | float], expected_value: str) -> bool:
"""
Assert less than
:param actual_value: actual value
:param expected_value: expected value
:return:
"""
if actual_value is None:
return False
if not isinstance(actual_value, int | float):
raise ValueError('Invalid actual value type: number')
if isinstance(actual_value, int):
expected_value = int(expected_value)
else:
expected_value = float(expected_value)
if actual_value >= expected_value:
return False
return True
def _assert_greater_than_or_equal(self, actual_value: Optional[int | float], expected_value: str) -> bool:
"""
Assert greater than or equal
:param actual_value: actual value
:param expected_value: expected value
:return:
"""
if actual_value is None:
return False
if not isinstance(actual_value, int | float):
raise ValueError('Invalid actual value type: number')
if isinstance(actual_value, int):
expected_value = int(expected_value)
else:
expected_value = float(expected_value)
if actual_value < expected_value:
return False
return True
def _assert_less_than_or_equal(self, actual_value: Optional[int | float], expected_value: str) -> bool:
"""
Assert less than or equal
:param actual_value: actual value
:param expected_value: expected value
:return:
"""
if actual_value is None:
return False
if not isinstance(actual_value, int | float):
raise ValueError('Invalid actual value type: number')
if isinstance(actual_value, int):
expected_value = int(expected_value)
else:
expected_value = float(expected_value)
if actual_value > expected_value:
return False
return True
def _assert_null(self, actual_value: Optional[int | float]) -> bool:
"""
Assert null
:param actual_value: actual value
:return:
"""
if actual_value is None:
return True
return False
def _assert_not_null(self, actual_value: Optional[int | float]) -> bool:
"""
Assert not null
:param actual_value: actual value
:return:
"""
if actual_value is not None:
return True
return False
class ConditionAssertionError(Exception):
def __init__(self, message: str, conditions: list[dict], sub_condition_compare_results: list[dict]) -> None:
self.message = message
self.conditions = conditions
self.sub_condition_compare_results = sub_condition_compare_results
super().__init__(self.message)

View File

@ -1,5 +1,4 @@
import logging import logging
import threading
import time import time
from typing import Any, Optional, cast from typing import Any, Optional, cast
@ -129,18 +128,172 @@ class WorkflowEngineManager:
callbacks=callbacks callbacks=callbacks
) )
try:
# run workflow # run workflow
self._run_workflow( self._run_workflow(
graph_config=graph_config, graph_config=graph_config,
workflow_runtime_state=workflow_runtime_state, workflow_runtime_state=workflow_runtime_state,
callbacks=callbacks, callbacks=callbacks,
) )
except WorkflowRunFailedError as e:
self._workflow_run_failed(
error=e.error,
callbacks=callbacks
)
except Exception as e:
self._workflow_run_failed(
error=str(e),
callbacks=callbacks
)
# workflow run success # workflow run success
self._workflow_run_success( self._workflow_run_success(
callbacks=callbacks callbacks=callbacks
) )
def _init_graph(self, graph_config: dict, root_node_id: Optional[str] = None) -> Optional[Graph]:
"""
Initialize graph
:param graph_config: graph config
:param root_node_id: root node id if needed
:return: graph
"""
# edge configs
edge_configs = graph_config.get('edges')
if not edge_configs:
return None
edge_configs = cast(list, edge_configs)
# reorganize edges mapping
source_edges_mapping: dict[str, list[dict]] = {}
target_edge_ids = set()
for edge_config in edge_configs:
source_node_id = edge_config.get('source')
if not source_node_id:
continue
if source_node_id not in source_edges_mapping:
source_edges_mapping[source_node_id] = []
source_edges_mapping[source_node_id].append(edge_config)
target_node_id = edge_config.get('target')
if target_node_id:
target_edge_ids.add(target_node_id)
# node configs
node_configs = graph_config.get('nodes')
if not node_configs:
return None
node_configs = cast(list, node_configs)
# fetch nodes that have no predecessor node
root_node_configs = []
nodes_mapping: dict[str, dict] = {}
for node_config in node_configs:
node_id = node_config.get('id')
if not node_id:
continue
if node_id not in target_edge_ids:
root_node_configs.append(node_config)
nodes_mapping[node_id] = node_config
# fetch root node
if root_node_id:
root_node_config = next((node_config for node_config in root_node_configs
if node_config.get('id') == root_node_id), None)
else:
# if no root node id, use the START type node as root node
root_node_config = next((node_config for node_config in root_node_configs
if node_config.get('data', {}).get('type', '') == NodeType.START.value), None)
if not root_node_config:
return None
# init graph
graph = Graph.init(
root_node_config=root_node_config
)
# add edge from root node
self._recursively_add_edges(
graph=graph,
source_node_config=root_node_config,
edges_mapping=source_edges_mapping,
nodes_mapping=nodes_mapping,
root_node_configs=root_node_configs
)
return graph
def _recursively_add_edges(self, graph: Graph,
source_node_config: dict,
edges_mapping: dict,
nodes_mapping: dict,
root_node_configs: list[dict]) -> None:
"""
Add edges
:param source_node_config: source node config
:param edges_mapping: edges mapping
:param nodes_mapping: nodes mapping
:param root_node_configs: root node configs
"""
source_node_id = source_node_config.get('id')
if not source_node_id:
return
for edge_config in edges_mapping.get(source_node_id, []):
target_node_id = edge_config.get('target')
if not target_node_id:
continue
target_node_config = nodes_mapping.get(target_node_id)
if not target_node_config:
continue
sub_graph: Optional[Graph] = None
target_node_type: NodeType = NodeType.value_of(target_node_config.get('data', {}).get('type'))
if target_node_type and target_node_type in [IterationNode.node_type, NodeType.LOOP]:
# find iteration/loop sub nodes that have no predecessor node
for root_node_config in root_node_configs:
if root_node_config.get('parentId') == target_node_id:
# create sub graph
sub_graph = Graph.init(
root_node_config=root_node_config
)
self._recursively_add_edges(
graph=sub_graph,
source_node_config=root_node_config,
edges_mapping=edges_mapping,
nodes_mapping=nodes_mapping,
root_node_configs=root_node_configs
)
break
# add edge
graph.add_edge(
edge_config=edge_config,
source_node_config=source_node_config,
target_node_config=target_node_config,
target_node_sub_graph=sub_graph,
)
# recursively add edges
self._recursively_add_edges(
graph=graph,
source_node_config=target_node_config,
edges_mapping=edges_mapping,
nodes_mapping=nodes_mapping,
root_node_configs=root_node_configs
)
def _run_workflow(self, graph_config: dict, def _run_workflow(self, graph_config: dict,
workflow_runtime_state: WorkflowRuntimeState, workflow_runtime_state: WorkflowRuntimeState,
callbacks: list[BaseWorkflowCallback], callbacks: list[BaseWorkflowCallback],
@ -157,10 +310,15 @@ class WorkflowEngineManager:
""" """
try: try:
# init graph # init graph
graph = Graph( graph = self._init_graph(
graph_config=graph_config graph_config=graph_config
) )
if not graph:
raise WorkflowRunFailedError(
error='Start node not found in workflow graph.'
)
predecessor_node: Optional[BaseNode] = None predecessor_node: Optional[BaseNode] = None
current_iteration_node: Optional[BaseIterationNode] = None current_iteration_node: Optional[BaseIterationNode] = None
max_execution_steps = current_app.config.get("WORKFLOW_MAX_EXECUTION_STEPS") max_execution_steps = current_app.config.get("WORKFLOW_MAX_EXECUTION_STEPS")
@ -231,11 +389,11 @@ class WorkflowEngineManager:
# max steps reached # max steps reached
if workflow_run_state.workflow_node_steps > max_execution_steps: if workflow_run_state.workflow_node_steps > max_execution_steps:
raise ValueError('Max steps {} reached.'.format(max_execution_steps)) raise WorkflowRunFailedError('Max steps {} reached.'.format(max_execution_steps))
# or max execution time reached # or max execution time reached
if self._is_timed_out(start_at=workflow_run_state.start_at, max_execution_time=max_execution_time): if self._is_timed_out(start_at=workflow_run_state.start_at, max_execution_time=max_execution_time):
raise ValueError('Max execution time {}s reached.'.format(max_execution_time)) raise WorkflowRunFailedError('Max execution time {}s reached.'.format(max_execution_time))
if len(next_nodes) == 1: if len(next_nodes) == 1:
next_node = next_nodes[0] next_node = next_nodes[0]
@ -256,63 +414,59 @@ class WorkflowEngineManager:
else: else:
result_dict = {} result_dict = {}
# new thread # # new thread
worker_thread = threading.Thread(target=self._async_run_nodes, kwargs={ # worker_thread = threading.Thread(target=self._async_run_nodes, kwargs={
'flask_app': current_app._get_current_object(), # 'flask_app': current_app._get_current_object(),
'graph': graph, # 'graph': graph,
'workflow_run_state': workflow_run_state, # 'workflow_run_state': workflow_run_state,
'predecessor_node': predecessor_node, # 'predecessor_node': predecessor_node,
'next_nodes': next_nodes, # 'next_nodes': next_nodes,
'callbacks': callbacks, # 'callbacks': callbacks,
'result': result_dict # 'result': result_dict
}) # })
#
worker_thread.start() # worker_thread.start()
worker_thread.join() # worker_thread.join()
if not workflow_run_state.workflow_node_runs: if not workflow_run_state.workflow_node_runs:
self._workflow_run_failed( raise WorkflowRunFailedError(
error='Start node not found in workflow graph.', error='Start node not found in workflow graph.'
callbacks=callbacks
) )
return
except GenerateTaskStoppedException as e: except GenerateTaskStoppedException as e:
return return
except Exception as e: except Exception as e:
self._workflow_run_failed( raise WorkflowRunFailedError(
error=str(e), error=str(e)
callbacks=callbacks
)
return
def _async_run_nodes(self, flask_app: Flask,
graph: dict,
workflow_run_state: WorkflowRunState,
predecessor_node: Optional[BaseNode],
next_nodes: list[BaseNode],
callbacks: list[BaseWorkflowCallback],
result: dict):
with flask_app.app_context():
try:
for next_node in next_nodes:
# TODO run sub workflows
# run node
is_continue = self._run_node(
graph=graph,
workflow_run_state=workflow_run_state,
predecessor_node=predecessor_node,
current_node=next_node,
callbacks=callbacks
) )
if not is_continue: # def _async_run_nodes(self, flask_app: Flask,
break # graph: dict,
# workflow_run_state: WorkflowRunState,
predecessor_node = next_node # predecessor_node: Optional[BaseNode],
except Exception as e: # next_nodes: list[BaseNode],
logger.exception("Unknown Error when generating") # callbacks: list[BaseWorkflowCallback],
finally: # result: dict):
db.session.remove() # with flask_app.app_context():
# try:
# for next_node in next_nodes:
# # TODO run sub workflows
# # run node
# is_continue = self._run_node(
# graph=graph,
# workflow_run_state=workflow_run_state,
# predecessor_node=predecessor_node,
# current_node=next_node,
# callbacks=callbacks
# )
#
# if not is_continue:
# break
#
# predecessor_node = next_node
# except Exception as e:
# logger.exception("Unknown Error when generating")
# finally:
# db.session.remove()
def _run_node(self, graph: dict, def _run_node(self, graph: dict,
workflow_run_state: WorkflowRunState, workflow_run_state: WorkflowRunState,
@ -584,14 +738,25 @@ class WorkflowEngineManager:
workflow_call_depth=0 workflow_call_depth=0
) )
try:
# run workflow # run workflow
self._run_workflow( self._run_workflow(
graph=workflow.graph, graph_config=workflow.graph,
workflow_run_state=workflow_run_state, workflow_runtime_state=workflow_runtime_state,
callbacks=callbacks, callbacks=callbacks,
start_node=node_id, start_node=node_id,
end_node=end_node_id end_node=end_node_id
) )
except WorkflowRunFailedError as e:
self._workflow_run_failed(
error=e.error,
callbacks=callbacks
)
except Exception as e:
self._workflow_run_failed(
error=str(e),
callbacks=callbacks
)
# workflow run success # workflow run success
self._workflow_run_success( self._workflow_run_success(
@ -1072,3 +1237,8 @@ class WorkflowEngineManager:
variable_key_list=variable_key_list, variable_key_list=variable_key_list,
value=value value=value
) )
class WorkflowRunFailedError(Exception):
def __init__(self, error: str):
self.error = error

View File

@ -0,0 +1,234 @@
from core.workflow.graph import Graph
from core.workflow.workflow_engine_manager import WorkflowEngineManager
def test__init_graph():
graph_config = {
"edges": [
{
"id": "llm-source-answer-target",
"source": "llm",
"target": "answer",
},
{
"id": "1717222650545-source-1719481290322-target",
"source": "1717222650545",
"target": "1719481290322",
},
{
"id": "1719481290322-1-llm-target",
"source": "1719481290322",
"sourceHandle": "1",
"target": "llm",
},
{
"id": "1719481290322-2-1719481315734-target",
"source": "1719481290322",
"sourceHandle": "2",
"target": "1719481315734",
},
{
"id": "1719481315734-source-1719481326339-target",
"source": "1719481315734",
"target": "1719481326339",
}
],
"nodes": [
{
"data": {
"desc": "",
"title": "Start",
"type": "start",
"variables": [
{
"label": "name",
"max_length": 48,
"options": [],
"required": False,
"type": "text-input",
"variable": "name"
}
]
},
"id": "1717222650545",
"position": {
"x": -147.65487258270954,
"y": 263.5326708413438
},
},
{
"data": {
"context": {
"enabled": False,
"variable_selector": []
},
"desc": "",
"memory": {
"query_prompt_template": "{{#sys.query#}}",
"role_prefix": {
"assistant": "",
"user": ""
},
"window": {
"enabled": False,
"size": 10
}
},
"model": {
"completion_params": {
"temperature": 0
},
"mode": "chat",
"name": "anthropic.claude-3-sonnet-20240229-v1:0",
"provider": "bedrock"
},
"prompt_config": {
"jinja2_variables": [
{
"value_selector": [
"sys",
"query"
],
"variable": "query"
}
]
},
"prompt_template": [
{
"edition_type": "basic",
"id": "8b02d178-3aa0-4dbd-82bf-8b6a40658300",
"jinja2_text": "",
"role": "system",
"text": "yep"
}
],
"title": "LLM",
"type": "llm",
"variables": [],
"vision": {
"configs": {
"detail": "low"
},
"enabled": True
}
},
"id": "llm",
"position": {
"x": 654.0331237272932,
"y": 263.5326708413438
},
},
{
"data": {
"answer": "123{{#llm.text#}}",
"desc": "",
"title": "Answer",
"type": "answer",
"variables": []
},
"id": "answer",
"position": {
"x": 958.1129142362784,
"y": 263.5326708413438
},
},
{
"data": {
"classes": [
{
"id": "1",
"name": "happy"
},
{
"id": "2",
"name": "sad"
}
],
"desc": "",
"instructions": "",
"model": {
"completion_params": {
"temperature": 0.7
},
"mode": "chat",
"name": "gpt-4o",
"provider": "openai"
},
"query_variable_selector": [
"1717222650545",
"sys.query"
],
"title": "Question Classifier",
"topics": [],
"type": "question-classifier"
},
"id": "1719481290322",
"position": {
"x": 165.25154615277052,
"y": 263.5326708413438
}
},
{
"data": {
"authorization": {
"config": None,
"type": "no-auth"
},
"body": {
"data": "",
"type": "none"
},
"desc": "",
"headers": "",
"method": "get",
"params": "",
"timeout": {
"max_connect_timeout": 0,
"max_read_timeout": 0,
"max_write_timeout": 0
},
"title": "HTTP Request",
"type": "http-request",
"url": "https://baidu.com",
"variables": []
},
"height": 88,
"id": "1719481315734",
"position": {
"x": 654.0331237272932,
"y": 474.1180064703089
}
},
{
"data": {
"answer": "{{#1719481315734.status_code#}}",
"desc": "",
"title": "Answer 2",
"type": "answer",
"variables": []
},
"height": 105,
"id": "1719481326339",
"position": {
"x": 958.1129142362784,
"y": 474.1180064703089
},
}
],
}
workflow_engine_manager = WorkflowEngineManager()
graph = workflow_engine_manager._init_graph(
graph_config=graph_config
)
assert graph.root_node.id == "1717222650545"
assert graph.root_node.source_edge_config is None
assert graph.root_node.target_edge_config is not None
assert graph.root_node.descendant_node_ids == ["1719481290322"]
assert graph.graph_nodes.get("1719481290322") is not None
assert len(graph.graph_nodes.get("1719481290322").descendant_node_ids) == 2
assert graph.graph_nodes.get("llm").run_condition_callback is not None
assert graph.graph_nodes.get("1719481315734").run_condition_callback is not None