diff --git a/api/core/workflow/entities/variable_pool.py b/api/core/workflow/entities/variable_pool.py index c04770616c..cac1843529 100644 --- a/api/core/workflow/entities/variable_pool.py +++ b/api/core/workflow/entities/variable_pool.py @@ -1,6 +1,8 @@ from enum import Enum from typing import Any, Optional, Union +from pydantic import BaseModel, Field + from core.file.file_obj import FileVar from core.workflow.entities.node_entities import SystemVariable @@ -21,20 +23,23 @@ class ValueType(Enum): FILE = "file" -class VariablePool: +class VariablePool(BaseModel): - def __init__(self, system_variables: dict[SystemVariable, Any], - user_inputs: dict) -> None: - # system variables - # for example: - # { - # 'query': 'abc', - # 'files': [] - # } - self.variables_mapping = {} - self.user_inputs = user_inputs - self.system_variables = system_variables - for system_variable, value in system_variables.items(): + variables_mapping: dict[str, dict[int, VariableValue]] = Field( + description='Variables mapping', + default={}, + ) + + user_inputs: dict = Field( + description='User inputs', + ) + + system_variables: dict[SystemVariable, Any] = Field( + description='System variables', + ) + + def __post_init__(self): + for system_variable, value in self.system_variables.items(): self.append_variable('sys', [system_variable.value], value) def append_variable(self, node_id: str, variable_key_list: list[str], value: VariableValue) -> None: diff --git a/api/core/workflow/graph.py b/api/core/workflow/graph.py index df49199d15..be981c2ea8 100644 --- a/api/core/workflow/graph.py +++ b/api/core/workflow/graph.py @@ -1,21 +1,40 @@ from collections.abc import Callable -from typing import Optional +from typing import Literal, Optional -from pydantic import BaseModel +from pydantic import BaseModel, Field + +from core.workflow.utils.condition.entities import Condition + + +class RunCondition(BaseModel): + type: Literal["branch_identify", "condition"] + """condition type""" + + branch_identify: Optional[str] = None + """branch identify, required when type is branch_identify""" + + conditions: Optional[list[Condition]] = None + """conditions to run the node, required when type is condition""" class GraphNode(BaseModel): id: str """node id""" + parent_id: Optional[str] = None + """parent node id, e.g. iteration/loop""" + predecessor_node_id: Optional[str] = None """predecessor node id""" - children_node_ids: list[str] = [] - """children node ids""" + descendant_node_ids: list[str] = [] + """descendant node ids""" - run_condition_callback: Optional[Callable] = None - """condition function check if the node can be executed""" + run_condition: Optional[RunCondition] = None + """condition to run the node""" + + run_condition_callback: Optional[Callable] = Field(None, exclude=True) + """condition function check if the node can be executed, translated from run_conditions, not serialized""" node_config: dict """original node config""" @@ -23,72 +42,96 @@ class GraphNode(BaseModel): source_edge_config: Optional[dict] = None """original source edge config""" - target_edge_config: Optional[dict] = None - """original target edge config""" + sub_graph: Optional["Graph"] = None + """sub graph of the node, e.g. iteration/loop sub graph""" def add_child(self, node_id: str) -> None: - self.children_node_ids.append(node_id) + self.descendant_node_ids.append(node_id) class Graph(BaseModel): - graph_config: dict - """graph config from workflow""" - graph_nodes: dict[str, GraphNode] = {} """graph nodes""" - root_node: Optional[GraphNode] = None + root_node: GraphNode """root node of the graph""" + @classmethod + def init(cls, root_node_config: dict, run_condition: Optional[RunCondition] = None) -> "Graph": + """ + Init graph + + :param root_node_config: root node config + :param run_condition: run condition when root node parent is iteration/loop + :return: graph + """ + root_node = GraphNode( + id=root_node_config.get('id'), + parent_id=root_node_config.get('parentId'), + node_config=root_node_config, + run_condition=run_condition + ) + + graph = cls(root_node=root_node) + + # TODO parse run_condition to run_condition_callback + + graph.add_graph_node(graph.root_node) + return graph + def add_edge(self, edge_config: dict, source_node_config: dict, target_node_config: dict, - run_condition_callback: Optional[Callable] = None) -> None: + target_node_sub_graph: Optional["Graph"] = None) -> None: """ Add edge to the graph :param edge_config: edge config :param source_node_config: source node config :param target_node_config: target node config - :param run_condition_callback: condition callback + :param target_node_sub_graph: sub graph """ source_node_id = source_node_config.get('id') if not source_node_id: return + if source_node_id not in self.graph_nodes: + return + target_node_id = target_node_config.get('id') if not target_node_id: return - if source_node_id not in self.graph_nodes: - source_graph_node = GraphNode( - id=source_node_id, - node_config=source_node_config, - children_node_ids=[target_node_id], - target_edge_config=edge_config, - ) + source_node = self.graph_nodes[source_node_id] + source_node.add_child(target_node_id) + + # if run_conditions: + # run_condition_callback = lambda: all() - self.add_graph_node(source_graph_node) - else: - source_node = self.graph_nodes[source_node_id] - source_node.add_child(target_node_id) - source_node.target_edge_config = edge_config if target_node_id not in self.graph_nodes: + run_condition = None # todo + run_condition_callback = None # todo + target_graph_node = GraphNode( id=target_node_id, + parent_id=source_node_config.get('parentId'), predecessor_node_id=source_node_id, node_config=target_node_config, + run_condition=run_condition, run_condition_callback=run_condition_callback, source_edge_config=edge_config, + sub_graph=target_node_sub_graph ) self.add_graph_node(target_graph_node) else: target_node = self.graph_nodes[target_node_id] target_node.predecessor_node_id = source_node_id + target_node.run_conditions = run_conditions target_node.run_condition_callback = run_condition_callback target_node.source_edge_config = edge_config + target_node.sub_graph = target_node_sub_graph def add_graph_node(self, graph_node: GraphNode) -> None: """ @@ -123,13 +166,13 @@ class Graph(BaseModel): return None graph_node = self.graph_nodes[node_id] - if not graph_node.children_node_ids: + if not graph_node.descendant_node_ids: return None - descendants_graph = Graph(graph_config=self.graph_config) + descendants_graph = Graph() descendants_graph.add_graph_node(graph_node) - for child_node_id in graph_node.children_node_ids: + for child_node_id in graph_node.descendant_node_ids: self._add_descendants_graph_nodes(descendants_graph, child_node_id) return descendants_graph @@ -147,5 +190,5 @@ class Graph(BaseModel): graph_node = self.graph_nodes[node_id] descendants_graph.add_graph_node(graph_node) - for child_node_id in graph_node.children_node_ids: + for child_node_id in graph_node.descendant_node_ids: self._add_descendants_graph_nodes(descendants_graph, child_node_id) diff --git a/api/core/workflow/nodes/if_else/entities.py b/api/core/workflow/nodes/if_else/entities.py index 68d51c93be..8af087a6c0 100644 --- a/api/core/workflow/nodes/if_else/entities.py +++ b/api/core/workflow/nodes/if_else/entities.py @@ -1,26 +1,12 @@ -from typing import Literal, Optional - -from pydantic import BaseModel +from typing import Literal from core.workflow.entities.base_node_data_entities import BaseNodeData +from core.workflow.utils.condition.entities import Condition class IfElseNodeData(BaseNodeData): """ Answer Node Data. """ - class Condition(BaseModel): - """ - Condition entity - """ - variable_selector: list[str] - comparison_operator: Literal[ - # for string or array - "contains", "not contains", "start with", "end with", "is", "is not", "empty", "not empty", - # for number - "=", "≠", ">", "<", "≥", "≤", "null", "not null" - ] - value: Optional[str] = None - logical_operator: Literal["and", "or"] = "and" conditions: list[Condition] diff --git a/api/core/workflow/nodes/if_else/if_else_node.py b/api/core/workflow/nodes/if_else/if_else_node.py index 44a4091a2e..9d53214972 100644 --- a/api/core/workflow/nodes/if_else/if_else_node.py +++ b/api/core/workflow/nodes/if_else/if_else_node.py @@ -1,10 +1,11 @@ -from typing import Optional, cast +from typing import cast from core.workflow.entities.base_node_data_entities import BaseNodeData from core.workflow.entities.node_entities import NodeRunResult, NodeType from core.workflow.entities.variable_pool import VariablePool from core.workflow.nodes.base_node import BaseNode from core.workflow.nodes.if_else.entities import IfElseNodeData +from core.workflow.utils.condition.processor import ConditionAssertionError, ConditionProcessor from models.workflow import WorkflowNodeExecutionStatus @@ -19,90 +20,42 @@ class IfElseNode(BaseNode): :return: """ node_data = self.node_data - node_data = cast(self._node_data_cls, node_data) + node_data = cast(IfElseNodeData, node_data) - node_inputs = { + node_inputs: dict[str, list] = { "conditions": [] } - process_datas = { + process_datas: dict[str, list] = { "condition_results": [] } try: - logical_operator = node_data.logical_operator - input_conditions = [] - for condition in node_data.conditions: - actual_value = variable_pool.get_variable_value( - variable_selector=condition.variable_selector - ) + processor = ConditionProcessor() + compare_result, sub_condition_compare_results = processor.process( + variable_pool=variable_pool, + logical_operator=node_data.logical_operator, + conditions=node_data.conditions, + ) - expected_value = condition.value + node_inputs["conditions"] = [{ + "actual_value": result['actual_value'], + "expected_value": result['expected_value'], + "comparison_operator": result['comparison_operator'], + } for result in sub_condition_compare_results] - input_conditions.append({ - "actual_value": actual_value, - "expected_value": expected_value, - "comparison_operator": condition.comparison_operator - }) - - node_inputs["conditions"] = input_conditions - - for input_condition in input_conditions: - actual_value = input_condition["actual_value"] - expected_value = input_condition["expected_value"] - comparison_operator = input_condition["comparison_operator"] - - if comparison_operator == "contains": - compare_result = self._assert_contains(actual_value, expected_value) - elif comparison_operator == "not contains": - compare_result = self._assert_not_contains(actual_value, expected_value) - elif comparison_operator == "start with": - compare_result = self._assert_start_with(actual_value, expected_value) - elif comparison_operator == "end with": - compare_result = self._assert_end_with(actual_value, expected_value) - elif comparison_operator == "is": - compare_result = self._assert_is(actual_value, expected_value) - elif comparison_operator == "is not": - compare_result = self._assert_is_not(actual_value, expected_value) - elif comparison_operator == "empty": - compare_result = self._assert_empty(actual_value) - elif comparison_operator == "not empty": - compare_result = self._assert_not_empty(actual_value) - elif comparison_operator == "=": - compare_result = self._assert_equal(actual_value, expected_value) - elif comparison_operator == "≠": - compare_result = self._assert_not_equal(actual_value, expected_value) - elif comparison_operator == ">": - compare_result = self._assert_greater_than(actual_value, expected_value) - elif comparison_operator == "<": - compare_result = self._assert_less_than(actual_value, expected_value) - elif comparison_operator == "≥": - compare_result = self._assert_greater_than_or_equal(actual_value, expected_value) - elif comparison_operator == "≤": - compare_result = self._assert_less_than_or_equal(actual_value, expected_value) - elif comparison_operator == "null": - compare_result = self._assert_null(actual_value) - elif comparison_operator == "not null": - compare_result = self._assert_not_null(actual_value) - else: - continue - - process_datas["condition_results"].append({ - **input_condition, - "result": compare_result - }) - except Exception as e: + process_datas["condition_results"] = sub_condition_compare_results + except ConditionAssertionError as e: + node_inputs["conditions"] = e.conditions + process_datas["condition_results"] = e.sub_condition_compare_results return NodeRunResult( status=WorkflowNodeExecutionStatus.FAILED, inputs=node_inputs, process_data=process_datas, error=str(e) ) - - if logical_operator == "and": - compare_result = False not in [condition["result"] for condition in process_datas["condition_results"]] - else: - compare_result = True in [condition["result"] for condition in process_datas["condition_results"]] + except Exception as e: + raise e return NodeRunResult( status=WorkflowNodeExecutionStatus.SUCCEEDED, @@ -114,280 +67,6 @@ class IfElseNode(BaseNode): } ) - def _assert_contains(self, actual_value: Optional[str | list], expected_value: str) -> bool: - """ - Assert contains - :param actual_value: actual value - :param expected_value: expected value - :return: - """ - if not actual_value: - return False - - if not isinstance(actual_value, str | list): - raise ValueError('Invalid actual value type: string or array') - - if expected_value not in actual_value: - return False - return True - - def _assert_not_contains(self, actual_value: Optional[str | list], expected_value: str) -> bool: - """ - Assert not contains - :param actual_value: actual value - :param expected_value: expected value - :return: - """ - if not actual_value: - return True - - if not isinstance(actual_value, str | list): - raise ValueError('Invalid actual value type: string or array') - - if expected_value in actual_value: - return False - return True - - def _assert_start_with(self, actual_value: Optional[str], expected_value: str) -> bool: - """ - Assert start with - :param actual_value: actual value - :param expected_value: expected value - :return: - """ - if not actual_value: - return False - - if not isinstance(actual_value, str): - raise ValueError('Invalid actual value type: string') - - if not actual_value.startswith(expected_value): - return False - return True - - def _assert_end_with(self, actual_value: Optional[str], expected_value: str) -> bool: - """ - Assert end with - :param actual_value: actual value - :param expected_value: expected value - :return: - """ - if not actual_value: - return False - - if not isinstance(actual_value, str): - raise ValueError('Invalid actual value type: string') - - if not actual_value.endswith(expected_value): - return False - return True - - def _assert_is(self, actual_value: Optional[str], expected_value: str) -> bool: - """ - Assert is - :param actual_value: actual value - :param expected_value: expected value - :return: - """ - if actual_value is None: - return False - - if not isinstance(actual_value, str): - raise ValueError('Invalid actual value type: string') - - if actual_value != expected_value: - return False - return True - - def _assert_is_not(self, actual_value: Optional[str], expected_value: str) -> bool: - """ - Assert is not - :param actual_value: actual value - :param expected_value: expected value - :return: - """ - if actual_value is None: - return False - - if not isinstance(actual_value, str): - raise ValueError('Invalid actual value type: string') - - if actual_value == expected_value: - return False - return True - - def _assert_empty(self, actual_value: Optional[str]) -> bool: - """ - Assert empty - :param actual_value: actual value - :return: - """ - if not actual_value: - return True - return False - - def _assert_not_empty(self, actual_value: Optional[str]) -> bool: - """ - Assert not empty - :param actual_value: actual value - :return: - """ - if actual_value: - return True - return False - - def _assert_equal(self, actual_value: Optional[int | float], expected_value: str) -> bool: - """ - Assert equal - :param actual_value: actual value - :param expected_value: expected value - :return: - """ - if actual_value is None: - return False - - if not isinstance(actual_value, int | float): - raise ValueError('Invalid actual value type: number') - - if isinstance(actual_value, int): - expected_value = int(expected_value) - else: - expected_value = float(expected_value) - - if actual_value != expected_value: - return False - return True - - def _assert_not_equal(self, actual_value: Optional[int | float], expected_value: str) -> bool: - """ - Assert not equal - :param actual_value: actual value - :param expected_value: expected value - :return: - """ - if actual_value is None: - return False - - if not isinstance(actual_value, int | float): - raise ValueError('Invalid actual value type: number') - - if isinstance(actual_value, int): - expected_value = int(expected_value) - else: - expected_value = float(expected_value) - - if actual_value == expected_value: - return False - return True - - def _assert_greater_than(self, actual_value: Optional[int | float], expected_value: str) -> bool: - """ - Assert greater than - :param actual_value: actual value - :param expected_value: expected value - :return: - """ - if actual_value is None: - return False - - if not isinstance(actual_value, int | float): - raise ValueError('Invalid actual value type: number') - - if isinstance(actual_value, int): - expected_value = int(expected_value) - else: - expected_value = float(expected_value) - - if actual_value <= expected_value: - return False - return True - - def _assert_less_than(self, actual_value: Optional[int | float], expected_value: str) -> bool: - """ - Assert less than - :param actual_value: actual value - :param expected_value: expected value - :return: - """ - if actual_value is None: - return False - - if not isinstance(actual_value, int | float): - raise ValueError('Invalid actual value type: number') - - if isinstance(actual_value, int): - expected_value = int(expected_value) - else: - expected_value = float(expected_value) - - if actual_value >= expected_value: - return False - return True - - def _assert_greater_than_or_equal(self, actual_value: Optional[int | float], expected_value: str) -> bool: - """ - Assert greater than or equal - :param actual_value: actual value - :param expected_value: expected value - :return: - """ - if actual_value is None: - return False - - if not isinstance(actual_value, int | float): - raise ValueError('Invalid actual value type: number') - - if isinstance(actual_value, int): - expected_value = int(expected_value) - else: - expected_value = float(expected_value) - - if actual_value < expected_value: - return False - return True - - def _assert_less_than_or_equal(self, actual_value: Optional[int | float], expected_value: str) -> bool: - """ - Assert less than or equal - :param actual_value: actual value - :param expected_value: expected value - :return: - """ - if actual_value is None: - return False - - if not isinstance(actual_value, int | float): - raise ValueError('Invalid actual value type: number') - - if isinstance(actual_value, int): - expected_value = int(expected_value) - else: - expected_value = float(expected_value) - - if actual_value > expected_value: - return False - return True - - def _assert_null(self, actual_value: Optional[int | float]) -> bool: - """ - Assert null - :param actual_value: actual value - :return: - """ - if actual_value is None: - return True - return False - - def _assert_not_null(self, actual_value: Optional[int | float]) -> bool: - """ - Assert not null - :param actual_value: actual value - :return: - """ - if actual_value is not None: - return True - return False - @classmethod def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[str, list[str]]: """ diff --git a/api/core/workflow/utils/condition/__init__.py b/api/core/workflow/utils/condition/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/workflow/utils/condition/entities.py b/api/core/workflow/utils/condition/entities.py new file mode 100644 index 0000000000..e195730a31 --- /dev/null +++ b/api/core/workflow/utils/condition/entities.py @@ -0,0 +1,17 @@ +from typing import Literal, Optional + +from pydantic import BaseModel + + +class Condition(BaseModel): + """ + Condition entity + """ + variable_selector: list[str] + comparison_operator: Literal[ + # for string or array + "contains", "not contains", "start with", "end with", "is", "is not", "empty", "not empty", + # for number + "=", "≠", ">", "<", "≥", "≤", "null", "not null" + ] + value: Optional[str] = None diff --git a/api/core/workflow/utils/condition/funcs.py b/api/core/workflow/utils/condition/funcs.py new file mode 100644 index 0000000000..fa2f3ca9f9 --- /dev/null +++ b/api/core/workflow/utils/condition/funcs.py @@ -0,0 +1,22 @@ +from typing import Optional + +from core.workflow.entities.node_entities import NodeRunResult +from core.workflow.entities.workflow_runtime_state_entities import WorkflowRuntimeState +from core.workflow.graph import GraphNode + + +def source_handle_condition_func(workflow_runtime_state: WorkflowRuntimeState, + graph_node: GraphNode, + # TODO cycle_state optional + predecessor_node_run_result: Optional[NodeRunResult] = None) -> bool: + if not graph_node.source_edge_config: + return False + + if not graph_node.source_edge_config.get('sourceHandle'): + return True + + source_handle = predecessor_node_run_result.edge_source_handle \ + if predecessor_node_run_result else None + + return (source_handle is not None + and graph_node.source_edge_config.get('sourceHandle') == source_handle) diff --git a/api/core/workflow/utils/condition/processor.py b/api/core/workflow/utils/condition/processor.py new file mode 100644 index 0000000000..405b01b9fa --- /dev/null +++ b/api/core/workflow/utils/condition/processor.py @@ -0,0 +1,369 @@ +from typing import Literal, Optional + +from core.workflow.entities.variable_pool import VariablePool +from core.workflow.utils.condition.entities import Condition + + +class ConditionProcessor: + def process(self, variable_pool: VariablePool, + logical_operator: Literal["and", "or"], + conditions: list[Condition]) -> tuple[bool, list[dict]]: + """ + Process conditions + + :param variable_pool: variable pool + :param logical_operator: logical operator + :param conditions: conditions + """ + input_conditions = [] + sub_condition_compare_results = [] + + try: + for condition in conditions: + actual_value = variable_pool.get_variable_value( + variable_selector=condition.variable_selector + ) + + expected_value = condition.value + + input_conditions.append({ + "actual_value": actual_value, + "expected_value": expected_value, + "comparison_operator": condition.comparison_operator + }) + + for input_condition in input_conditions: + actual_value = input_condition["actual_value"] + expected_value = input_condition["expected_value"] + comparison_operator = input_condition["comparison_operator"] + + if comparison_operator == "contains": + compare_result = self._assert_contains(actual_value, expected_value) + elif comparison_operator == "not contains": + compare_result = self._assert_not_contains(actual_value, expected_value) + elif comparison_operator == "start with": + compare_result = self._assert_start_with(actual_value, expected_value) + elif comparison_operator == "end with": + compare_result = self._assert_end_with(actual_value, expected_value) + elif comparison_operator == "is": + compare_result = self._assert_is(actual_value, expected_value) + elif comparison_operator == "is not": + compare_result = self._assert_is_not(actual_value, expected_value) + elif comparison_operator == "empty": + compare_result = self._assert_empty(actual_value) + elif comparison_operator == "not empty": + compare_result = self._assert_not_empty(actual_value) + elif comparison_operator == "=": + compare_result = self._assert_equal(actual_value, expected_value) + elif comparison_operator == "≠": + compare_result = self._assert_not_equal(actual_value, expected_value) + elif comparison_operator == ">": + compare_result = self._assert_greater_than(actual_value, expected_value) + elif comparison_operator == "<": + compare_result = self._assert_less_than(actual_value, expected_value) + elif comparison_operator == "≥": + compare_result = self._assert_greater_than_or_equal(actual_value, expected_value) + elif comparison_operator == "≤": + compare_result = self._assert_less_than_or_equal(actual_value, expected_value) + elif comparison_operator == "null": + compare_result = self._assert_null(actual_value) + elif comparison_operator == "not null": + compare_result = self._assert_not_null(actual_value) + else: + continue + + sub_condition_compare_results.append({ + **input_condition, + "result": compare_result + }) + except Exception as e: + raise ConditionAssertionError(str(e), input_conditions, sub_condition_compare_results) + + if logical_operator == "and": + compare_result = False not in [condition["result"] for condition in sub_condition_compare_results] + else: + compare_result = True in [condition["result"] for condition in sub_condition_compare_results] + + return compare_result, sub_condition_compare_results + + def _assert_contains(self, actual_value: Optional[str | list], expected_value: str) -> bool: + """ + Assert contains + :param actual_value: actual value + :param expected_value: expected value + :return: + """ + if not actual_value: + return False + + if not isinstance(actual_value, str | list): + raise ValueError('Invalid actual value type: string or array') + + if expected_value not in actual_value: + return False + return True + + def _assert_not_contains(self, actual_value: Optional[str | list], expected_value: str) -> bool: + """ + Assert not contains + :param actual_value: actual value + :param expected_value: expected value + :return: + """ + if not actual_value: + return True + + if not isinstance(actual_value, str | list): + raise ValueError('Invalid actual value type: string or array') + + if expected_value in actual_value: + return False + return True + + def _assert_start_with(self, actual_value: Optional[str], expected_value: str) -> bool: + """ + Assert start with + :param actual_value: actual value + :param expected_value: expected value + :return: + """ + if not actual_value: + return False + + if not isinstance(actual_value, str): + raise ValueError('Invalid actual value type: string') + + if not actual_value.startswith(expected_value): + return False + return True + + def _assert_end_with(self, actual_value: Optional[str], expected_value: str) -> bool: + """ + Assert end with + :param actual_value: actual value + :param expected_value: expected value + :return: + """ + if not actual_value: + return False + + if not isinstance(actual_value, str): + raise ValueError('Invalid actual value type: string') + + if not actual_value.endswith(expected_value): + return False + return True + + def _assert_is(self, actual_value: Optional[str], expected_value: str) -> bool: + """ + Assert is + :param actual_value: actual value + :param expected_value: expected value + :return: + """ + if actual_value is None: + return False + + if not isinstance(actual_value, str): + raise ValueError('Invalid actual value type: string') + + if actual_value != expected_value: + return False + return True + + def _assert_is_not(self, actual_value: Optional[str], expected_value: str) -> bool: + """ + Assert is not + :param actual_value: actual value + :param expected_value: expected value + :return: + """ + if actual_value is None: + return False + + if not isinstance(actual_value, str): + raise ValueError('Invalid actual value type: string') + + if actual_value == expected_value: + return False + return True + + def _assert_empty(self, actual_value: Optional[str]) -> bool: + """ + Assert empty + :param actual_value: actual value + :return: + """ + if not actual_value: + return True + return False + + def _assert_not_empty(self, actual_value: Optional[str]) -> bool: + """ + Assert not empty + :param actual_value: actual value + :return: + """ + if actual_value: + return True + return False + + def _assert_equal(self, actual_value: Optional[int | float], expected_value: str) -> bool: + """ + Assert equal + :param actual_value: actual value + :param expected_value: expected value + :return: + """ + if actual_value is None: + return False + + if not isinstance(actual_value, int | float): + raise ValueError('Invalid actual value type: number') + + if isinstance(actual_value, int): + expected_value = int(expected_value) + else: + expected_value = float(expected_value) + + if actual_value != expected_value: + return False + return True + + def _assert_not_equal(self, actual_value: Optional[int | float], expected_value: str) -> bool: + """ + Assert not equal + :param actual_value: actual value + :param expected_value: expected value + :return: + """ + if actual_value is None: + return False + + if not isinstance(actual_value, int | float): + raise ValueError('Invalid actual value type: number') + + if isinstance(actual_value, int): + expected_value = int(expected_value) + else: + expected_value = float(expected_value) + + if actual_value == expected_value: + return False + return True + + def _assert_greater_than(self, actual_value: Optional[int | float], expected_value: str) -> bool: + """ + Assert greater than + :param actual_value: actual value + :param expected_value: expected value + :return: + """ + if actual_value is None: + return False + + if not isinstance(actual_value, int | float): + raise ValueError('Invalid actual value type: number') + + if isinstance(actual_value, int): + expected_value = int(expected_value) + else: + expected_value = float(expected_value) + + if actual_value <= expected_value: + return False + return True + + def _assert_less_than(self, actual_value: Optional[int | float], expected_value: str) -> bool: + """ + Assert less than + :param actual_value: actual value + :param expected_value: expected value + :return: + """ + if actual_value is None: + return False + + if not isinstance(actual_value, int | float): + raise ValueError('Invalid actual value type: number') + + if isinstance(actual_value, int): + expected_value = int(expected_value) + else: + expected_value = float(expected_value) + + if actual_value >= expected_value: + return False + return True + + def _assert_greater_than_or_equal(self, actual_value: Optional[int | float], expected_value: str) -> bool: + """ + Assert greater than or equal + :param actual_value: actual value + :param expected_value: expected value + :return: + """ + if actual_value is None: + return False + + if not isinstance(actual_value, int | float): + raise ValueError('Invalid actual value type: number') + + if isinstance(actual_value, int): + expected_value = int(expected_value) + else: + expected_value = float(expected_value) + + if actual_value < expected_value: + return False + return True + + def _assert_less_than_or_equal(self, actual_value: Optional[int | float], expected_value: str) -> bool: + """ + Assert less than or equal + :param actual_value: actual value + :param expected_value: expected value + :return: + """ + if actual_value is None: + return False + + if not isinstance(actual_value, int | float): + raise ValueError('Invalid actual value type: number') + + if isinstance(actual_value, int): + expected_value = int(expected_value) + else: + expected_value = float(expected_value) + + if actual_value > expected_value: + return False + return True + + def _assert_null(self, actual_value: Optional[int | float]) -> bool: + """ + Assert null + :param actual_value: actual value + :return: + """ + if actual_value is None: + return True + return False + + def _assert_not_null(self, actual_value: Optional[int | float]) -> bool: + """ + Assert not null + :param actual_value: actual value + :return: + """ + if actual_value is not None: + return True + return False + + +class ConditionAssertionError(Exception): + def __init__(self, message: str, conditions: list[dict], sub_condition_compare_results: list[dict]) -> None: + self.message = message + self.conditions = conditions + self.sub_condition_compare_results = sub_condition_compare_results + super().__init__(self.message) diff --git a/api/core/workflow/workflow_engine_manager.py b/api/core/workflow/workflow_engine_manager.py index 544438a5bf..66c081860d 100644 --- a/api/core/workflow/workflow_engine_manager.py +++ b/api/core/workflow/workflow_engine_manager.py @@ -1,5 +1,4 @@ import logging -import threading import time from typing import Any, Optional, cast @@ -129,18 +128,172 @@ class WorkflowEngineManager: callbacks=callbacks ) - # run workflow - self._run_workflow( - graph_config=graph_config, - workflow_runtime_state=workflow_runtime_state, - callbacks=callbacks, - ) + try: + # run workflow + self._run_workflow( + graph_config=graph_config, + workflow_runtime_state=workflow_runtime_state, + callbacks=callbacks, + ) + except WorkflowRunFailedError as e: + self._workflow_run_failed( + error=e.error, + callbacks=callbacks + ) + except Exception as e: + self._workflow_run_failed( + error=str(e), + callbacks=callbacks + ) # workflow run success self._workflow_run_success( callbacks=callbacks ) + def _init_graph(self, graph_config: dict, root_node_id: Optional[str] = None) -> Optional[Graph]: + """ + Initialize graph + + :param graph_config: graph config + :param root_node_id: root node id if needed + :return: graph + """ + # edge configs + edge_configs = graph_config.get('edges') + if not edge_configs: + return None + + edge_configs = cast(list, edge_configs) + + # reorganize edges mapping + source_edges_mapping: dict[str, list[dict]] = {} + target_edge_ids = set() + for edge_config in edge_configs: + source_node_id = edge_config.get('source') + if not source_node_id: + continue + + if source_node_id not in source_edges_mapping: + source_edges_mapping[source_node_id] = [] + + source_edges_mapping[source_node_id].append(edge_config) + + target_node_id = edge_config.get('target') + if target_node_id: + target_edge_ids.add(target_node_id) + + # node configs + node_configs = graph_config.get('nodes') + if not node_configs: + return None + + node_configs = cast(list, node_configs) + + # fetch nodes that have no predecessor node + root_node_configs = [] + nodes_mapping: dict[str, dict] = {} + for node_config in node_configs: + node_id = node_config.get('id') + if not node_id: + continue + + if node_id not in target_edge_ids: + root_node_configs.append(node_config) + + nodes_mapping[node_id] = node_config + + # fetch root node + if root_node_id: + root_node_config = next((node_config for node_config in root_node_configs + if node_config.get('id') == root_node_id), None) + else: + # if no root node id, use the START type node as root node + root_node_config = next((node_config for node_config in root_node_configs + if node_config.get('data', {}).get('type', '') == NodeType.START.value), None) + + if not root_node_config: + return None + + # init graph + graph = Graph.init( + root_node_config=root_node_config + ) + + # add edge from root node + self._recursively_add_edges( + graph=graph, + source_node_config=root_node_config, + edges_mapping=source_edges_mapping, + nodes_mapping=nodes_mapping, + root_node_configs=root_node_configs + ) + + return graph + + def _recursively_add_edges(self, graph: Graph, + source_node_config: dict, + edges_mapping: dict, + nodes_mapping: dict, + root_node_configs: list[dict]) -> None: + """ + Add edges + + :param source_node_config: source node config + :param edges_mapping: edges mapping + :param nodes_mapping: nodes mapping + :param root_node_configs: root node configs + """ + source_node_id = source_node_config.get('id') + if not source_node_id: + return + + for edge_config in edges_mapping.get(source_node_id, []): + target_node_id = edge_config.get('target') + if not target_node_id: + continue + + target_node_config = nodes_mapping.get(target_node_id) + if not target_node_config: + continue + + sub_graph: Optional[Graph] = None + target_node_type: NodeType = NodeType.value_of(target_node_config.get('data', {}).get('type')) + if target_node_type and target_node_type in [IterationNode.node_type, NodeType.LOOP]: + # find iteration/loop sub nodes that have no predecessor node + for root_node_config in root_node_configs: + if root_node_config.get('parentId') == target_node_id: + # create sub graph + sub_graph = Graph.init( + root_node_config=root_node_config + ) + + self._recursively_add_edges( + graph=sub_graph, + source_node_config=root_node_config, + edges_mapping=edges_mapping, + nodes_mapping=nodes_mapping, + root_node_configs=root_node_configs + ) + break + + # add edge + graph.add_edge( + edge_config=edge_config, + source_node_config=source_node_config, + target_node_config=target_node_config, + target_node_sub_graph=sub_graph, + ) + + # recursively add edges + self._recursively_add_edges( + graph=graph, + source_node_config=target_node_config, + edges_mapping=edges_mapping, + nodes_mapping=nodes_mapping, + root_node_configs=root_node_configs + ) + def _run_workflow(self, graph_config: dict, workflow_runtime_state: WorkflowRuntimeState, callbacks: list[BaseWorkflowCallback], @@ -157,10 +310,15 @@ class WorkflowEngineManager: """ try: # init graph - graph = Graph( + graph = self._init_graph( graph_config=graph_config ) + if not graph: + raise WorkflowRunFailedError( + error='Start node not found in workflow graph.' + ) + predecessor_node: Optional[BaseNode] = None current_iteration_node: Optional[BaseIterationNode] = None max_execution_steps = current_app.config.get("WORKFLOW_MAX_EXECUTION_STEPS") @@ -231,11 +389,11 @@ class WorkflowEngineManager: # max steps reached if workflow_run_state.workflow_node_steps > max_execution_steps: - raise ValueError('Max steps {} reached.'.format(max_execution_steps)) + raise WorkflowRunFailedError('Max steps {} reached.'.format(max_execution_steps)) # or max execution time reached if self._is_timed_out(start_at=workflow_run_state.start_at, max_execution_time=max_execution_time): - raise ValueError('Max execution time {}s reached.'.format(max_execution_time)) + raise WorkflowRunFailedError('Max execution time {}s reached.'.format(max_execution_time)) if len(next_nodes) == 1: next_node = next_nodes[0] @@ -256,63 +414,59 @@ class WorkflowEngineManager: else: result_dict = {} - # new thread - worker_thread = threading.Thread(target=self._async_run_nodes, kwargs={ - 'flask_app': current_app._get_current_object(), - 'graph': graph, - 'workflow_run_state': workflow_run_state, - 'predecessor_node': predecessor_node, - 'next_nodes': next_nodes, - 'callbacks': callbacks, - 'result': result_dict - }) - - worker_thread.start() - worker_thread.join() + # # new thread + # worker_thread = threading.Thread(target=self._async_run_nodes, kwargs={ + # 'flask_app': current_app._get_current_object(), + # 'graph': graph, + # 'workflow_run_state': workflow_run_state, + # 'predecessor_node': predecessor_node, + # 'next_nodes': next_nodes, + # 'callbacks': callbacks, + # 'result': result_dict + # }) + # + # worker_thread.start() + # worker_thread.join() if not workflow_run_state.workflow_node_runs: - self._workflow_run_failed( - error='Start node not found in workflow graph.', - callbacks=callbacks + raise WorkflowRunFailedError( + error='Start node not found in workflow graph.' ) - return except GenerateTaskStoppedException as e: return except Exception as e: - self._workflow_run_failed( - error=str(e), - callbacks=callbacks + raise WorkflowRunFailedError( + error=str(e) ) - return - def _async_run_nodes(self, flask_app: Flask, - graph: dict, - workflow_run_state: WorkflowRunState, - predecessor_node: Optional[BaseNode], - next_nodes: list[BaseNode], - callbacks: list[BaseWorkflowCallback], - result: dict): - with flask_app.app_context(): - try: - for next_node in next_nodes: - # TODO run sub workflows - # run node - is_continue = self._run_node( - graph=graph, - workflow_run_state=workflow_run_state, - predecessor_node=predecessor_node, - current_node=next_node, - callbacks=callbacks - ) - - if not is_continue: - break - - predecessor_node = next_node - except Exception as e: - logger.exception("Unknown Error when generating") - finally: - db.session.remove() + # def _async_run_nodes(self, flask_app: Flask, + # graph: dict, + # workflow_run_state: WorkflowRunState, + # predecessor_node: Optional[BaseNode], + # next_nodes: list[BaseNode], + # callbacks: list[BaseWorkflowCallback], + # result: dict): + # with flask_app.app_context(): + # try: + # for next_node in next_nodes: + # # TODO run sub workflows + # # run node + # is_continue = self._run_node( + # graph=graph, + # workflow_run_state=workflow_run_state, + # predecessor_node=predecessor_node, + # current_node=next_node, + # callbacks=callbacks + # ) + # + # if not is_continue: + # break + # + # predecessor_node = next_node + # except Exception as e: + # logger.exception("Unknown Error when generating") + # finally: + # db.session.remove() def _run_node(self, graph: dict, workflow_run_state: WorkflowRunState, @@ -584,14 +738,25 @@ class WorkflowEngineManager: workflow_call_depth=0 ) - # run workflow - self._run_workflow( - graph=workflow.graph, - workflow_run_state=workflow_run_state, - callbacks=callbacks, - start_node=node_id, - end_node=end_node_id - ) + try: + # run workflow + self._run_workflow( + graph_config=workflow.graph, + workflow_runtime_state=workflow_runtime_state, + callbacks=callbacks, + start_node=node_id, + end_node=end_node_id + ) + except WorkflowRunFailedError as e: + self._workflow_run_failed( + error=e.error, + callbacks=callbacks + ) + except Exception as e: + self._workflow_run_failed( + error=str(e), + callbacks=callbacks + ) # workflow run success self._workflow_run_success( @@ -1072,3 +1237,8 @@ class WorkflowEngineManager: variable_key_list=variable_key_list, value=value ) + + +class WorkflowRunFailedError(Exception): + def __init__(self, error: str): + self.error = error diff --git a/api/tests/unit_tests/core/workflow/test_workflow_engine_manager.py b/api/tests/unit_tests/core/workflow/test_workflow_engine_manager.py new file mode 100644 index 0000000000..0d3ba65843 --- /dev/null +++ b/api/tests/unit_tests/core/workflow/test_workflow_engine_manager.py @@ -0,0 +1,234 @@ +from core.workflow.graph import Graph +from core.workflow.workflow_engine_manager import WorkflowEngineManager + + +def test__init_graph(): + graph_config = { + "edges": [ + { + "id": "llm-source-answer-target", + "source": "llm", + "target": "answer", + }, + { + "id": "1717222650545-source-1719481290322-target", + "source": "1717222650545", + "target": "1719481290322", + }, + { + "id": "1719481290322-1-llm-target", + "source": "1719481290322", + "sourceHandle": "1", + "target": "llm", + }, + { + "id": "1719481290322-2-1719481315734-target", + "source": "1719481290322", + "sourceHandle": "2", + "target": "1719481315734", + }, + { + "id": "1719481315734-source-1719481326339-target", + "source": "1719481315734", + "target": "1719481326339", + } + ], + "nodes": [ + { + "data": { + "desc": "", + "title": "Start", + "type": "start", + "variables": [ + { + "label": "name", + "max_length": 48, + "options": [], + "required": False, + "type": "text-input", + "variable": "name" + } + ] + }, + "id": "1717222650545", + "position": { + "x": -147.65487258270954, + "y": 263.5326708413438 + }, + }, + { + "data": { + "context": { + "enabled": False, + "variable_selector": [] + }, + "desc": "", + "memory": { + "query_prompt_template": "{{#sys.query#}}", + "role_prefix": { + "assistant": "", + "user": "" + }, + "window": { + "enabled": False, + "size": 10 + } + }, + "model": { + "completion_params": { + "temperature": 0 + }, + "mode": "chat", + "name": "anthropic.claude-3-sonnet-20240229-v1:0", + "provider": "bedrock" + }, + "prompt_config": { + "jinja2_variables": [ + { + "value_selector": [ + "sys", + "query" + ], + "variable": "query" + } + ] + }, + "prompt_template": [ + { + "edition_type": "basic", + "id": "8b02d178-3aa0-4dbd-82bf-8b6a40658300", + "jinja2_text": "", + "role": "system", + "text": "yep" + } + ], + "title": "LLM", + "type": "llm", + "variables": [], + "vision": { + "configs": { + "detail": "low" + }, + "enabled": True + } + }, + "id": "llm", + "position": { + "x": 654.0331237272932, + "y": 263.5326708413438 + }, + }, + { + "data": { + "answer": "123{{#llm.text#}}", + "desc": "", + "title": "Answer", + "type": "answer", + "variables": [] + }, + "id": "answer", + "position": { + "x": 958.1129142362784, + "y": 263.5326708413438 + }, + }, + { + "data": { + "classes": [ + { + "id": "1", + "name": "happy" + }, + { + "id": "2", + "name": "sad" + } + ], + "desc": "", + "instructions": "", + "model": { + "completion_params": { + "temperature": 0.7 + }, + "mode": "chat", + "name": "gpt-4o", + "provider": "openai" + }, + "query_variable_selector": [ + "1717222650545", + "sys.query" + ], + "title": "Question Classifier", + "topics": [], + "type": "question-classifier" + }, + "id": "1719481290322", + "position": { + "x": 165.25154615277052, + "y": 263.5326708413438 + } + }, + { + "data": { + "authorization": { + "config": None, + "type": "no-auth" + }, + "body": { + "data": "", + "type": "none" + }, + "desc": "", + "headers": "", + "method": "get", + "params": "", + "timeout": { + "max_connect_timeout": 0, + "max_read_timeout": 0, + "max_write_timeout": 0 + }, + "title": "HTTP Request", + "type": "http-request", + "url": "https://baidu.com", + "variables": [] + }, + "height": 88, + "id": "1719481315734", + "position": { + "x": 654.0331237272932, + "y": 474.1180064703089 + } + }, + { + "data": { + "answer": "{{#1719481315734.status_code#}}", + "desc": "", + "title": "Answer 2", + "type": "answer", + "variables": [] + }, + "height": 105, + "id": "1719481326339", + "position": { + "x": 958.1129142362784, + "y": 474.1180064703089 + }, + } + ], + } + + workflow_engine_manager = WorkflowEngineManager() + graph = workflow_engine_manager._init_graph( + graph_config=graph_config + ) + + assert graph.root_node.id == "1717222650545" + assert graph.root_node.source_edge_config is None + assert graph.root_node.target_edge_config is not None + assert graph.root_node.descendant_node_ids == ["1719481290322"] + + assert graph.graph_nodes.get("1719481290322") is not None + assert len(graph.graph_nodes.get("1719481290322").descendant_node_ids) == 2 + + assert graph.graph_nodes.get("llm").run_condition_callback is not None + assert graph.graph_nodes.get("1719481315734").run_condition_callback is not None