This commit is contained in:
takatost 2024-06-29 15:44:52 +08:00
parent 1d8ecac093
commit 8375517ccd
10 changed files with 995 additions and 470 deletions

View File

@ -1,6 +1,8 @@
from enum import Enum
from typing import Any, Optional, Union
from pydantic import BaseModel, Field
from core.file.file_obj import FileVar
from core.workflow.entities.node_entities import SystemVariable
@ -21,20 +23,23 @@ class ValueType(Enum):
FILE = "file"
class VariablePool:
class VariablePool(BaseModel):
def __init__(self, system_variables: dict[SystemVariable, Any],
user_inputs: dict) -> None:
# system variables
# for example:
# {
# 'query': 'abc',
# 'files': []
# }
self.variables_mapping = {}
self.user_inputs = user_inputs
self.system_variables = system_variables
for system_variable, value in system_variables.items():
variables_mapping: dict[str, dict[int, VariableValue]] = Field(
description='Variables mapping',
default={},
)
user_inputs: dict = Field(
description='User inputs',
)
system_variables: dict[SystemVariable, Any] = Field(
description='System variables',
)
def __post_init__(self):
for system_variable, value in self.system_variables.items():
self.append_variable('sys', [system_variable.value], value)
def append_variable(self, node_id: str, variable_key_list: list[str], value: VariableValue) -> None:

View File

