diff --git a/api/core/workflow/entities/variable_pool.py b/api/core/workflow/entities/variable_pool.py index cac1843529..b12ef1c64a 100644 --- a/api/core/workflow/entities/variable_pool.py +++ b/api/core/workflow/entities/variable_pool.py @@ -1,7 +1,7 @@ from enum import Enum from typing import Any, Optional, Union -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, model_validator from core.file.file_obj import FileVar from core.workflow.entities.node_entities import SystemVariable @@ -38,9 +38,21 @@ class VariablePool(BaseModel): description='System variables', ) - def __post_init__(self): - for system_variable, value in self.system_variables.items(): - self.append_variable('sys', [system_variable.value], value) + @model_validator(mode='before') + def append_system_variables(cls, v: dict) -> dict: + """ + Append system variables + :param v: params + :return: + """ + v['variables_mapping'] = { + 'sys': {} + } + system_variables = v['system_variables'] + for system_variable, value in system_variables.items(): + variable_key_list_hash = hash((system_variable.value,)) + v['variables_mapping']['sys'][variable_key_list_hash] = value + return v def append_variable(self, node_id: str, variable_key_list: list[str], value: VariableValue) -> None: """ diff --git a/api/core/workflow/graph_engine/condition_handlers/base_handler.py b/api/core/workflow/graph_engine/condition_handlers/base_handler.py index bdaa3d1529..80f6b6a8a7 100644 --- a/api/core/workflow/graph_engine/condition_handlers/base_handler.py +++ b/api/core/workflow/graph_engine/condition_handlers/base_handler.py @@ -4,6 +4,7 @@ from core.workflow.graph_engine.entities.graph import Graph from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState from core.workflow.graph_engine.entities.run_condition import RunCondition +from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState class RunConditionHandler(ABC): @@ -18,13 +19,13 @@ class RunConditionHandler(ABC): @abstractmethod def check(self, graph_runtime_state: GraphRuntimeState, - source_node_id: str, + previous_route_node_state: RouteNodeState, target_node_id: str) -> bool: """ Check if the condition can be executed :param graph_runtime_state: graph runtime state - :param source_node_id: source node id + :param previous_route_node_state: previous route node state :param target_node_id: target node id :return: bool """ diff --git a/api/core/workflow/graph_engine/condition_handlers/branch_identify_handler.py b/api/core/workflow/graph_engine/condition_handlers/branch_identify_handler.py index f2a52c6dab..17212440f7 100644 --- a/api/core/workflow/graph_engine/condition_handlers/branch_identify_handler.py +++ b/api/core/workflow/graph_engine/condition_handlers/branch_identify_handler.py @@ -1,29 +1,26 @@ from core.workflow.graph_engine.condition_handlers.base_handler import RunConditionHandler from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState +from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState class BranchIdentifyRunConditionHandler(RunConditionHandler): def check(self, graph_runtime_state: GraphRuntimeState, - source_node_id: str, + previous_route_node_state: RouteNodeState, target_node_id: str) -> bool: """ Check if the condition can be executed :param graph_runtime_state: graph runtime state - :param source_node_id: source node id + :param previous_route_node_state: previous route node state :param target_node_id: target node id :return: bool """ if not self.condition.branch_identify: raise Exception("Branch identify is required") - node_route_state = graph_runtime_state.node_run_state.node_state_mapping.get(source_node_id) - if not node_route_state: - return False - - run_result = node_route_state.node_run_result + run_result = previous_route_node_state.node_run_result if not run_result: return False diff --git a/api/core/workflow/graph_engine/condition_handlers/condition_handler.py b/api/core/workflow/graph_engine/condition_handlers/condition_handler.py index 15946bdd84..fce5ef13f1 100644 --- a/api/core/workflow/graph_engine/condition_handlers/condition_handler.py +++ b/api/core/workflow/graph_engine/condition_handlers/condition_handler.py @@ -1,18 +1,19 @@ from core.workflow.graph_engine.condition_handlers.base_handler import RunConditionHandler from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState +from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState from core.workflow.utils.condition.processor import ConditionProcessor class ConditionRunConditionHandlerHandler(RunConditionHandler): def check(self, graph_runtime_state: GraphRuntimeState, - source_node_id: str, + previous_route_node_state: RouteNodeState, target_node_id: str) -> bool: """ Check if the condition can be executed :param graph_runtime_state: graph runtime state - :param source_node_id: source node id + :param previous_route_node_state: previous route node state :param target_node_id: target node id :return: bool """ diff --git a/api/core/workflow/graph_engine/graph_engine.py b/api/core/workflow/graph_engine/graph_engine.py index 00ce48571b..627eba19d4 100644 --- a/api/core/workflow/graph_engine/graph_engine.py +++ b/api/core/workflow/graph_engine/graph_engine.py @@ -167,7 +167,7 @@ class GraphEngine: run_condition=edge.run_condition, ).check( graph_runtime_state=self.graph_runtime_state, - source_node_id=edge.source_node_id, + previous_route_node_state=previous_route_node_state, target_node_id=edge.target_node_id, ) diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_graph_engine.py b/api/tests/unit_tests/core/workflow/graph_engine/test_graph_engine.py index 4b25b2a3fb..ef6c5bf289 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_graph_engine.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_graph_engine.py @@ -21,7 +21,7 @@ from models.workflow import WorkflowType @patch('extensions.ext_database.db.session.remove') @patch('extensions.ext_database.db.session.close') -def test_run(mock_close, mock_remove): +def test_run_parallel(mock_close, mock_remove): graph_config = { "edges": [ { @@ -127,7 +127,7 @@ def test_run(mock_close, mock_remove): max_execution_time=1200 ) - print("") + # print("") app = Flask('test') @@ -135,7 +135,7 @@ def test_run(mock_close, mock_remove): with app.app_context(): generator = graph_engine.run() for item in generator: - print(type(item), item) + # print(type(item), item) items.append(item) if isinstance(item, NodeRunSucceededEvent): assert item.route_node_state.status == RouteNodeState.Status.SUCCESS @@ -153,4 +153,146 @@ def test_run(mock_close, mock_remove): assert isinstance(items[2], NodeRunSucceededEvent) assert items[2].route_node_state.node_id == 'start' - print(graph_engine.graph_runtime_state) + +@patch('extensions.ext_database.db.session.remove') +@patch('extensions.ext_database.db.session.close') +def test_run_branch(mock_close, mock_remove): + graph_config = { + "edges": [{ + "id": "1", + "source": "start", + "target": "if-else-1", + }, { + "id": "2", + "source": "if-else-1", + "sourceHandle": "true", + "target": "answer-1", + }, { + "id": "3", + "source": "if-else-1", + "sourceHandle": "false", + "target": "if-else-2", + }, { + "id": "4", + "source": "if-else-2", + "sourceHandle": "true", + "target": "answer-2", + }, { + "id": "5", + "source": "if-else-2", + "sourceHandle": "false", + "target": "answer-3", + }], + "nodes": [{ + "data": { + "title": "Start", + "type": "start", + "variables": [] + }, + "id": "start" + }, { + "data": { + "answer": "1", + "title": "Answer", + "type": "answer", + "variables": [] + }, + "id": "answer-1", + }, { + "data": { + "cases": [{ + "case_id": "true", + "conditions": [{ + "comparison_operator": "contains", + "id": "b0f02473-08b6-4a81-af91-15345dcb2ec8", + "value": "hi", + "varType": "string", + "variable_selector": ["sys", "query"] + }], + "id": "true", + "logical_operator": "and" + }], + "desc": "", + "title": "IF/ELSE", + "type": "if-else" + }, + "id": "if-else-1", + }, { + "data": { + "cases": [{ + "case_id": "true", + "conditions": [{ + "comparison_operator": "contains", + "id": "ae895199-5608-433b-b5f0-0997ae1431e4", + "value": "takatost", + "varType": "string", + "variable_selector": ["sys", "query"] + }], + "id": "true", + "logical_operator": "and" + }], + "title": "IF/ELSE 2", + "type": "if-else" + }, + "id": "if-else-2", + }, { + "data": { + "answer": "2", + "title": "Answer 2", + "type": "answer", + }, + "id": "answer-2", + }, { + "data": { + "answer": "3", + "title": "Answer 3", + "type": "answer", + }, + "id": "answer-3", + }] + } + + graph = Graph.init( + graph_config=graph_config + ) + + variable_pool = VariablePool(system_variables={ + SystemVariable.QUERY: 'hi', + SystemVariable.FILES: [], + SystemVariable.CONVERSATION_ID: 'abababa', + SystemVariable.USER_ID: 'aaa' + }, user_inputs={}) + + graph_engine = GraphEngine( + tenant_id="111", + app_id="222", + workflow_type=WorkflowType.CHAT, + workflow_id="333", + user_id="444", + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.WEB_APP, + call_depth=0, + graph=graph, + variable_pool=variable_pool, + max_execution_steps=500, + max_execution_time=1200 + ) + + # print("") + + app = Flask('test') + + items = [] + with app.app_context(): + generator = graph_engine.run() + for item in generator: + # print(type(item), item) + items.append(item) + + assert len(items) == 8 + assert items[3].route_node_state.node_id == 'if-else-1' + assert items[4].route_node_state.node_id == 'if-else-1' + assert items[5].route_node_state.node_id == 'answer-1' + assert items[6].route_node_state.node_id == 'answer-1' + + # print(graph_engine.graph_runtime_state.model_dump_json(indent=2))