mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-16 19:05:55 +08:00
test(api): Add some tests for DraftVariable related code
This commit is contained in:
parent
4cf9d23069
commit
f1fd05ccda
@ -4,40 +4,42 @@ import uuid
|
|||||||
import pytest
|
import pytest
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
|
||||||
from factories.variable_factory import build_segment
|
from factories.variable_factory import build_segment
|
||||||
from models import db
|
from models import db
|
||||||
from models.workflow import WorkflowDraftVariable
|
from models.workflow import WorkflowDraftVariable
|
||||||
from services.workflow_draft_variable_service import WorkflowDraftVariableService
|
from services.workflow_draft_variable_service import DraftVarLoader, WorkflowDraftVariableService
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.usefixtures("flask_req_ctx")
|
@pytest.mark.usefixtures("flask_req_ctx")
|
||||||
class TestWorkflowDraftVariableService(unittest.TestCase):
|
class TestWorkflowDraftVariableService(unittest.TestCase):
|
||||||
_test_app_id: str
|
_test_app_id: str
|
||||||
_session: Session
|
_session: Session
|
||||||
|
_node1_id = "test_node_1"
|
||||||
_node2_id = "test_node_2"
|
_node2_id = "test_node_2"
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self._test_app_id = str(uuid.uuid4())
|
self._test_app_id = str(uuid.uuid4())
|
||||||
self._session: Session = db.session
|
self._session: Session = db.session
|
||||||
sys_var = WorkflowDraftVariable.create_sys_variable(
|
sys_var = WorkflowDraftVariable.new_sys_variable(
|
||||||
app_id=self._test_app_id,
|
app_id=self._test_app_id,
|
||||||
name="sys_var",
|
name="sys_var",
|
||||||
value=build_segment("sys_value"),
|
value=build_segment("sys_value"),
|
||||||
)
|
)
|
||||||
conv_var = WorkflowDraftVariable.create_conversation_variable(
|
conv_var = WorkflowDraftVariable.new_conversation_variable(
|
||||||
app_id=self._test_app_id,
|
app_id=self._test_app_id,
|
||||||
name="conv_var",
|
name="conv_var",
|
||||||
value=build_segment("conv_value"),
|
value=build_segment("conv_value"),
|
||||||
)
|
)
|
||||||
node2_vars = [
|
node2_vars = [
|
||||||
WorkflowDraftVariable.create_node_variable(
|
WorkflowDraftVariable.new_node_variable(
|
||||||
app_id=self._test_app_id,
|
app_id=self._test_app_id,
|
||||||
node_id=self._node2_id,
|
node_id=self._node2_id,
|
||||||
name="int_var",
|
name="int_var",
|
||||||
value=build_segment(1),
|
value=build_segment(1),
|
||||||
visible=False,
|
visible=False,
|
||||||
),
|
),
|
||||||
WorkflowDraftVariable.create_node_variable(
|
WorkflowDraftVariable.new_node_variable(
|
||||||
app_id=self._test_app_id,
|
app_id=self._test_app_id,
|
||||||
node_id=self._node2_id,
|
node_id=self._node2_id,
|
||||||
name="str_var",
|
name="str_var",
|
||||||
@ -45,9 +47,9 @@ class TestWorkflowDraftVariableService(unittest.TestCase):
|
|||||||
visible=True,
|
visible=True,
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
node1_var = WorkflowDraftVariable.create_node_variable(
|
node1_var = WorkflowDraftVariable.new_node_variable(
|
||||||
app_id=self._test_app_id,
|
app_id=self._test_app_id,
|
||||||
node_id="node_1",
|
node_id=self._node1_id,
|
||||||
name="str_var",
|
name="str_var",
|
||||||
value=build_segment("str_value"),
|
value=build_segment("str_value"),
|
||||||
visible=True,
|
visible=True,
|
||||||
@ -92,7 +94,7 @@ class TestWorkflowDraftVariableService(unittest.TestCase):
|
|||||||
|
|
||||||
def test_get_node_variable(self):
|
def test_get_node_variable(self):
|
||||||
srv = self._get_test_srv()
|
srv = self._get_test_srv()
|
||||||
node_var = srv.get_node_variable(self._test_app_id, "node_1", "str_var")
|
node_var = srv.get_node_variable(self._test_app_id, self._node1_id, "str_var")
|
||||||
assert node_var.id == self._node1_str_var_id
|
assert node_var.id == self._node1_str_var_id
|
||||||
assert node_var.name == "str_var"
|
assert node_var.name == "str_var"
|
||||||
assert node_var.get_value() == build_segment("str_value")
|
assert node_var.get_value() == build_segment("str_value")
|
||||||
@ -138,5 +140,86 @@ class TestWorkflowDraftVariableService(unittest.TestCase):
|
|||||||
def test__list_node_variables(self):
|
def test__list_node_variables(self):
|
||||||
srv = self._get_test_srv()
|
srv = self._get_test_srv()
|
||||||
node_vars = srv._list_node_variables(self._test_app_id, self._node2_id)
|
node_vars = srv._list_node_variables(self._test_app_id, self._node2_id)
|
||||||
assert len(node_vars) == 2
|
assert len(node_vars.variables) == 2
|
||||||
assert {v.id for v in node_vars} == set(self._node2_var_ids)
|
assert {v.id for v in node_vars.variables} == set(self._node2_var_ids)
|
||||||
|
|
||||||
|
def test_get_draft_variables_by_selectors(self):
|
||||||
|
srv = self._get_test_srv()
|
||||||
|
selectors = [
|
||||||
|
[self._node1_id, "str_var"],
|
||||||
|
[self._node2_id, "str_var"],
|
||||||
|
[self._node2_id, "int_var"],
|
||||||
|
]
|
||||||
|
variables = srv.get_draft_variables_by_selectors(self._test_app_id, selectors)
|
||||||
|
assert len(variables) == 3
|
||||||
|
assert {v.id for v in variables} == {self._node1_str_var_id} | set(self._node2_var_ids)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.usefixtures("flask_req_ctx")
|
||||||
|
class TestDraftVariableLoader(unittest.TestCase):
|
||||||
|
_test_app_id: str
|
||||||
|
|
||||||
|
_node1_id = "test_loader_node_1"
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
self._test_app_id = str(uuid.uuid4())
|
||||||
|
sys_var = WorkflowDraftVariable.new_sys_variable(
|
||||||
|
app_id=self._test_app_id,
|
||||||
|
name="sys_var",
|
||||||
|
value=build_segment("sys_value"),
|
||||||
|
)
|
||||||
|
conv_var = WorkflowDraftVariable.new_conversation_variable(
|
||||||
|
app_id=self._test_app_id,
|
||||||
|
name="conv_var",
|
||||||
|
value=build_segment("conv_value"),
|
||||||
|
)
|
||||||
|
node_var = WorkflowDraftVariable.new_node_variable(
|
||||||
|
app_id=self._test_app_id,
|
||||||
|
node_id=self._node1_id,
|
||||||
|
name="str_var",
|
||||||
|
value=build_segment("str_value"),
|
||||||
|
visible=True,
|
||||||
|
)
|
||||||
|
_variables = [
|
||||||
|
node_var,
|
||||||
|
sys_var,
|
||||||
|
conv_var,
|
||||||
|
]
|
||||||
|
|
||||||
|
with Session(bind=db.engine, expire_on_commit=False) as session:
|
||||||
|
session.add_all(_variables)
|
||||||
|
session.flush()
|
||||||
|
session.commit()
|
||||||
|
self._variable_ids = [v.id for v in _variables]
|
||||||
|
self._node_var_id = node_var.id
|
||||||
|
self._sys_var_id = sys_var.id
|
||||||
|
self._conv_var_id = conv_var.id
|
||||||
|
|
||||||
|
def tearDown(self):
|
||||||
|
with Session(bind=db.engine, expire_on_commit=False) as session:
|
||||||
|
session.query(WorkflowDraftVariable).filter(WorkflowDraftVariable.app_id == self._test_app_id).delete(
|
||||||
|
synchronize_session=False
|
||||||
|
)
|
||||||
|
session.commit()
|
||||||
|
|
||||||
|
def test_variable_loader_with_empty_selector(self):
|
||||||
|
var_loader = DraftVarLoader(engine=db.engine, app_id=self._test_app_id)
|
||||||
|
variables = var_loader.load_variables([])
|
||||||
|
assert len(variables) == 0
|
||||||
|
|
||||||
|
def test_variable_loader_with_non_empty_selector(self):
|
||||||
|
var_loader = DraftVarLoader(engine=db.engine, app_id=self._test_app_id)
|
||||||
|
variables = var_loader.load_variables(
|
||||||
|
[
|
||||||
|
[SYSTEM_VARIABLE_NODE_ID, "sys_var"],
|
||||||
|
[CONVERSATION_VARIABLE_NODE_ID, "conv_var"],
|
||||||
|
[self._node1_id, "str_var"],
|
||||||
|
]
|
||||||
|
)
|
||||||
|
assert len(variables) == 3
|
||||||
|
conv_var = next(v for v in variables if v.selector[0] == CONVERSATION_VARIABLE_NODE_ID)
|
||||||
|
assert conv_var.id == self._conv_var_id
|
||||||
|
sys_var = next(v for v in variables if v.selector[0] == SYSTEM_VARIABLE_NODE_ID)
|
||||||
|
assert sys_var.id == self._sys_var_id
|
||||||
|
node1_var = next(v for v in variables if v.selector[0] == self._node1_id)
|
||||||
|
assert node1_var.id == self._node_var_id
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
import math
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
|
|
||||||
@ -232,7 +233,11 @@ def _scalar_value() -> st.SearchStrategy[int | float | str | File]:
|
|||||||
@given(_scalar_value())
|
@given(_scalar_value())
|
||||||
def test_build_segment_and_extract_values_for_scalar_types(value):
|
def test_build_segment_and_extract_values_for_scalar_types(value):
|
||||||
seg = variable_factory.build_segment(value)
|
seg = variable_factory.build_segment(value)
|
||||||
assert seg.value == value
|
# nan == nan yields false, so we need to use `math.isnan` to check `seg.value` here.
|
||||||
|
if isinstance(value, float) and math.isnan(value):
|
||||||
|
assert math.isnan(seg.value)
|
||||||
|
else:
|
||||||
|
assert seg.value == value
|
||||||
|
|
||||||
|
|
||||||
@given(st.lists(_scalar_value()))
|
@given(st.lists(_scalar_value()))
|
||||||
|
@ -1,22 +1,10 @@
|
|||||||
from core.variables import SecretVariable
|
import dataclasses
|
||||||
|
|
||||||
from core.workflow.entities.variable_entities import VariableSelector
|
from core.workflow.entities.variable_entities import VariableSelector
|
||||||
from core.workflow.entities.variable_pool import VariablePool
|
|
||||||
from core.workflow.enums import SystemVariableKey
|
|
||||||
from core.workflow.utils import variable_template_parser
|
from core.workflow.utils import variable_template_parser
|
||||||
|
|
||||||
|
|
||||||
def test_extract_selectors_from_template():
|
def test_extract_selectors_from_template():
|
||||||
variable_pool = VariablePool(
|
|
||||||
system_variables={
|
|
||||||
SystemVariableKey("user_id"): "fake-user-id",
|
|
||||||
},
|
|
||||||
user_inputs={},
|
|
||||||
environment_variables=[
|
|
||||||
SecretVariable(name="secret_key", value="fake-secret-key"),
|
|
||||||
],
|
|
||||||
conversation_variables=[],
|
|
||||||
)
|
|
||||||
variable_pool.add(("node_id", "custom_query"), "fake-user-query")
|
|
||||||
template = (
|
template = (
|
||||||
"Hello, {{#sys.user_id#}}! Your query is {{#node_id.custom_query#}}. And your key is {{#env.secret_key#}}."
|
"Hello, {{#sys.user_id#}}! Your query is {{#node_id.custom_query#}}. And your key is {{#env.secret_key#}}."
|
||||||
)
|
)
|
||||||
@ -26,3 +14,35 @@ def test_extract_selectors_from_template():
|
|||||||
VariableSelector(variable="#node_id.custom_query#", value_selector=["node_id", "custom_query"]),
|
VariableSelector(variable="#node_id.custom_query#", value_selector=["node_id", "custom_query"]),
|
||||||
VariableSelector(variable="#env.secret_key#", value_selector=["env", "secret_key"]),
|
VariableSelector(variable="#env.secret_key#", value_selector=["env", "secret_key"]),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def test_invalid_references():
|
||||||
|
@dataclasses.dataclass
|
||||||
|
class TestCase:
|
||||||
|
name: str
|
||||||
|
template: str
|
||||||
|
|
||||||
|
cases = [
|
||||||
|
TestCase(
|
||||||
|
name="lack of closing brace",
|
||||||
|
template="Hello, {{#sys.user_id#",
|
||||||
|
),
|
||||||
|
TestCase(
|
||||||
|
name="lack of opening brace",
|
||||||
|
template="Hello, #sys.user_id#}}",
|
||||||
|
),
|
||||||
|
TestCase(
|
||||||
|
name="lack selector name",
|
||||||
|
template="Hello, {{#sys#}}",
|
||||||
|
),
|
||||||
|
TestCase(
|
||||||
|
name="empty node name part",
|
||||||
|
template="Hello, {{#.user_id#}}",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
for idx, c in enumerate(cases, 1):
|
||||||
|
fail_msg = f"Test case {c.name} failed, index={idx}"
|
||||||
|
selectors = variable_template_parser.extract_selectors_from_template(c.template)
|
||||||
|
assert selectors == [], fail_msg
|
||||||
|
parser = variable_template_parser.VariableTemplateParser(c.template)
|
||||||
|
assert parser.extract_variable_selectors() == [], fail_msg
|
||||||
|
@ -5,7 +5,7 @@ from uuid import uuid4
|
|||||||
import contexts
|
import contexts
|
||||||
from constants import HIDDEN_VALUE
|
from constants import HIDDEN_VALUE
|
||||||
from core.variables import FloatVariable, IntegerVariable, SecretVariable, StringVariable
|
from core.variables import FloatVariable, IntegerVariable, SecretVariable, StringVariable
|
||||||
from models.workflow import Workflow, WorkflowNodeExecution
|
from models.workflow import Workflow, WorkflowNodeExecution, is_system_variable_editable
|
||||||
|
|
||||||
|
|
||||||
def test_environment_variables():
|
def test_environment_variables():
|
||||||
@ -149,3 +149,21 @@ class TestWorkflowNodeExecution:
|
|||||||
original = {"a": 1, "b": ["2"]}
|
original = {"a": 1, "b": ["2"]}
|
||||||
node_exec.execution_metadata = json.dumps(original)
|
node_exec.execution_metadata = json.dumps(original)
|
||||||
assert node_exec.execution_metadata_dict == original
|
assert node_exec.execution_metadata_dict == original
|
||||||
|
|
||||||
|
|
||||||
|
class TestIsSystemVariableEditable:
|
||||||
|
def test_is_system_variable(self):
|
||||||
|
cases = [
|
||||||
|
("query", True),
|
||||||
|
("files", True),
|
||||||
|
("dialogue_count", False),
|
||||||
|
("conversation_id", False),
|
||||||
|
("user_id", False),
|
||||||
|
("app_id", False),
|
||||||
|
("workflow_id", False),
|
||||||
|
("workflow_run_id", False),
|
||||||
|
]
|
||||||
|
for name, editable in cases:
|
||||||
|
assert editable == is_system_variable_editable(name)
|
||||||
|
|
||||||
|
assert is_system_variable_editable("invalid_or_new_system_variable") == False
|
||||||
|
@ -0,0 +1,82 @@
|
|||||||
|
import dataclasses
|
||||||
|
import secrets
|
||||||
|
|
||||||
|
from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID
|
||||||
|
from core.workflow.nodes import NodeType
|
||||||
|
from factories.variable_factory import build_segment
|
||||||
|
from models.workflow import WorkflowDraftVariable
|
||||||
|
from services.workflow_draft_variable_service import _DraftVariableBuilder
|
||||||
|
|
||||||
|
|
||||||
|
class TestDraftVariableBuilder:
|
||||||
|
def _get_test_app_id(self):
|
||||||
|
suffix = secrets.token_hex(6)
|
||||||
|
return f"test_app_id_{suffix}"
|
||||||
|
|
||||||
|
def test_get_variables(self):
|
||||||
|
test_app_id = self._get_test_app_id()
|
||||||
|
builder = _DraftVariableBuilder(app_id=test_app_id)
|
||||||
|
variables = [
|
||||||
|
WorkflowDraftVariable.new_node_variable(
|
||||||
|
app_id=test_app_id,
|
||||||
|
node_id="test_node_1",
|
||||||
|
name="test_var_1",
|
||||||
|
value=build_segment("test_value_1"),
|
||||||
|
visible=True,
|
||||||
|
),
|
||||||
|
WorkflowDraftVariable.new_sys_variable(
|
||||||
|
app_id=test_app_id,
|
||||||
|
name="test_sys_var",
|
||||||
|
value=build_segment("test_sys_value"),
|
||||||
|
),
|
||||||
|
WorkflowDraftVariable.new_conversation_variable(
|
||||||
|
app_id=test_app_id,
|
||||||
|
name="test_conv_var",
|
||||||
|
value=build_segment("test_conv_value"),
|
||||||
|
),
|
||||||
|
]
|
||||||
|
builder._draft_vars = variables
|
||||||
|
assert builder.get_variables() == variables
|
||||||
|
|
||||||
|
def test__should_variable_be_visible(self):
|
||||||
|
assert _DraftVariableBuilder._should_variable_be_visible(NodeType.IF_ELSE, "123_456", "output") == False
|
||||||
|
assert _DraftVariableBuilder._should_variable_be_visible(NodeType.START, "123", "output") == True
|
||||||
|
|
||||||
|
def test__normalize_variable_for_start_node(self):
|
||||||
|
@dataclasses.dataclass(frozen=True)
|
||||||
|
class TestCase:
|
||||||
|
name: str
|
||||||
|
input_node_id: str
|
||||||
|
input_name: str
|
||||||
|
expected_node_id: str
|
||||||
|
expected_name: str
|
||||||
|
|
||||||
|
_NODE_ID = "1747228642872"
|
||||||
|
cases = [
|
||||||
|
TestCase(
|
||||||
|
name="name with `sys.` prefix should return the system node_id",
|
||||||
|
input_node_id=_NODE_ID,
|
||||||
|
input_name="sys.workflow_id",
|
||||||
|
expected_node_id=SYSTEM_VARIABLE_NODE_ID,
|
||||||
|
expected_name="workflow_id",
|
||||||
|
),
|
||||||
|
TestCase(
|
||||||
|
name="name without `sys.` prefix should return the original input node_id",
|
||||||
|
input_node_id=_NODE_ID,
|
||||||
|
input_name="start_input",
|
||||||
|
expected_node_id=_NODE_ID,
|
||||||
|
expected_name="start_input",
|
||||||
|
),
|
||||||
|
TestCase(
|
||||||
|
name="dummy_variable should return the original input node_id",
|
||||||
|
input_node_id=_NODE_ID,
|
||||||
|
input_name="__dummy__",
|
||||||
|
expected_node_id=_NODE_ID,
|
||||||
|
expected_name="__dummy__",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
for idx, c in enumerate(cases, 1):
|
||||||
|
fail_msg = f"Test case {c.name} failed, index={idx}"
|
||||||
|
node_id, name = _DraftVariableBuilder._normalize_variable_for_start_node(c.input_node_id, c.input_name)
|
||||||
|
assert node_id == c.expected_node_id, fail_msg
|
||||||
|
assert name == c.expected_name, fail_msg
|
Loading…
x
Reference in New Issue
Block a user