mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-14 11:15:57 +08:00
test(api): Add some tests for DraftVariable related code
This commit is contained in:
parent
7165333468
commit
c730f0fcf2
@ -4,40 +4,42 @@ import uuid
|
||||
import pytest
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
|
||||
from factories.variable_factory import build_segment
|
||||
from models import db
|
||||
from models.workflow import WorkflowDraftVariable
|
||||
from services.workflow_draft_variable_service import WorkflowDraftVariableService
|
||||
from services.workflow_draft_variable_service import DraftVarLoader, WorkflowDraftVariableService
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("flask_req_ctx")
|
||||
class TestWorkflowDraftVariableService(unittest.TestCase):
|
||||
_test_app_id: str
|
||||
_session: Session
|
||||
_node1_id = "test_node_1"
|
||||
_node2_id = "test_node_2"
|
||||
|
||||
def setUp(self):
|
||||
self._test_app_id = str(uuid.uuid4())
|
||||
self._session: Session = db.session
|
||||
sys_var = WorkflowDraftVariable.create_sys_variable(
|
||||
sys_var = WorkflowDraftVariable.new_sys_variable(
|
||||
app_id=self._test_app_id,
|
||||
name="sys_var",
|
||||
value=build_segment("sys_value"),
|
||||
)
|
||||
conv_var = WorkflowDraftVariable.create_conversation_variable(
|
||||
conv_var = WorkflowDraftVariable.new_conversation_variable(
|
||||
app_id=self._test_app_id,
|
||||
name="conv_var",
|
||||
value=build_segment("conv_value"),
|
||||
)
|
||||
node2_vars = [
|
||||
WorkflowDraftVariable.create_node_variable(
|
||||
WorkflowDraftVariable.new_node_variable(
|
||||
app_id=self._test_app_id,
|
||||
node_id=self._node2_id,
|
||||
name="int_var",
|
||||
value=build_segment(1),
|
||||
visible=False,
|
||||
),
|
||||
WorkflowDraftVariable.create_node_variable(
|
||||
WorkflowDraftVariable.new_node_variable(
|
||||
app_id=self._test_app_id,
|
||||
node_id=self._node2_id,
|
||||
name="str_var",
|
||||
@ -45,9 +47,9 @@ class TestWorkflowDraftVariableService(unittest.TestCase):
|
||||
visible=True,
|
||||
),
|
||||
]
|
||||
node1_var = WorkflowDraftVariable.create_node_variable(
|
||||
node1_var = WorkflowDraftVariable.new_node_variable(
|
||||
app_id=self._test_app_id,
|
||||
node_id="node_1",
|
||||
node_id=self._node1_id,
|
||||
name="str_var",
|
||||
value=build_segment("str_value"),
|
||||
visible=True,
|
||||
@ -92,7 +94,7 @@ class TestWorkflowDraftVariableService(unittest.TestCase):
|
||||
|
||||
def test_get_node_variable(self):
|
||||
srv = self._get_test_srv()
|
||||
node_var = srv.get_node_variable(self._test_app_id, "node_1", "str_var")
|
||||
node_var = srv.get_node_variable(self._test_app_id, self._node1_id, "str_var")
|
||||
assert node_var.id == self._node1_str_var_id
|
||||
assert node_var.name == "str_var"
|
||||
assert node_var.get_value() == build_segment("str_value")
|
||||
@ -138,5 +140,86 @@ class TestWorkflowDraftVariableService(unittest.TestCase):
|
||||
def test__list_node_variables(self):
|
||||
srv = self._get_test_srv()
|
||||
node_vars = srv._list_node_variables(self._test_app_id, self._node2_id)
|
||||
assert len(node_vars) == 2
|
||||
assert {v.id for v in node_vars} == set(self._node2_var_ids)
|
||||
assert len(node_vars.variables) == 2
|
||||
assert {v.id for v in node_vars.variables} == set(self._node2_var_ids)
|
||||
|
||||
def test_get_draft_variables_by_selectors(self):
|
||||
srv = self._get_test_srv()
|
||||
selectors = [
|
||||
[self._node1_id, "str_var"],
|
||||
[self._node2_id, "str_var"],
|
||||
[self._node2_id, "int_var"],
|
||||
]
|
||||
variables = srv.get_draft_variables_by_selectors(self._test_app_id, selectors)
|
||||
assert len(variables) == 3
|
||||
assert {v.id for v in variables} == {self._node1_str_var_id} | set(self._node2_var_ids)
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("flask_req_ctx")
|
||||
class TestDraftVariableLoader(unittest.TestCase):
|
||||
_test_app_id: str
|
||||
|
||||
_node1_id = "test_loader_node_1"
|
||||
|
||||
def setUp(self):
|
||||
self._test_app_id = str(uuid.uuid4())
|
||||
sys_var = WorkflowDraftVariable.new_sys_variable(
|
||||
app_id=self._test_app_id,
|
||||
name="sys_var",
|
||||
value=build_segment("sys_value"),
|
||||
)
|
||||
conv_var = WorkflowDraftVariable.new_conversation_variable(
|
||||
app_id=self._test_app_id,
|
||||
name="conv_var",
|
||||
value=build_segment("conv_value"),
|
||||
)
|
||||
node_var = WorkflowDraftVariable.new_node_variable(
|
||||
app_id=self._test_app_id,
|
||||
node_id=self._node1_id,
|
||||
name="str_var",
|
||||
value=build_segment("str_value"),
|
||||
visible=True,
|
||||
)
|
||||
_variables = [
|
||||
node_var,
|
||||
sys_var,
|
||||
conv_var,
|
||||
]
|
||||
|
||||
with Session(bind=db.engine, expire_on_commit=False) as session:
|
||||
session.add_all(_variables)
|
||||
session.flush()
|
||||
session.commit()
|
||||
self._variable_ids = [v.id for v in _variables]
|
||||
self._node_var_id = node_var.id
|
||||
self._sys_var_id = sys_var.id
|
||||
self._conv_var_id = conv_var.id
|
||||
|
||||
def tearDown(self):
|
||||
with Session(bind=db.engine, expire_on_commit=False) as session:
|
||||
session.query(WorkflowDraftVariable).filter(WorkflowDraftVariable.app_id == self._test_app_id).delete(
|
||||
synchronize_session=False
|
||||
)
|
||||
session.commit()
|
||||
|
||||
def test_variable_loader_with_empty_selector(self):
|
||||
var_loader = DraftVarLoader(engine=db.engine, app_id=self._test_app_id)
|
||||
variables = var_loader.load_variables([])
|
||||
assert len(variables) == 0
|
||||
|
||||
def test_variable_loader_with_non_empty_selector(self):
|
||||
var_loader = DraftVarLoader(engine=db.engine, app_id=self._test_app_id)
|
||||
variables = var_loader.load_variables(
|
||||
[
|
||||
[SYSTEM_VARIABLE_NODE_ID, "sys_var"],
|
||||
[CONVERSATION_VARIABLE_NODE_ID, "conv_var"],
|
||||
[self._node1_id, "str_var"],
|
||||
]
|
||||
)
|
||||
assert len(variables) == 3
|
||||
conv_var = next(v for v in variables if v.selector[0] == CONVERSATION_VARIABLE_NODE_ID)
|
||||
assert conv_var.id == self._conv_var_id
|
||||
sys_var = next(v for v in variables if v.selector[0] == SYSTEM_VARIABLE_NODE_ID)
|
||||
assert sys_var.id == self._sys_var_id
|
||||
node1_var = next(v for v in variables if v.selector[0] == self._node1_id)
|
||||
assert node1_var.id == self._node_var_id
|
||||
|
@ -1,3 +1,4 @@
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from uuid import uuid4
|
||||
|
||||
@ -232,7 +233,11 @@ def _scalar_value() -> st.SearchStrategy[int | float | str | File]:
|
||||
@given(_scalar_value())
|
||||
def test_build_segment_and_extract_values_for_scalar_types(value):
|
||||
seg = variable_factory.build_segment(value)
|
||||
assert seg.value == value
|
||||
# nan == nan yields false, so we need to use `math.isnan` to check `seg.value` here.
|
||||
if isinstance(value, float) and math.isnan(value):
|
||||
assert math.isnan(seg.value)
|
||||
else:
|
||||
assert seg.value == value
|
||||
|
||||
|
||||
@given(st.lists(_scalar_value()))
|
||||
|
@ -1,22 +1,10 @@
|
||||
from core.variables import SecretVariable
|
||||
import dataclasses
|
||||
|
||||
from core.workflow.entities.variable_entities import VariableSelector
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.enums import SystemVariableKey
|
||||
from core.workflow.utils import variable_template_parser
|
||||
|
||||
|
||||
def test_extract_selectors_from_template():
|
||||
variable_pool = VariablePool(
|
||||
system_variables={
|
||||
SystemVariableKey("user_id"): "fake-user-id",
|
||||
},
|
||||
user_inputs={},
|
||||
environment_variables=[
|
||||
SecretVariable(name="secret_key", value="fake-secret-key"),
|
||||
],
|
||||
conversation_variables=[],
|
||||
)
|
||||
variable_pool.add(("node_id", "custom_query"), "fake-user-query")
|
||||
template = (
|
||||
"Hello, {{#sys.user_id#}}! Your query is {{#node_id.custom_query#}}. And your key is {{#env.secret_key#}}."
|
||||
)
|
||||
@ -26,3 +14,35 @@ def test_extract_selectors_from_template():
|
||||
VariableSelector(variable="#node_id.custom_query#", value_selector=["node_id", "custom_query"]),
|
||||
VariableSelector(variable="#env.secret_key#", value_selector=["env", "secret_key"]),
|
||||
]
|
||||
|
||||
|
||||
def test_invalid_references():
|
||||
@dataclasses.dataclass
|
||||
class TestCase:
|
||||
name: str
|
||||
template: str
|
||||
|
||||
cases = [
|
||||
TestCase(
|
||||
name="lack of closing brace",
|
||||
template="Hello, {{#sys.user_id#",
|
||||
),
|
||||
TestCase(
|
||||
name="lack of opening brace",
|
||||
template="Hello, #sys.user_id#}}",
|
||||
),
|
||||
TestCase(
|
||||
name="lack selector name",
|
||||
template="Hello, {{#sys#}}",
|
||||
),
|
||||
TestCase(
|
||||
name="empty node name part",
|
||||
template="Hello, {{#.user_id#}}",
|
||||
),
|
||||
]
|
||||
for idx, c in enumerate(cases, 1):
|
||||
fail_msg = f"Test case {c.name} failed, index={idx}"
|
||||
selectors = variable_template_parser.extract_selectors_from_template(c.template)
|
||||
assert selectors == [], fail_msg
|
||||
parser = variable_template_parser.VariableTemplateParser(c.template)
|
||||
assert parser.extract_variable_selectors() == [], fail_msg
|
||||
|
@ -4,7 +4,7 @@ from uuid import uuid4
|
||||
|
||||
from constants import HIDDEN_VALUE
|
||||
from core.variables import FloatVariable, IntegerVariable, SecretVariable, StringVariable
|
||||
from models.workflow import Workflow, WorkflowNodeExecution
|
||||
from models.workflow import Workflow, WorkflowNodeExecution, is_system_variable_editable
|
||||
|
||||
|
||||
def test_environment_variables():
|
||||
@ -163,3 +163,21 @@ class TestWorkflowNodeExecution:
|
||||
original = {"a": 1, "b": ["2"]}
|
||||
node_exec.execution_metadata = json.dumps(original)
|
||||
assert node_exec.execution_metadata_dict == original
|
||||
|
||||
|
||||
class TestIsSystemVariableEditable:
|
||||
def test_is_system_variable(self):
|
||||
cases = [
|
||||
("query", True),
|
||||
("files", True),
|
||||
("dialogue_count", False),
|
||||
("conversation_id", False),
|
||||
("user_id", False),
|
||||
("app_id", False),
|
||||
("workflow_id", False),
|
||||
("workflow_run_id", False),
|
||||
]
|
||||
for name, editable in cases:
|
||||
assert editable == is_system_variable_editable(name)
|
||||
|
||||
assert is_system_variable_editable("invalid_or_new_system_variable") == False
|
||||
|
@ -0,0 +1,82 @@
|
||||
import dataclasses
|
||||
import secrets
|
||||
|
||||
from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID
|
||||
from core.workflow.nodes import NodeType
|
||||
from factories.variable_factory import build_segment
|
||||
from models.workflow import WorkflowDraftVariable
|
||||
from services.workflow_draft_variable_service import _DraftVariableBuilder
|
||||
|
||||
|
||||
class TestDraftVariableBuilder:
|
||||
def _get_test_app_id(self):
|
||||
suffix = secrets.token_hex(6)
|
||||
return f"test_app_id_{suffix}"
|
||||
|
||||
def test_get_variables(self):
|
||||
test_app_id = self._get_test_app_id()
|
||||
builder = _DraftVariableBuilder(app_id=test_app_id)
|
||||
variables = [
|
||||
WorkflowDraftVariable.new_node_variable(
|
||||
app_id=test_app_id,
|
||||
node_id="test_node_1",
|
||||
name="test_var_1",
|
||||
value=build_segment("test_value_1"),
|
||||
visible=True,
|
||||
),
|
||||
WorkflowDraftVariable.new_sys_variable(
|
||||
app_id=test_app_id,
|
||||
name="test_sys_var",
|
||||
value=build_segment("test_sys_value"),
|
||||
),
|
||||
WorkflowDraftVariable.new_conversation_variable(
|
||||
app_id=test_app_id,
|
||||
name="test_conv_var",
|
||||
value=build_segment("test_conv_value"),
|
||||
),
|
||||
]
|
||||
builder._draft_vars = variables
|
||||
assert builder.get_variables() == variables
|
||||
|
||||
def test__should_variable_be_visible(self):
|
||||
assert _DraftVariableBuilder._should_variable_be_visible(NodeType.IF_ELSE, "123_456", "output") == False
|
||||
assert _DraftVariableBuilder._should_variable_be_visible(NodeType.START, "123", "output") == True
|
||||
|
||||
def test__normalize_variable_for_start_node(self):
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class TestCase:
|
||||
name: str
|
||||
input_node_id: str
|
||||
input_name: str
|
||||
expected_node_id: str
|
||||
expected_name: str
|
||||
|
||||
_NODE_ID = "1747228642872"
|
||||
cases = [
|
||||
TestCase(
|
||||
name="name with `sys.` prefix should return the system node_id",
|
||||
input_node_id=_NODE_ID,
|
||||
input_name="sys.workflow_id",
|
||||
expected_node_id=SYSTEM_VARIABLE_NODE_ID,
|
||||
expected_name="workflow_id",
|
||||
),
|
||||
TestCase(
|
||||
name="name without `sys.` prefix should return the original input node_id",
|
||||
input_node_id=_NODE_ID,
|
||||
input_name="start_input",
|
||||
expected_node_id=_NODE_ID,
|
||||
expected_name="start_input",
|
||||
),
|
||||
TestCase(
|
||||
name="dummy_variable should return the original input node_id",
|
||||
input_node_id=_NODE_ID,
|
||||
input_name="__dummy__",
|
||||
expected_node_id=_NODE_ID,
|
||||
expected_name="__dummy__",
|
||||
),
|
||||
]
|
||||
for idx, c in enumerate(cases, 1):
|
||||
fail_msg = f"Test case {c.name} failed, index={idx}"
|
||||
node_id, name = _DraftVariableBuilder._normalize_variable_for_start_node(c.input_node_id, c.input_name)
|
||||
assert node_id == c.expected_node_id, fail_msg
|
||||
assert name == c.expected_name, fail_msg
|
Loading…
x
Reference in New Issue
Block a user