test(api): Add some tests for DraftVariable related code

This commit is contained in:
QuantumGhost 2025-05-26 14:23:34 +08:00
parent 7165333468
commit c730f0fcf2
5 changed files with 234 additions and 26 deletions

View File

@ -4,40 +4,42 @@ import uuid
import pytest
from sqlalchemy.orm import Session
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
from factories.variable_factory import build_segment
from models import db
from models.workflow import WorkflowDraftVariable
from services.workflow_draft_variable_service import WorkflowDraftVariableService
from services.workflow_draft_variable_service import DraftVarLoader, WorkflowDraftVariableService
@pytest.mark.usefixtures("flask_req_ctx")
class TestWorkflowDraftVariableService(unittest.TestCase):
_test_app_id: str
_session: Session
_node1_id = "test_node_1"
_node2_id = "test_node_2"
def setUp(self):
self._test_app_id = str(uuid.uuid4())
self._session: Session = db.session
sys_var = WorkflowDraftVariable.create_sys_variable(
sys_var = WorkflowDraftVariable.new_sys_variable(
app_id=self._test_app_id,
name="sys_var",
value=build_segment("sys_value"),
)
conv_var = WorkflowDraftVariable.create_conversation_variable(
conv_var = WorkflowDraftVariable.new_conversation_variable(
app_id=self._test_app_id,
name="conv_var",
value=build_segment("conv_value"),
)
node2_vars = [
WorkflowDraftVariable.create_node_variable(
WorkflowDraftVariable.new_node_variable(
app_id=self._test_app_id,
node_id=self._node2_id,
name="int_var",
value=build_segment(1),
visible=False,
),
WorkflowDraftVariable.create_node_variable(
WorkflowDraftVariable.new_node_variable(
app_id=self._test_app_id,
node_id=self._node2_id,
name="str_var",
@ -45,9 +47,9 @@ class TestWorkflowDraftVariableService(unittest.TestCase):
visible=True,
),
]
node1_var = WorkflowDraftVariable.create_node_variable(
node1_var = WorkflowDraftVariable.new_node_variable(
app_id=self._test_app_id,
node_id="node_1",
node_id=self._node1_id,
name="str_var",
value=build_segment("str_value"),
visible=True,
@ -92,7 +94,7 @@ class TestWorkflowDraftVariableService(unittest.TestCase):
def test_get_node_variable(self):
srv = self._get_test_srv()
node_var = srv.get_node_variable(self._test_app_id, "node_1", "str_var")
node_var = srv.get_node_variable(self._test_app_id, self._node1_id, "str_var")
assert node_var.id == self._node1_str_var_id
assert node_var.name == "str_var"
assert node_var.get_value() == build_segment("str_value")
@ -138,5 +140,86 @@ class TestWorkflowDraftVariableService(unittest.TestCase):
def test__list_node_variables(self):
srv = self._get_test_srv()
node_vars = srv._list_node_variables(self._test_app_id, self._node2_id)
assert len(node_vars) == 2
assert {v.id for v in node_vars} == set(self._node2_var_ids)
assert len(node_vars.variables) == 2
assert {v.id for v in node_vars.variables} == set(self._node2_var_ids)
def test_get_draft_variables_by_selectors(self):
srv = self._get_test_srv()
selectors = [
[self._node1_id, "str_var"],
[self._node2_id, "str_var"],
[self._node2_id, "int_var"],
]
variables = srv.get_draft_variables_by_selectors(self._test_app_id, selectors)
assert len(variables) == 3
assert {v.id for v in variables} == {self._node1_str_var_id} | set(self._node2_var_ids)
@pytest.mark.usefixtures("flask_req_ctx")
class TestDraftVariableLoader(unittest.TestCase):
_test_app_id: str
_node1_id = "test_loader_node_1"
def setUp(self):
self._test_app_id = str(uuid.uuid4())
sys_var = WorkflowDraftVariable.new_sys_variable(
app_id=self._test_app_id,
name="sys_var",
value=build_segment("sys_value"),
)
conv_var = WorkflowDraftVariable.new_conversation_variable(
app_id=self._test_app_id,
name="conv_var",
value=build_segment("conv_value"),
)
node_var = WorkflowDraftVariable.new_node_variable(
app_id=self._test_app_id,
node_id=self._node1_id,
name="str_var",
value=build_segment("str_value"),
visible=True,
)
_variables = [
node_var,
sys_var,
conv_var,
]
with Session(bind=db.engine, expire_on_commit=False) as session:
session.add_all(_variables)
session.flush()
session.commit()
self._variable_ids = [v.id for v in _variables]
self._node_var_id = node_var.id
self._sys_var_id = sys_var.id
self._conv_var_id = conv_var.id
def tearDown(self):
with Session(bind=db.engine, expire_on_commit=False) as session:
session.query(WorkflowDraftVariable).filter(WorkflowDraftVariable.app_id == self._test_app_id).delete(
synchronize_session=False
)
session.commit()
def test_variable_loader_with_empty_selector(self):
var_loader = DraftVarLoader(engine=db.engine, app_id=self._test_app_id)
variables = var_loader.load_variables([])
assert len(variables) == 0
def test_variable_loader_with_non_empty_selector(self):
var_loader = DraftVarLoader(engine=db.engine, app_id=self._test_app_id)
variables = var_loader.load_variables(
[
[SYSTEM_VARIABLE_NODE_ID, "sys_var"],
[CONVERSATION_VARIABLE_NODE_ID, "conv_var"],
[self._node1_id, "str_var"],
]
)
assert len(variables) == 3
conv_var = next(v for v in variables if v.selector[0] == CONVERSATION_VARIABLE_NODE_ID)
assert conv_var.id == self._conv_var_id
sys_var = next(v for v in variables if v.selector[0] == SYSTEM_VARIABLE_NODE_ID)
assert sys_var.id == self._sys_var_id
node1_var = next(v for v in variables if v.selector[0] == self._node1_id)
assert node1_var.id == self._node_var_id

View File

@ -1,3 +1,4 @@
import math
from dataclasses import dataclass
from uuid import uuid4
@ -232,7 +233,11 @@ def _scalar_value() -> st.SearchStrategy[int | float | str | File]:
@given(_scalar_value())
def test_build_segment_and_extract_values_for_scalar_types(value):
seg = variable_factory.build_segment(value)
assert seg.value == value
# nan == nan yields false, so we need to use `math.isnan` to check `seg.value` here.
if isinstance(value, float) and math.isnan(value):
assert math.isnan(seg.value)
else:
assert seg.value == value
@given(st.lists(_scalar_value()))

View File

@ -1,22 +1,10 @@
from core.variables import SecretVariable
import dataclasses
from core.workflow.entities.variable_entities import VariableSelector
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.enums import SystemVariableKey
from core.workflow.utils import variable_template_parser
def test_extract_selectors_from_template():
variable_pool = VariablePool(
system_variables={
SystemVariableKey("user_id"): "fake-user-id",
},
user_inputs={},
environment_variables=[
SecretVariable(name="secret_key", value="fake-secret-key"),
],
conversation_variables=[],
)
variable_pool.add(("node_id", "custom_query"), "fake-user-query")
template = (
"Hello, {{#sys.user_id#}}! Your query is {{#node_id.custom_query#}}. And your key is {{#env.secret_key#}}."
)
@ -26,3 +14,35 @@ def test_extract_selectors_from_template():
VariableSelector(variable="#node_id.custom_query#", value_selector=["node_id", "custom_query"]),
VariableSelector(variable="#env.secret_key#", value_selector=["env", "secret_key"]),
]
def test_invalid_references():
@dataclasses.dataclass
class TestCase:
name: str
template: str
cases = [
TestCase(
name="lack of closing brace",
template="Hello, {{#sys.user_id#",
),
TestCase(
name="lack of opening brace",
template="Hello, #sys.user_id#}}",
),
TestCase(
name="lack selector name",
template="Hello, {{#sys#}}",
),
TestCase(
name="empty node name part",
template="Hello, {{#.user_id#}}",
),
]
for idx, c in enumerate(cases, 1):
fail_msg = f"Test case {c.name} failed, index={idx}"
selectors = variable_template_parser.extract_selectors_from_template(c.template)
assert selectors == [], fail_msg
parser = variable_template_parser.VariableTemplateParser(c.template)
assert parser.extract_variable_selectors() == [], fail_msg

View File

@ -4,7 +4,7 @@ from uuid import uuid4
from constants import HIDDEN_VALUE
from core.variables import FloatVariable, IntegerVariable, SecretVariable, StringVariable
from models.workflow import Workflow, WorkflowNodeExecution
from models.workflow import Workflow, WorkflowNodeExecution, is_system_variable_editable
def test_environment_variables():
@ -163,3 +163,21 @@ class TestWorkflowNodeExecution:
original = {"a": 1, "b": ["2"]}
node_exec.execution_metadata = json.dumps(original)
assert node_exec.execution_metadata_dict == original
class TestIsSystemVariableEditable:
def test_is_system_variable(self):
cases = [
("query", True),
("files", True),
("dialogue_count", False),
("conversation_id", False),
("user_id", False),
("app_id", False),
("workflow_id", False),
("workflow_run_id", False),
]
for name, editable in cases:
assert editable == is_system_variable_editable(name)
assert is_system_variable_editable("invalid_or_new_system_variable") == False

View File

@ -0,0 +1,82 @@
import dataclasses
import secrets
from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID
from core.workflow.nodes import NodeType
from factories.variable_factory import build_segment
from models.workflow import WorkflowDraftVariable
from services.workflow_draft_variable_service import _DraftVariableBuilder
class TestDraftVariableBuilder:
def _get_test_app_id(self):
suffix = secrets.token_hex(6)
return f"test_app_id_{suffix}"
def test_get_variables(self):
test_app_id = self._get_test_app_id()
builder = _DraftVariableBuilder(app_id=test_app_id)
variables = [
WorkflowDraftVariable.new_node_variable(
app_id=test_app_id,
node_id="test_node_1",
name="test_var_1",
value=build_segment("test_value_1"),
visible=True,
),
WorkflowDraftVariable.new_sys_variable(
app_id=test_app_id,
name="test_sys_var",
value=build_segment("test_sys_value"),
),
WorkflowDraftVariable.new_conversation_variable(
app_id=test_app_id,
name="test_conv_var",
value=build_segment("test_conv_value"),
),
]
builder._draft_vars = variables
assert builder.get_variables() == variables
def test__should_variable_be_visible(self):
assert _DraftVariableBuilder._should_variable_be_visible(NodeType.IF_ELSE, "123_456", "output") == False
assert _DraftVariableBuilder._should_variable_be_visible(NodeType.START, "123", "output") == True
def test__normalize_variable_for_start_node(self):
@dataclasses.dataclass(frozen=True)
class TestCase:
name: str
input_node_id: str
input_name: str
expected_node_id: str
expected_name: str
_NODE_ID = "1747228642872"
cases = [
TestCase(
name="name with `sys.` prefix should return the system node_id",
input_node_id=_NODE_ID,
input_name="sys.workflow_id",
expected_node_id=SYSTEM_VARIABLE_NODE_ID,
expected_name="workflow_id",
),
TestCase(
name="name without `sys.` prefix should return the original input node_id",
input_node_id=_NODE_ID,
input_name="start_input",
expected_node_id=_NODE_ID,
expected_name="start_input",
),
TestCase(
name="dummy_variable should return the original input node_id",
input_node_id=_NODE_ID,
input_name="__dummy__",
expected_node_id=_NODE_ID,
expected_name="__dummy__",
),
]
for idx, c in enumerate(cases, 1):
fail_msg = f"Test case {c.name} failed, index={idx}"
node_id, name = _DraftVariableBuilder._normalize_variable_for_start_node(c.input_node_id, c.input_name)
assert node_id == c.expected_node_id, fail_msg
assert name == c.expected_name, fail_msg