@ -1,21 +1,40 @@
from collections.abc import Callable
from typing import Optional
from typing import Literal, Optional
from pydantic import BaseModel
from pydantic import BaseModel, Field
from core.workflow.utils.condition.entities import Condition
class RunCondition(BaseModel):
type: Literal["branch_identify", "condition"]
"""condition type"""
branch_identify: Optional[str] = None
"""branch identify, required when type is branch_identify"""
conditions: Optional[list[Condition]] = None
"""conditions to run the node, required when type is condition"""
class GraphNode(BaseModel):
id: str
"""node id"""
parent_id: Optional[str] = None
"""parent node id, e.g. iteration/loop"""
predecessor_node_id: Optional[str] = None
"""predecessor node id"""
children_node_ids: list[str] = []
"""children node ids"""
descendant_node_ids: list[str] = []
"""descendant node ids"""
run_condition_callback: Optional[Callable] = None
"""condition function check if the node can be executed"""
run_condition: Optional[RunCondition] = None
"""condition to run the node"""
run_condition_callback: Optional[Callable] = Field(None, exclude=True)
"""condition function check if the node can be executed, translated from run_conditions, not serialized"""
node_config: dict
"""original node config"""
@ -23,72 +42,96 @@ class GraphNode(BaseModel):
source_edge_config: Optional[dict] = None
"""original source edge config"""
target_edge_config: Optional[dict] = None
"""original target edge config"""
sub_graph: Optional["Graph"] = None
"""sub graph of the node, e.g. iteration/loop sub graph"""
def add_child(self, node_id: str) -> None:
self.children_node_ids.append(node_id)
self.descendant_node_ids.append(node_id)
class Graph(BaseModel):
graph_config: dict
"""graph config from workflow"""
graph_nodes: dict[str, GraphNode] = {}
"""graph nodes"""
root_node: Optional[GraphNode] = None
root_node: GraphNode
"""root node of the graph"""
@classmethod
def init(cls, root_node_config: dict, run_condition: Optional[RunCondition] = None) -> "Graph":
"""
Init graph
:param root_node_config: root node config
:param run_condition: run condition when root node parent is iteration/loop
:return: graph
"""
root_node = GraphNode(
id=root_node_config.get('id'),
parent_id=root_node_config.get('parentId'),
node_config=root_node_config,
run_condition=run_condition
)
graph = cls(root_node=root_node)
# TODO parse run_condition to run_condition_callback
graph.add_graph_node(graph.root_node)
return graph
def add_edge(self, edge_config: dict,
source_node_config: dict,
target_node_config: dict,
run_condition_callback: Optional[Callable] = None) -> None:
target_node_sub_graph: Optional["Graph"] = None) -> None:
"""
Add edge to the graph
:param edge_config: edge config
:param source_node_config: source node config
:param target_node_config: target node config
:param run_condition_callback: condition callback
:param target_node_sub_graph: sub graph
"""
source_node_id = source_node_config.get('id')
if not source_node_id:
return
if source_node_id not in self.graph_nodes:
return
target_node_id = target_node_config.get('id')
if not target_node_id:
return
if source_node_id not in self.graph_nodes:
source_graph_node = GraphNode(
id=source_node_id,
node_config=source_node_config,
children_node_ids=[target_node_id],
target_edge_config=edge_config,
)
source_node = self.graph_nodes[source_node_id]
source_node.add_child(target_node_id)
# if run_conditions:
# run_condition_callback = lambda: all()
self.add_graph_node(source_graph_node)
else:
source_node = self.graph_nodes[source_node_id]
source_node.add_child(target_node_id)
source_node.target_edge_config = edge_config
if target_node_id not in self.graph_nodes:
run_condition = None # todo
run_condition_callback = None # todo
target_graph_node = GraphNode(
id=target_node_id,
parent_id=source_node_config.get('parentId'),
predecessor_node_id=source_node_id,
node_config=target_node_config,
run_condition=run_condition,
run_condition_callback=run_condition_callback,
source_edge_config=edge_config,
sub_graph=target_node_sub_graph
)
self.add_graph_node(target_graph_node)
else:
target_node = self.graph_nodes[target_node_id]
target_node.predecessor_node_id = source_node_id
target_node.run_conditions = run_conditions
target_node.run_condition_callback = run_condition_callback
target_node.source_edge_config = edge_config
target_node.sub_graph = target_node_sub_graph
def add_graph_node(self, graph_node: GraphNode) -> None:
"""
@ -123,13 +166,13 @@ class Graph(BaseModel):
return None
graph_node = self.graph_nodes[node_id]
if not graph_node.children_node_ids:
if not graph_node.descendant_node_ids:
return None
descendants_graph = Graph(graph_config=self.graph_config)
descendants_graph = Graph()
descendants_graph.add_graph_node(graph_node)
for child_node_id in graph_node.children_node_ids:
for child_node_id in graph_node.descendant_node_ids:
self._add_descendants_graph_nodes(descendants_graph, child_node_id)
return descendants_graph
@ -147,5 +190,5 @@ class Graph(BaseModel):
graph_node = self.graph_nodes[node_id]
descendants_graph.add_graph_node(graph_node)
for child_node_id in graph_node.children_node_ids:
for child_node_id in graph_node.descendant_node_ids:
self._add_descendants_graph_nodes(descendants_graph, child_node_id)

View File

@ -1,26 +1,12 @@
from typing import Literal, Optional
from pydantic import BaseModel
from typing import Literal
from core.workflow.entities.base_node_data_entities import BaseNodeData
from core.workflow.utils.condition.entities import Condition
class IfElseNodeData(BaseNodeData):
"""
Answer Node Data.
"""
class Condition(BaseModel):
"""
Condition entity
"""
variable_selector: list[str]
comparison_operator: Literal[
# for string or array
"contains", "not contains", "start with", "end with", "is", "is not", "empty", "not empty",
# for number
"=", "", ">", "<", "", "", "null", "not null"
]
value: Optional[str] = None
logical_operator: Literal["and", "or"] = "and"
conditions: list[Condition]

View File

@ -1,10 +1,11 @@
from typing import Optional, cast
from typing import cast
from core.workflow.entities.base_node_data_entities import BaseNodeData
from core.workflow.entities.node_entities import NodeRunResult, NodeType
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.nodes.base_node import BaseNode
from core.workflow.nodes.if_else.entities import IfElseNodeData
from core.workflow.utils.condition.processor import ConditionAssertionError, ConditionProcessor
from models.workflow import WorkflowNodeExecutionStatus
@ -19,90 +20,42 @@ class IfElseNode(BaseNode):
:return:
"""
node_data = self.node_data
node_data = cast(self._node_data_cls, node_data)
node_data = cast(IfElseNodeData, node_data)
node_inputs = {
node_inputs: dict[str, list] = {
"conditions": []
}
process_datas = {
process_datas: dict[str, list] = {
"condition_results": []
}
try:
logical_operator = node_data.logical_operator
input_conditions = []
for condition in node_data.conditions:
actual_value = variable_pool.get_variable_value(
variable_selector=condition.variable_selector
)
processor = ConditionProcessor()
compare_result, sub_condition_compare_results = processor.process(
variable_pool=variable_pool,
logical_operator=node_data.logical_operator,
conditions=node_data.conditions,
)
expected_value = condition.value
node_inputs["conditions"] = [{
"actual_value": result['actual_value'],
"expected_value": result['expected_value'],
"comparison_operator": result['comparison_operator'],
} for result in sub_condition_compare_results]
input_conditions.append({
"actual_value": actual_value,
"expected_value": expected_value,
"comparison_operator": condition.comparison_operator
})
node_inputs["conditions"] = input_conditions
for input_condition in input_conditions:
actual_value = input_condition["actual_value"]
expected_value = input_condition["expected_value"]
comparison_operator = input_condition["comparison_operator"]
if comparison_operator == "contains":
compare_result = self._assert_contains(actual_value, expected_value)
elif comparison_operator == "not contains":
compare_result = self._assert_not_contains(actual_value, expected_value)
elif comparison_operator == "start with":
compare_result = self._assert_start_with(actual_value, expected_value)
elif comparison_operator == "end with":
compare_result = self._assert_end_with(actual_value, expected_value)
elif comparison_operator == "is":
compare_result = self._assert_is(actual_value, expected_value)
elif comparison_operator == "is not":
compare_result = self._assert_is_not(actual_value, expected_value)
elif comparison_operator == "empty":
compare_result = self._assert_empty(actual_value)
elif comparison_operator == "not empty":
compare_result = self._assert_not_empty(actual_value)
elif comparison_operator == "=":
compare_result = self._assert_equal(actual_value, expected_value)
elif comparison_operator == "":
compare_result = self._assert_not_equal(actual_value, expected_value)
elif comparison_operator == ">":
compare_result = self._assert_greater_than(actual_value, expected_value)
elif comparison_operator == "<":
compare_result = self._assert_less_than(actual_value, expected_value)
elif comparison_operator == "":
compare_result = self._assert_greater_than_or_equal(actual_value, expected_value)
elif comparison_operator == "":
compare_result = self._assert_less_than_or_equal(actual_value, expected_value)
elif comparison_operator == "null":
compare_result = self._assert_null(actual_value)
elif comparison_operator == "not null":
compare_result = self._assert_not_null(actual_value)
else:
continue
process_datas["condition_results"].append({
**input_condition,
"result": compare_result
})
except Exception as e:
process_datas["condition_results"] = sub_condition_compare_results
except ConditionAssertionError as e:
node_inputs["conditions"] = e.conditions
process_datas["condition_results"] = e.sub_condition_compare_results
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
inputs=node_inputs,
process_data=process_datas,
error=str(e)
)
if logical_operator == "and":
compare_result = False not in [condition["result"] for condition in process_datas["condition_results"]]
else:
compare_result = True in [condition["result"] for condition in process_datas["condition_results"]]
except Exception as e:
raise e
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
@ -114,280 +67,6 @@ class IfElseNode(BaseNode):
}
)
def _assert_contains(self, actual_value: Optional[str | list], expected_value: str) -> bool:
"""
Assert contains
:param actual_value: actual value
:param expected_value: expected value
:return:
"""
if not actual_value:
return False
if not isinstance(actual_value, str | list):
raise ValueError('Invalid actual value type: string or array')
if expected_value not in actual_value:
return False
return True
def _assert_not_contains(self, actual_value: Optional[str | list], expected_value: str) -> bool:
"""
Assert not contains
:param actual_value: actual value
:param expected_value: expected value
:return:
"""
if not actual_value:
return True
if not isinstance(actual_value, str | list):
raise ValueError('Invalid actual value type: string or array')
if expected_value in actual_value:
return False
return True
def _assert_start_with(self, actual_value: Optional[str], expected_value: str) -> bool:
"""
Assert start with
:param actual_value: actual value
:param expected_value: expected value
:return:
"""
if not actual_value:
return False
if not isinstance(actual_value, str):
raise ValueError('Invalid actual value type: string')
if not actual_value.startswith(expected_value):
return False
return True
def _assert_end_with(self, actual_value: Optional[str], expected_value: str) -> bool:
"""
Assert end with
:param actual_value: actual value
:param expected_value: expected value
:return:
"""
if not actual_value:
return False
if not isinstance(actual_value, str):
raise ValueError('Invalid actual value type: string')
if not actual_value.endswith(expected_value):
return False
return True
def _assert_is(self, actual_value: Optional[str], expected_value: str) -> bool:
"""
Assert is
:param actual_value: actual value
:param expected_value: expected value
:return:
"""
if actual_value is None:
return False
if not isinstance(actual_value, str):
raise ValueError('Invalid actual value type: string')
if actual_value != expected_value:
return False
return True
def _assert_is_not(self, actual_value: Optional[str], expected_value: str) -> bool:
"""
Assert is not
:param actual_value: actual value
:param expected_value: expected value
:return:
"""
if actual_value is None:
return False
if not isinstance(actual_value, str):
raise ValueError('Invalid actual value type: string')
if actual_value == expected_value:
return False
return True
def _assert_empty(self, actual_value: Optional[str]) -> bool:
"""
Assert empty
:param actual_value: actual value
:return:
"""
if not actual_value:
return True
return False
def _assert_not_empty(self, actual_value: Optional[str]) -> bool:
"""
Assert not empty
:param actual_value: actual value
:return:
"""
if actual_value:
return True
return False
def _assert_equal(self, actual_value: Optional[int | float], expected_value: str) -> bool:
"""
Assert equal
:param actual_value: actual value
:param expected_value: expected value
:return:
"""
if actual_value is None:
return False
if not isinstance(actual_value, int | float):
raise ValueError('Invalid actual value type: number')
if isinstance(actual_value, int):
expected_value = int(expected_value)
else:
expected_value = float(expected_value)
if actual_value != expected_value:
return False
return True
def _assert_not_equal(self, actual_value: Optional[int | float], expected_value: str) -> bool:
"""
Assert not equal
:param actual_value: actual value
:param expected_value: expected value
:return:
"""
if actual_value is None:
return False
if not isinstance(actual_value, int | float):
raise ValueError('Invalid actual value type: number')
if isinstance(actual_value, int):
expected_value = int(expected_value)
else:
expected_value = float(expected_value)
if actual_value == expected_value:
return False
return True
def _assert_greater_than(self, actual_value: Optional[int | float], expected_value: str) -> bool:
"""
Assert greater than
:param actual_value: actual value
:param expected_value: expected value
:return:
"""
if actual_value is None:
return False
if not isinstance(actual_value, int | float):
raise ValueError('Invalid actual value type: number')
if isinstance(actual_value, int):
expected_value = int(expected_value)
else:
expected_value = float(expected_value)
if actual_value <= expected_value:
return False
return True
def _assert_less_than(self, actual_value: Optional[int | float], expected_value: str) -> bool:
"""
Assert less than
:param actual_value: actual value
:param expected_value: expected value
:return:
"""
if actual_value is None:
return False
if not isinstance(actual_value, int | float):
raise ValueError('Invalid actual value type: number')
if isinstance(actual_value, int):
expected_value = int(expected_value)
else:
expected_value = float(expected_value)
if actual_value >= expected_value:
return False
return True
def _assert_greater_than_or_equal(self, actual_value: Optional[int | float], expected_value: str) -> bool:
"""
Assert greater than or equal
:param actual_value: actual value
:param expected_value: expected value
:return:
"""
if actual_value is None:
return False
if not isinstance(actual_value, int | float):
raise ValueError('Invalid actual value type: number')
if isinstance(actual_value, int):
expected_value = int(expected_value)
else:
expected_value = float(expected_value)
if actual_value < expected_value:
return False
return True
def _assert_less_than_or_equal(self, actual_value: Optional[int | float], expected_value: str) -> bool:
"""
Assert less than or equal
:param actual_value: actual value
:param expected_value: expected value
:return:
"""
if actual_value is None:
return False
if not isinstance(actual_value, int | float):
raise ValueError('Invalid actual value type: number')
if isinstance(actual_value, int):
expected_value = int(expected_value)
else:
expected_value = float(expected_value)
if actual_value > expected_value:
return False
return True
def _assert_null(self, actual_value: Optional[int | float]) -> bool:
"""
Assert null
:param actual_value: actual value
:return:
"""
if actual_value is None:
return True
return False
def _assert_not_null(self, actual_value: Optional[int | float]) -> bool:
"""
Assert not null
:param actual_value: actual value
:return:
"""
if actual_value is not None:
return True
return False
@classmethod
def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[str, list[str]]:
"""

View File

@ -0,0 +1,17 @@
from typing import Literal, Optional
from pydantic import BaseModel
class Condition(BaseModel):
"""
Condition entity
"""
variable_selector: list[str]
comparison_operator: Literal[
# for string or array
"contains", "not contains", "start with", "end with", "is", "is not", "empty", "not empty",
# for number
"=", "", ">", "<", "", "", "null", "not null"
]
value: Optional[str] = None

View File

@ -0,0 +1,22 @@
from typing import Optional
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.workflow_runtime_state_entities import WorkflowRuntimeState
from core.workflow.graph import GraphNode
def source_handle_condition_func(workflow_runtime_state: WorkflowRuntimeState,
graph_node: GraphNode,
# TODO cycle_state optional
predecessor_node_run_result: Optional[NodeRunResult] = None) -> bool:
if not graph_node.source_edge_config:
return False
if not graph_node.source_edge_config.get('sourceHandle'):
return True
source_handle = predecessor_node_run_result.edge_source_handle \
if predecessor_node_run_result else None
return (source_handle is not None
and graph_node.source_edge_config.get('sourceHandle') == source_handle)

View File

@ -0,0 +1,369 @@
from typing import Literal, Optional
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.utils.condition.entities import Condition
class ConditionProcessor:
def process(self, variable_pool: VariablePool,
logical_operator: Literal["and", "or"],
conditions: list[Condition]) -> tuple[bool, list[dict]]:
"""
Process conditions
:param variable_pool: variable pool
:param logical_operator: logical operator
:param conditions: conditions
"""
input_conditions = []
sub_condition_compare_results = []
try:
for condition in conditions:
actual_value = variable_pool.get_variable_value(
variable_selector=condition.variable_selector
)
expected_value = condition.value
input_conditions.append({
"actual_value": actual_value,
"expected_value": expected_value,
"comparison_operator": condition.comparison_operator
})
for input_condition in input_conditions:
actual_value = input_condition["actual_value"]
expected_value = input_condition["expected_value"]
comparison_operator = input_condition["comparison_operator"]
if comparison_operator == "contains":
compare_result = self._assert_contains(actual_value, expected_value)
elif comparison_operator == "not contains":
compare_result = self._assert_not_contains(actual_value, expected_value)
elif comparison_operator == "start with":
compare_result = self._assert_start_with(actual_value, expected_value)
elif comparison_operator == "end with":
compare_result = self._assert_end_with(actual_value, expected_value)
elif comparison_operator == "is":
compare_result = self._assert_is(actual_value, expected_value)
elif comparison_operator == "is not":
compare_result = self._assert_is_not(actual_value, expected_value)
elif comparison_operator == "empty":
compare_result = self._assert_empty(actual_value)
elif comparison_operator == "not empty":
compare_result = self._assert_not_empty(actual_value)
elif comparison_operator == "=":
compare_result = self._assert_equal(actual_value, expected_value)
elif comparison_operator == "":
compare_result = self._assert_not_equal(actual_value, expected_value)
elif comparison_operator == ">":
compare_result = self._assert_greater_than(actual_value, expected_value)
elif comparison_operator == "<":
compare_result = self._assert_less_than(actual_value, expected_value)
elif comparison_operator == "":
compare_result = self._assert_greater_than_or_equal(actual_value, expected_value)
elif comparison_operator == "":
compare_result = self._assert_less_than_or_equal(actual_value, expected_value)
elif comparison_operator == "null":
compare_result = self._assert_null(actual_value)
elif comparison_operator == "not null":
compare_result = self._assert_not_null(actual_value)
else:
continue
sub_condition_compare_results.append({
**input_condition,
"result": compare_result
})
except Exception as e:
raise ConditionAssertionError(str(e), input_conditions, sub_condition_compare_results)
if logical_operator == "and":
compare_result = False not in [condition["result"] for condition in sub_condition_compare_results]
else:
compare_result = True in [condition["result"] for condition in sub_condition_compare_results]
return compare_result, sub_condition_compare_results
def _assert_contains(self, actual_value: Optional[str | list], expected_value: str) -> bool:
"""
Assert contains
:param actual_value: actual value
:param expected_value: expected value
:return:
"""
if not actual_value:
return False
if not isinstance(actual_value, str | list):
raise ValueError('Invalid actual value type: string or array')
if expected_value not in actual_value:
return False
return True
def _assert_not_contains(self, actual_value: Optional[str | list], expected_value: str) -> bool:
"""
Assert not contains
:param actual_value: actual value
:param expected_value: expected value
:return:
"""
if not actual_value:
return True
if not isinstance(actual_value, str | list):
raise ValueError('Invalid actual value type: string or array')
if expected_value in actual_value:
return False
return True
def _assert_start_with(self, actual_value: Optional[str], expected_value: str) -> bool:
"""
Assert start with
:param actual_value: actual value
:param expected_value: expected value
:return:
"""
if not actual_value:
return False
if not isinstance(actual_value, str):
raise ValueError('Invalid actual value type: string')
if not actual_value.startswith(expected_value):
return False
return True
def _assert_end_with(self, actual_value: Optional[str], expected_value: str) -> bool:
"""
Assert end with
:param actual_value: actual value
:param expected_value: expected value
:return:
"""
if not actual_value:
return False
if not isinstance(actual_value, str):
raise ValueError('Invalid actual value type: string')
if not actual_value.endswith(expected_value):
return False
return True
def _assert_is(self, actual_value: Optional[str], expected_value: str) -> bool:
"""
Assert is
:param actual_value: actual value
:param expected_value: expected value
:return:
"""
if actual_value is None:
return False
if not isinstance(actual_value, str):
raise ValueError('Invalid actual value type: string')
if actual_value != expected_value:
return False
return True
def _assert_is_not(self, actual_value: Optional[str], expected_value: str) -> bool:
"""
Assert is not
:param actual_value: actual value
:param expected_value: expected value
:return:
"""
if actual_value is None:
return False
if not isinstance(actual_value, str):
raise ValueError('Invalid actual value type: string')
if actual_value == expected_value:
return False
return True
def _assert_empty(self, actual_value: Optional[str]) -> bool:
"""
Assert empty
:param actual_value: actual value
:return:
"""
if not actual_value:
return True
return False
def _assert_not_empty(self, actual_value: Optional[str]) -> bool:
"""
Assert not empty
:param actual_value: actual value
:return:
"""
if actual_value:
return True
return False
def _assert_equal(self, actual_value: Optional[int | float], expected_value: str) -> bool:
"""
Assert equal
:param actual_value: actual value
:param expected_value: expected value
:return:
"""
if actual_value is None:
return False
if not isinstance(actual_value, int | float):
raise ValueError('Invalid actual value type: number')
if isinstance(actual_value, int):
expected_value = int(expected_value)
else:
expected_value = float(expected_value)
if actual_value != expected_value:
return False
return True
def _assert_not_equal(self, actual_value: Optional[int | float], expected_value: str) -> bool:
"""
Assert not equal
:param actual_value: actual value
:param expected_value: expected value
:return:
"""
if actual_value is None:
return False
if not isinstance(actual_value, int | float):
raise ValueError('Invalid actual value type: number')
if isinstance(actual_value, int):
expected_value = int(expected_value)
else:
expected_value = float(expected_value)
if actual_value == expected_value:
return False
return True
def _assert_greater_than(self, actual_value: Optional[int | float], expected_value: str) -> bool:
"""
Assert greater than
:param actual_value: actual value
:param expected_value: expected value
:return:
"""
if actual_value is None:
return False
if not isinstance(actual_value, int | float):
raise ValueError('Invalid actual value type: number')
if isinstance(actual_value, int):
expected_value = int(expected_value)
else:
expected_value = float(expected_value)
if actual_value <= expected_value:
return False
return True
def _assert_less_than(self, actual_value: Optional[int | float], expected_value: str) -> bool:
"""
Assert less than
:param actual_value: actual value
:param expected_value: expected value
:return:
"""
if actual_value is None:
return False
if not isinstance(actual_value, int | float):
raise ValueError('Invalid actual value type: number')
if isinstance(actual_value, int):
expected_value = int(expected_value)
else:
expected_value = float(expected_value)
if actual_value >= expected_value:
return False
return True
def _assert_greater_than_or_equal(self, actual_value: Optional[int | float], expected_value: str) -> bool:
"""
Assert greater than or equal
:param actual_value: actual value
:param expected_value: expected value
:return:
"""
if actual_value is None:
return False
if not isinstance(actual_value, int | float):
raise ValueError('Invalid actual value type: number')
if isinstance(actual_value, int):
expected_value = int(expected_value)
else:
expected_value = float(expected_value)
if actual_value < expected_value:
return False
return True
def _assert_less_than_or_equal(self, actual_value: Optional[int | float], expected_value: str) -> bool:
"""
Assert less than or equal
:param actual_value: actual value
:param expected_value: expected value
:return:
"""
if actual_value is None:
return False
if not isinstance(actual_value, int | float):
raise ValueError('Invalid actual value type: number')
if isinstance(actual_value, int):
expected_value = int(expected_value)
else:
expected_value = float(expected_value)
if actual_value > expected_value:
return False
return True
def _assert_null(self, actual_value: Optional[int | float]) -> bool:
"""
Assert null
:param actual_value: actual value
:return:
"""
if actual_value is None:
return True
return False
def _assert_not_null(self, actual_value: Optional[int | float]) -> bool:
"""
Assert not null
:param actual_value: actual value
:return:
"""
if actual_value is not None:
return True
return False
class ConditionAssertionError(Exception):
def __init__(self, message: str, conditions: list[dict], sub_condition_compare_results: list[dict]) -> None:
self.message = message
self.conditions = conditions
self.sub_condition_compare_results = sub_condition_compare_results
super().__init__(self.message)

View File

@ -1,5 +1,4 @@
import logging
import threading
import time
from typing import Any, Optional, cast
@ -129,18 +128,172 @@ class WorkflowEngineManager:
callbacks=callbacks
)
# run workflow
self._run_workflow(
graph_config=graph_config,
workflow_runtime_state=workflow_runtime_state,
callbacks=callbacks,
)
try:
# run workflow
self._run_workflow(
graph_config=graph_config,
workflow_runtime_state=workflow_runtime_state,
callbacks=callbacks,
)
except WorkflowRunFailedError as e:
self._workflow_run_failed(
error=e.error,
callbacks=callbacks
)
except Exception as e:
self._workflow_run_failed(
error=str(e),
callbacks=callbacks
)
# workflow run success
self._workflow_run_success(
callbacks=callbacks
)
def _init_graph(self, graph_config: dict, root_node_id: Optional[str] = None) -> Optional[Graph]:
"""
Initialize graph
:param graph_config: graph config
:param root_node_id: root node id if needed
:return: graph
"""
# edge configs
edge_configs = graph_config.get('edges')
if not edge_configs:
return None
edge_configs = cast(list, edge_configs)
# reorganize edges mapping
source_edges_mapping: dict[str, list[dict]] = {}
target_edge_ids = set()
for edge_config in edge_configs:
source_node_id = edge_config.get('source')
if not source_node_id:
continue
if source_node_id not in source_edges_mapping:
source_edges_mapping[source_node_id] = []
source_edges_mapping[source_node_id].append(edge_config)
target_node_id = edge_config.get('target')
if target_node_id:
target_edge_ids.add(target_node_id)
# node configs
node_configs = graph_config.get('nodes')
if not node_configs:
return None
node_configs = cast(list, node_configs)
# fetch nodes that have no predecessor node
root_node_configs = []
nodes_mapping: dict[str, dict] = {}
for node_config in node_configs:
node_id = node_config.get('id')
if not node_id:
continue
if node_id not in target_edge_ids:
root_node_configs.append(node_config)
nodes_mapping[node_id] = node_config
# fetch root node
if root_node_id:
root_node_config = next((node_config for node_config in root_node_configs
if node_config.get('id') == root_node_id), None)
else:
# if no root node id, use the START type node as root node
root_node_config = next((node_config for node_config in root_node_configs
if node_config.get('data', {}).get('type', '') == NodeType.START.value), None)
if not root_node_config:
return None
# init graph
graph = Graph.init(
root_node_config=root_node_config
)
# add edge from root node
self._recursively_add_edges(
graph=graph,
source_node_config=root_node_config,
edges_mapping=source_edges_mapping,
nodes_mapping=nodes_mapping,
root_node_configs=root_node_configs
)
return graph
def _recursively_add_edges(self, graph: Graph,
source_node_config: dict,
edges_mapping: dict,
nodes_mapping: dict,
root_node_configs: list[dict]) -> None:
"""
Add edges
:param source_node_config: source node config
:param edges_mapping: edges mapping
:param nodes_mapping: nodes mapping
:param root_node_configs: root node configs
"""
source_node_id = source_node_config.get('id')
if not source_node_id:
return
for edge_config in edges_mapping.get(source_node_id, []):
target_node_id = edge_config.get('target')
if not target_node_id:
continue
target_node_config = nodes_mapping.get(target_node_id)
if not target_node_config:
continue
sub_graph: Optional[Graph] = None
target_node_type: NodeType = NodeType.value_of(target_node_config.get('data', {}).get('type'))
if target_node_type and target_node_type in [IterationNode.node_type, NodeType.LOOP]:
# find iteration/loop sub nodes that have no predecessor node
for root_node_config in root_node_configs:
if root_node_config.get('parentId') == target_node_id:
# create sub graph
sub_graph = Graph.init(
root_node_config=root_node_config
)
self._recursively_add_edges(
graph=sub_graph,
source_node_config=root_node_config,
edges_mapping=edges_mapping,
nodes_mapping=nodes_mapping,
root_node_configs=root_node_configs
)
break
# add edge
graph.add_edge(
edge_config=edge_config,
source_node_config=source_node_config,
target_node_config=target_node_config,
target_node_sub_graph=sub_graph,
)
# recursively add edges
self._recursively_add_edges(
graph=graph,
source_node_config=target_node_config,
edges_mapping=edges_mapping,
nodes_mapping=nodes_mapping,
root_node_configs=root_node_configs
)
def _run_workflow(self, graph_config: dict,
workflow_runtime_state: WorkflowRuntimeState,
callbacks: list[BaseWorkflowCallback],
@ -157,10 +310,15 @@ class WorkflowEngineManager:
"""
try:
# init graph
graph = Graph(
graph = self._init_graph(
graph_config=graph_config
)
if not graph:
raise WorkflowRunFailedError(
error='Start node not found in workflow graph.'
)
predecessor_node: Optional[BaseNode] = None
current_iteration_node: Optional[BaseIterationNode] = None
max_execution_steps = current_app.config.get("WORKFLOW_MAX_EXECUTION_STEPS")
@ -231,11 +389,11 @@ class WorkflowEngineManager:
# max steps reached
if workflow_run_state.workflow_node_steps > max_execution_steps:
raise ValueError('Max steps {} reached.'.format(max_execution_steps))
raise WorkflowRunFailedError('Max steps {} reached.'.format(max_execution_steps))
# or max execution time reached
if self._is_timed_out(start_at=workflow_run_state.start_at, max_execution_time=max_execution_time):
raise ValueError('Max execution time {}s reached.'.format(max_execution_time))
raise WorkflowRunFailedError('Max execution time {}s reached.'.format(max_execution_time))
if len(next_nodes) == 1:
next_node = next_nodes[0]
@ -256,63 +414,59 @@ class WorkflowEngineManager:
else:
result_dict = {}
# new thread
worker_thread = threading.Thread(target=self._async_run_nodes, kwargs={
'flask_app': current_app._get_current_object(),
'graph': graph,
'workflow_run_state': workflow_run_state,
'predecessor_node': predecessor_node,
'next_nodes': next_nodes,
'callbacks': callbacks,
'result': result_dict
})
worker_thread.start()
worker_thread.join()
# # new thread
# worker_thread = threading.Thread(target=self._async_run_nodes, kwargs={
# 'flask_app': current_app._get_current_object(),
# 'graph': graph,
# 'workflow_run_state': workflow_run_state,
# 'predecessor_node': predecessor_node,
# 'next_nodes': next_nodes,
# 'callbacks': callbacks,
# 'result': result_dict
# })
#
# worker_thread.start()
# worker_thread.join()
if not workflow_run_state.workflow_node_runs:
self._workflow_run_failed(
error='Start node not found in workflow graph.',
callbacks=callbacks
raise WorkflowRunFailedError(
error='Start node not found in workflow graph.'
)
return
except GenerateTaskStoppedException as e:
return
except Exception as e:
self._workflow_run_failed(
error=str(e),
callbacks=callbacks
raise WorkflowRunFailedError(
error=str(e)
)
return
def _async_run_nodes(self, flask_app: Flask,
graph: dict,
workflow_run_state: WorkflowRunState,
predecessor_node: Optional[BaseNode],
next_nodes: list[BaseNode],
callbacks: list[BaseWorkflowCallback],
result: dict):
with flask_app.app_context():
try:
for next_node in next_nodes:
# TODO run sub workflows
# run node
is_continue = self._run_node(
graph=graph,
workflow_run_state=workflow_run_state,
predecessor_node=predecessor_node,
current_node=next_node,
callbacks=callbacks
)
if not is_continue:
break
predecessor_node = next_node
except Exception as e:
logger.exception("Unknown Error when generating")
finally:
db.session.remove()
# def _async_run_nodes(self, flask_app: Flask,
# graph: dict,
# workflow_run_state: WorkflowRunState,
# predecessor_node: Optional[BaseNode],
# next_nodes: list[BaseNode],
# callbacks: list[BaseWorkflowCallback],
# result: dict):
# with flask_app.app_context():
# try:
# for next_node in next_nodes:
# # TODO run sub workflows
# # run node
# is_continue = self._run_node(
# graph=graph,
# workflow_run_state=workflow_run_state,
# predecessor_node=predecessor_node,
# current_node=next_node,
# callbacks=callbacks
# )
#
# if not is_continue:
# break
#
# predecessor_node = next_node
# except Exception as e:
# logger.exception("Unknown Error when generating")
# finally:
# db.session.remove()
def _run_node(self, graph: dict,
workflow_run_state: WorkflowRunState,
@ -584,14 +738,25 @@ class WorkflowEngineManager:
workflow_call_depth=0
)
# run workflow
self._run_workflow(
graph=workflow.graph,
workflow_run_state=workflow_run_state,
callbacks=callbacks,
start_node=node_id,
end_node=end_node_id
)
try:
# run workflow
self._run_workflow(
graph_config=workflow.graph,
workflow_runtime_state=workflow_runtime_state,
callbacks=callbacks,
start_node=node_id,
end_node=end_node_id
)
except WorkflowRunFailedError as e:
self._workflow_run_failed(
error=e.error,
callbacks=callbacks
)
except Exception as e:
self._workflow_run_failed(
error=str(e),
callbacks=callbacks
)
# workflow run success
self._workflow_run_success(
@ -1072,3 +1237,8 @@ class WorkflowEngineManager:
variable_key_list=variable_key_list,
value=value
)
class WorkflowRunFailedError(Exception):
def __init__(self, error: str):
self.error = error

View File

@ -0,0 +1,234 @@
from core.workflow.graph import Graph
from core.workflow.workflow_engine_manager import WorkflowEngineManager
def test__init_graph():
graph_config = {
"edges": [
{
"id": "llm-source-answer-target",
"source": "llm",
"target": "answer",
},
{
"id": "1717222650545-source-1719481290322-target",
"source": "1717222650545",
"target": "1719481290322",
},
{
"id": "1719481290322-1-llm-target",
"source": "1719481290322",
"sourceHandle": "1",
"target": "llm",
},
{
"id": "1719481290322-2-1719481315734-target",
"source": "1719481290322",
"sourceHandle": "2",
"target": "1719481315734",
},
{
"id": "1719481315734-source-1719481326339-target",
"source": "1719481315734",
"target": "1719481326339",
}
],
"nodes": [
{
"data": {
"desc": "",
"title": "Start",
"type": "start",
"variables": [
{
"label": "name",
"max_length": 48,
"options": [],
"required": False,
"type": "text-input",
"variable": "name"
}
]
},
"id": "1717222650545",
"position": {
"x": -147.65487258270954,
"y": 263.5326708413438
},
},
{
"data": {
"context": {
"enabled": False,
"variable_selector": []
},
"desc": "",
"memory": {
"query_prompt_template": "{{#sys.query#}}",
"role_prefix": {
"assistant": "",
"user": ""
},
"window": {
"enabled": False,
"size": 10
}
},
"model": {
"completion_params": {
"temperature": 0
},
"mode": "chat",
"name": "anthropic.claude-3-sonnet-20240229-v1:0",
"provider": "bedrock"
},
"prompt_config": {
"jinja2_variables": [
{
"value_selector": [
"sys",
"query"
],
"variable": "query"
}
]
},
"prompt_template": [
{
"edition_type": "basic",
"id": "8b02d178-3aa0-4dbd-82bf-8b6a40658300",
"jinja2_text": "",
"role": "system",
"text": "yep"
}
],
"title": "LLM",
"type": "llm",
"variables": [],
"vision": {
"configs": {
"detail": "low"
},
"enabled": True
}
},
"id": "llm",
"position": {
"x": 654.0331237272932,
"y": 263.5326708413438
},
},
{
"data": {
"answer": "123{{#llm.text#}}",
"desc": "",
"title": "Answer",
"type": "answer",
"variables": []
},
"id": "answer",
"position": {
"x": 958.1129142362784,
"y": 263.5326708413438
},
},
{
"data": {
"classes": [
{
"id": "1",
"name": "happy"
},
{
"id": "2",
"name": "sad"
}
],
"desc": "",
"instructions": "",
"model": {
"completion_params": {
"temperature": 0.7
},
"mode": "chat",
"name": "gpt-4o",
"provider": "openai"
},
"query_variable_selector": [
"1717222650545",
"sys.query"
],
"title": "Question Classifier",
"topics": [],
"type": "question-classifier"
},
"id": "1719481290322",
"position": {
"x": 165.25154615277052,
"y": 263.5326708413438
}
},
{
"data": {
"authorization": {
"config": None,
"type": "no-auth"
},
"body": {
"data": "",
"type": "none"
},
"desc": "",
"headers": "",
"method": "get",
"params": "",
"timeout": {
"max_connect_timeout": 0,
"max_read_timeout": 0,
"max_write_timeout": 0
},
"title": "HTTP Request",
"type": "http-request",
"url": "https://baidu.com",
"variables": []
},
"height": 88,
"id": "1719481315734",
"position": {
"x": 654.0331237272932,
"y": 474.1180064703089
}
},
{
"data": {
"answer": "{{#1719481315734.status_code#}}",
"desc": "",
"title": "Answer 2",
"type": "answer",
"variables": []
},
"height": 105,
"id": "1719481326339",
"position": {
"x": 958.1129142362784,
"y": 474.1180064703089
},
}
],
}
workflow_engine_manager = WorkflowEngineManager()
graph = workflow_engine_manager._init_graph(
graph_config=graph_config
)
assert graph.root_node.id == "1717222650545"
assert graph.root_node.source_edge_config is None
assert graph.root_node.target_edge_config is not None
assert graph.root_node.descendant_node_ids == ["1719481290322"]
assert graph.graph_nodes.get("1719481290322") is not None
assert len(graph.graph_nodes.get("1719481290322").descendant_node_ids) == 2
assert graph.graph_nodes.get("llm").run_condition_callback is not None
assert graph.graph_nodes.get("1719481315734").run_condition_callback is not None