From f1fd05ccda7eca9637d2c904956a548af422a717 Mon Sep 17 00:00:00 2001 From: QuantumGhost Date: Mon, 26 May 2025 14:23:34 +0800 Subject: [PATCH] test(api): Add some tests for DraftVariable related code --- .../test_workflow_draft_variable_service.py | 103 ++++++++++++++++-- .../core/app/segments/test_factory.py | 7 +- .../utils/test_variable_template_parser.py | 48 +++++--- api/tests/unit_tests/models/test_workflow.py | 20 +++- .../test_workflow_draft_variable_service.py | 82 ++++++++++++++ 5 files changed, 234 insertions(+), 26 deletions(-) create mode 100644 api/tests/unit_tests/services/workflow/test_workflow_draft_variable_service.py diff --git a/api/tests/integration_tests/services/test_workflow_draft_variable_service.py b/api/tests/integration_tests/services/test_workflow_draft_variable_service.py index fbe7826b3a..8288dc7bd8 100644 --- a/api/tests/integration_tests/services/test_workflow_draft_variable_service.py +++ b/api/tests/integration_tests/services/test_workflow_draft_variable_service.py @@ -4,40 +4,42 @@ import uuid import pytest from sqlalchemy.orm import Session +from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID from factories.variable_factory import build_segment from models import db from models.workflow import WorkflowDraftVariable -from services.workflow_draft_variable_service import WorkflowDraftVariableService +from services.workflow_draft_variable_service import DraftVarLoader, WorkflowDraftVariableService @pytest.mark.usefixtures("flask_req_ctx") class TestWorkflowDraftVariableService(unittest.TestCase): _test_app_id: str _session: Session + _node1_id = "test_node_1" _node2_id = "test_node_2" def setUp(self): self._test_app_id = str(uuid.uuid4()) self._session: Session = db.session - sys_var = WorkflowDraftVariable.create_sys_variable( + sys_var = WorkflowDraftVariable.new_sys_variable( app_id=self._test_app_id, name="sys_var", value=build_segment("sys_value"), ) - conv_var = WorkflowDraftVariable.create_conversation_variable( + conv_var = WorkflowDraftVariable.new_conversation_variable( app_id=self._test_app_id, name="conv_var", value=build_segment("conv_value"), ) node2_vars = [ - WorkflowDraftVariable.create_node_variable( + WorkflowDraftVariable.new_node_variable( app_id=self._test_app_id, node_id=self._node2_id, name="int_var", value=build_segment(1), visible=False, ), - WorkflowDraftVariable.create_node_variable( + WorkflowDraftVariable.new_node_variable( app_id=self._test_app_id, node_id=self._node2_id, name="str_var", @@ -45,9 +47,9 @@ class TestWorkflowDraftVariableService(unittest.TestCase): visible=True, ), ] - node1_var = WorkflowDraftVariable.create_node_variable( + node1_var = WorkflowDraftVariable.new_node_variable( app_id=self._test_app_id, - node_id="node_1", + node_id=self._node1_id, name="str_var", value=build_segment("str_value"), visible=True, @@ -92,7 +94,7 @@ class TestWorkflowDraftVariableService(unittest.TestCase): def test_get_node_variable(self): srv = self._get_test_srv() - node_var = srv.get_node_variable(self._test_app_id, "node_1", "str_var") + node_var = srv.get_node_variable(self._test_app_id, self._node1_id, "str_var") assert node_var.id == self._node1_str_var_id assert node_var.name == "str_var" assert node_var.get_value() == build_segment("str_value") @@ -138,5 +140,86 @@ class TestWorkflowDraftVariableService(unittest.TestCase): def test__list_node_variables(self): srv = self._get_test_srv() node_vars = srv._list_node_variables(self._test_app_id, self._node2_id) - assert len(node_vars) == 2 - assert {v.id for v in node_vars} == set(self._node2_var_ids) + assert len(node_vars.variables) == 2 + assert {v.id for v in node_vars.variables} == set(self._node2_var_ids) + + def test_get_draft_variables_by_selectors(self): + srv = self._get_test_srv() + selectors = [ + [self._node1_id, "str_var"], + [self._node2_id, "str_var"], + [self._node2_id, "int_var"], + ] + variables = srv.get_draft_variables_by_selectors(self._test_app_id, selectors) + assert len(variables) == 3 + assert {v.id for v in variables} == {self._node1_str_var_id} | set(self._node2_var_ids) + + +@pytest.mark.usefixtures("flask_req_ctx") +class TestDraftVariableLoader(unittest.TestCase): + _test_app_id: str + + _node1_id = "test_loader_node_1" + + def setUp(self): + self._test_app_id = str(uuid.uuid4()) + sys_var = WorkflowDraftVariable.new_sys_variable( + app_id=self._test_app_id, + name="sys_var", + value=build_segment("sys_value"), + ) + conv_var = WorkflowDraftVariable.new_conversation_variable( + app_id=self._test_app_id, + name="conv_var", + value=build_segment("conv_value"), + ) + node_var = WorkflowDraftVariable.new_node_variable( + app_id=self._test_app_id, + node_id=self._node1_id, + name="str_var", + value=build_segment("str_value"), + visible=True, + ) + _variables = [ + node_var, + sys_var, + conv_var, + ] + + with Session(bind=db.engine, expire_on_commit=False) as session: + session.add_all(_variables) + session.flush() + session.commit() + self._variable_ids = [v.id for v in _variables] + self._node_var_id = node_var.id + self._sys_var_id = sys_var.id + self._conv_var_id = conv_var.id + + def tearDown(self): + with Session(bind=db.engine, expire_on_commit=False) as session: + session.query(WorkflowDraftVariable).filter(WorkflowDraftVariable.app_id == self._test_app_id).delete( + synchronize_session=False + ) + session.commit() + + def test_variable_loader_with_empty_selector(self): + var_loader = DraftVarLoader(engine=db.engine, app_id=self._test_app_id) + variables = var_loader.load_variables([]) + assert len(variables) == 0 + + def test_variable_loader_with_non_empty_selector(self): + var_loader = DraftVarLoader(engine=db.engine, app_id=self._test_app_id) + variables = var_loader.load_variables( + [ + [SYSTEM_VARIABLE_NODE_ID, "sys_var"], + [CONVERSATION_VARIABLE_NODE_ID, "conv_var"], + [self._node1_id, "str_var"], + ] + ) + assert len(variables) == 3 + conv_var = next(v for v in variables if v.selector[0] == CONVERSATION_VARIABLE_NODE_ID) + assert conv_var.id == self._conv_var_id + sys_var = next(v for v in variables if v.selector[0] == SYSTEM_VARIABLE_NODE_ID) + assert sys_var.id == self._sys_var_id + node1_var = next(v for v in variables if v.selector[0] == self._node1_id) + assert node1_var.id == self._node_var_id diff --git a/api/tests/unit_tests/core/app/segments/test_factory.py b/api/tests/unit_tests/core/app/segments/test_factory.py index 68fc85aa17..725351d429 100644 --- a/api/tests/unit_tests/core/app/segments/test_factory.py +++ b/api/tests/unit_tests/core/app/segments/test_factory.py @@ -1,3 +1,4 @@ +import math from dataclasses import dataclass from uuid import uuid4 @@ -232,7 +233,11 @@ def _scalar_value() -> st.SearchStrategy[int | float | str | File]: @given(_scalar_value()) def test_build_segment_and_extract_values_for_scalar_types(value): seg = variable_factory.build_segment(value) - assert seg.value == value + # nan == nan yields false, so we need to use `math.isnan` to check `seg.value` here. + if isinstance(value, float) and math.isnan(value): + assert math.isnan(seg.value) + else: + assert seg.value == value @given(st.lists(_scalar_value())) diff --git a/api/tests/unit_tests/core/workflow/utils/test_variable_template_parser.py b/api/tests/unit_tests/core/workflow/utils/test_variable_template_parser.py index 2f90afcf89..28ef05edde 100644 --- a/api/tests/unit_tests/core/workflow/utils/test_variable_template_parser.py +++ b/api/tests/unit_tests/core/workflow/utils/test_variable_template_parser.py @@ -1,22 +1,10 @@ -from core.variables import SecretVariable +import dataclasses + from core.workflow.entities.variable_entities import VariableSelector -from core.workflow.entities.variable_pool import VariablePool -from core.workflow.enums import SystemVariableKey from core.workflow.utils import variable_template_parser def test_extract_selectors_from_template(): - variable_pool = VariablePool( - system_variables={ - SystemVariableKey("user_id"): "fake-user-id", - }, - user_inputs={}, - environment_variables=[ - SecretVariable(name="secret_key", value="fake-secret-key"), - ], - conversation_variables=[], - ) - variable_pool.add(("node_id", "custom_query"), "fake-user-query") template = ( "Hello, {{#sys.user_id#}}! Your query is {{#node_id.custom_query#}}. And your key is {{#env.secret_key#}}." ) @@ -26,3 +14,35 @@ def test_extract_selectors_from_template(): VariableSelector(variable="#node_id.custom_query#", value_selector=["node_id", "custom_query"]), VariableSelector(variable="#env.secret_key#", value_selector=["env", "secret_key"]), ] + + +def test_invalid_references(): + @dataclasses.dataclass + class TestCase: + name: str + template: str + + cases = [ + TestCase( + name="lack of closing brace", + template="Hello, {{#sys.user_id#", + ), + TestCase( + name="lack of opening brace", + template="Hello, #sys.user_id#}}", + ), + TestCase( + name="lack selector name", + template="Hello, {{#sys#}}", + ), + TestCase( + name="empty node name part", + template="Hello, {{#.user_id#}}", + ), + ] + for idx, c in enumerate(cases, 1): + fail_msg = f"Test case {c.name} failed, index={idx}" + selectors = variable_template_parser.extract_selectors_from_template(c.template) + assert selectors == [], fail_msg + parser = variable_template_parser.VariableTemplateParser(c.template) + assert parser.extract_variable_selectors() == [], fail_msg diff --git a/api/tests/unit_tests/models/test_workflow.py b/api/tests/unit_tests/models/test_workflow.py index 70ce224eb6..e7633e6141 100644 --- a/api/tests/unit_tests/models/test_workflow.py +++ b/api/tests/unit_tests/models/test_workflow.py @@ -5,7 +5,7 @@ from uuid import uuid4 import contexts from constants import HIDDEN_VALUE from core.variables import FloatVariable, IntegerVariable, SecretVariable, StringVariable -from models.workflow import Workflow, WorkflowNodeExecution +from models.workflow import Workflow, WorkflowNodeExecution, is_system_variable_editable def test_environment_variables(): @@ -149,3 +149,21 @@ class TestWorkflowNodeExecution: original = {"a": 1, "b": ["2"]} node_exec.execution_metadata = json.dumps(original) assert node_exec.execution_metadata_dict == original + + +class TestIsSystemVariableEditable: + def test_is_system_variable(self): + cases = [ + ("query", True), + ("files", True), + ("dialogue_count", False), + ("conversation_id", False), + ("user_id", False), + ("app_id", False), + ("workflow_id", False), + ("workflow_run_id", False), + ] + for name, editable in cases: + assert editable == is_system_variable_editable(name) + + assert is_system_variable_editable("invalid_or_new_system_variable") == False diff --git a/api/tests/unit_tests/services/workflow/test_workflow_draft_variable_service.py b/api/tests/unit_tests/services/workflow/test_workflow_draft_variable_service.py new file mode 100644 index 0000000000..c1da6eaede --- /dev/null +++ b/api/tests/unit_tests/services/workflow/test_workflow_draft_variable_service.py @@ -0,0 +1,82 @@ +import dataclasses +import secrets + +from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID +from core.workflow.nodes import NodeType +from factories.variable_factory import build_segment +from models.workflow import WorkflowDraftVariable +from services.workflow_draft_variable_service import _DraftVariableBuilder + + +class TestDraftVariableBuilder: + def _get_test_app_id(self): + suffix = secrets.token_hex(6) + return f"test_app_id_{suffix}" + + def test_get_variables(self): + test_app_id = self._get_test_app_id() + builder = _DraftVariableBuilder(app_id=test_app_id) + variables = [ + WorkflowDraftVariable.new_node_variable( + app_id=test_app_id, + node_id="test_node_1", + name="test_var_1", + value=build_segment("test_value_1"), + visible=True, + ), + WorkflowDraftVariable.new_sys_variable( + app_id=test_app_id, + name="test_sys_var", + value=build_segment("test_sys_value"), + ), + WorkflowDraftVariable.new_conversation_variable( + app_id=test_app_id, + name="test_conv_var", + value=build_segment("test_conv_value"), + ), + ] + builder._draft_vars = variables + assert builder.get_variables() == variables + + def test__should_variable_be_visible(self): + assert _DraftVariableBuilder._should_variable_be_visible(NodeType.IF_ELSE, "123_456", "output") == False + assert _DraftVariableBuilder._should_variable_be_visible(NodeType.START, "123", "output") == True + + def test__normalize_variable_for_start_node(self): + @dataclasses.dataclass(frozen=True) + class TestCase: + name: str + input_node_id: str + input_name: str + expected_node_id: str + expected_name: str + + _NODE_ID = "1747228642872" + cases = [ + TestCase( + name="name with `sys.` prefix should return the system node_id", + input_node_id=_NODE_ID, + input_name="sys.workflow_id", + expected_node_id=SYSTEM_VARIABLE_NODE_ID, + expected_name="workflow_id", + ), + TestCase( + name="name without `sys.` prefix should return the original input node_id", + input_node_id=_NODE_ID, + input_name="start_input", + expected_node_id=_NODE_ID, + expected_name="start_input", + ), + TestCase( + name="dummy_variable should return the original input node_id", + input_node_id=_NODE_ID, + input_name="__dummy__", + expected_node_id=_NODE_ID, + expected_name="__dummy__", + ), + ] + for idx, c in enumerate(cases, 1): + fail_msg = f"Test case {c.name} failed, index={idx}" + node_id, name = _DraftVariableBuilder._normalize_variable_for_start_node(c.input_node_id, c.input_name) + assert node_id == c.expected_node_id, fail_msg + assert name == c.expected_name, fail_msg