Merge branch 'feat/variable-pool-rebased' into deploy/dev

This commit is contained in:
QuantumGhost 2025-05-26 14:23:55 +08:00
commit a7bfe67797
14 changed files with 508 additions and 131 deletions

View File

@ -75,6 +75,20 @@ class StringSegment(Segment):
class FloatSegment(Segment):
value_type: SegmentType = SegmentType.NUMBER
value: float
# NOTE(QuantumGhost): seems that the equality for FloatSegment with `NaN` value has some problems.
# The following tests cannot pass.
#
# def test_float_segment_and_nan():
# nan = float("nan")
# assert nan != nan
#
# f1 = FloatSegment(value=float("nan"))
# f2 = FloatSegment(value=float("nan"))
# assert f1 != f2
#
# f3 = FloatSegment(value=nan)
# f4 = FloatSegment(value=nan)
# assert f3 != f4
class IntegerSegment(Segment):

View File

@ -96,7 +96,7 @@ class VariablePool(BaseModel):
if isinstance(value, Variable):
variable = value
if isinstance(value, Segment):
elif isinstance(value, Segment):
variable = variable_factory.segment_to_variable(segment=value, selector=selector)
else:
segment = variable_factory.build_segment(value)

View File

@ -90,8 +90,28 @@ class BaseNode(Generic[GenericNodeData]):
graph_config: Mapping[str, Any],
config: Mapping[str, Any],
) -> Mapping[str, Sequence[str]]:
"""
Extract variable selector to variable mapping
"""Extracts references variable selectors from node configuration.
The `config` parameter represents the configuration for a specific node type and corresponds
to the `data` field in the node definition object.
The returned mapping has the following structure:
{'1747829548239.#1747829667553.result#': ['1747829667553', 'result']}
Here, the key consists of two parts: the current node ID (provided as the `node_id`
parameter to `_extract_variable_selector_to_variable_mapping`) and the variable selector,
enclosed in `#` symbols. These two parts are separated by a dot (`.`).
The value is a list of string representing the variable selector, where the first element is the node ID
of the referenced variable, and the second element is the variable name within that node.
The meaning of the above response is:
The node with ID `1747829548239` references the variable `result` from the node with
ID `1747829667553`. For example, if `1747829548239` is a LLM node, its prompt may contain a
reference to the `result` output variable of node `1747829667553`.
:param graph_config: graph config
:param config: node config
:return:

View File

@ -1,4 +1,5 @@
from typing import Literal
from collections.abc import Mapping, Sequence
from typing import Any, Literal
from typing_extensions import deprecated
@ -91,6 +92,22 @@ class IfElseNode(BaseNode[IfElseNodeData]):
return data
@classmethod
def _extract_variable_selector_to_variable_mapping(
cls,
*,
graph_config: Mapping[str, Any],
node_id: str,
node_data: IfElseNodeData,
) -> Mapping[str, Sequence[str]]:
var_mapping: dict[str, list[str]] = {}
for case in node_data.cases or []:
for condition in case.conditions:
key = "{}.#{}#".format(node_id, ".".join(condition.variable_selector))
var_mapping[key] = condition.variable_selector
return var_mapping
@deprecated("This function is deprecated. You should use the new cases structure.")
def _should_not_use_old_function(

View File

@ -0,0 +1,38 @@
import abc
from typing import Protocol
from core.variables import Variable
class VariableLoader(Protocol):
"""Interface for loading variables based on selectors.
A `VariableLoader` is responsible for retrieving additional variables required during the execution
of a single node, which are not provided as user inputs.
NOTE(QuantumGhost): Typically, all variables loaded by a `VariableLoader` should belong to the same
application and share the same `app_id`. However, this interface does not enforce that constraint,
and the `app_id` parameter is intentionally omitted from `load_variables` to achieve separation of
concern and allow for flexible implementations.
Implementations of `VariableLoader` should almost always have an `app_id` parameter in
their constructor.
TODO(QuantumGhost): this is a temporally workaround. If we can move the creation of node instance into
`WorkflowService.single_step_run`, we may get rid of this interface.
"""
@abc.abstractmethod
def load_variables(self, selectors: list[list[str]]) -> list[Variable]:
"""Load variables based on the provided selectors. If the selectors are empty,
this method should return an empty list.
The order of the returned variables is not guaranteed. If the caller wants to ensure
a specific order, they should sort the returned list themselves.
:param: selectors: a list of string list, each inner list should have at least two elements:
- the first element is the node ID,
- the second element is the variable name.
:return: a list of Variable objects that match the provided selectors.
"""
pass

View File

@ -2,12 +2,13 @@ import logging
import time
import uuid
from collections.abc import Callable, Generator, Mapping, Sequence
from typing import Any, Optional, TypeAlias, TypeVar, cast
from typing import Any, Optional, TypeAlias, cast
from configs import dify_config
from core.app.apps.base_app_queue_manager import GenerateTaskStoppedError
from core.app.entities.app_invoke_entities import InvokeFrom
from core.file.models import File
from core.variables import Variable
from core.workflow.callbacks import WorkflowCallback
from core.workflow.constants import ENVIRONMENT_VARIABLE_NODE_ID
from core.workflow.entities.node_entities import NodeRunMetadataKey
@ -22,13 +23,27 @@ from core.workflow.nodes import NodeType
from core.workflow.nodes.base import BaseNode
from core.workflow.nodes.event import NodeEvent, RunCompletedEvent
from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING
from core.workflow.variable_loader import VariableLoader
from factories import file_factory
from libs import gen_utils
from models.enums import UserFrom
from models.workflow import (
Workflow,
WorkflowType,
)
class _DummyVariableLoader(VariableLoader):
"""A dummy implementation of VariableLoader that does not load any variables.
Serves as a placeholder when no variable loading is needed.
"""
def load_variables(self, selectors: list[list[str]]) -> list[Variable]:
return []
_DUMMY_VARIABLE_LOADER = _DummyVariableLoader()
logger = logging.getLogger(__name__)
@ -122,6 +137,7 @@ class WorkflowEntry:
user_id: str,
user_inputs: dict,
conversation_variables: dict | None = None,
variable_loader: VariableLoader = _DUMMY_VARIABLE_LOADER,
) -> tuple[BaseNode, Generator[NodeEvent | InNodeEvent, None, None]]:
"""
Single step run workflow node
@ -190,6 +206,19 @@ class WorkflowEntry:
except NotImplementedError:
variable_mapping = {}
# Loading missing variable from draft var here, and set it into
# variable_pool.
variables_to_load: list[list[str]] = []
for key, selector in variable_mapping.items():
trimmed_key = key.removeprefix(f"{node_id}.")
if trimmed_key in user_inputs:
continue
if variable_pool.get(selector) is None:
variables_to_load.append(list(selector))
loaded = variable_loader.load_variables(variables_to_load)
for var in loaded:
variable_pool.add(var.selector, var.value)
cls.mapping_user_inputs_to_variable_pool(
variable_mapping=variable_mapping,
user_inputs=user_inputs,
@ -204,7 +233,7 @@ class WorkflowEntry:
except Exception as e:
raise WorkflowNodeRunFailedError(node_instance=node_instance, error=str(e))
if metadata_attacher:
generator = _wrap_generator(generator, metadata_attacher)
generator = gen_utils.map_(generator, metadata_attacher)
return node_instance, generator
@classmethod
@ -391,18 +420,6 @@ class WorkflowEntry:
variable_pool.add([variable_node_id] + variable_key_list, input_value)
_YieldT_co = TypeVar("_YieldT_co", covariant=True)
_YieldR_co = TypeVar("_YieldR_co", covariant=True)
def _wrap_generator(
gen: Generator[_YieldT_co, None, None],
mapper: Callable[[_YieldT_co], _YieldR_co],
) -> Generator[_YieldR_co, None, None]:
for item in gen:
yield mapper(item)
_NodeOrInNodeEvent: TypeAlias = NodeEvent | InNodeEvent

53
api/libs/gen_utils.py Normal file
View File

@ -0,0 +1,53 @@
"""Utility functions for working with generators."""
import logging
from collections.abc import Callable, Generator
from inspect import isgenerator
from typing import TypeVar
_YieldT = TypeVar("_YieldT")
_YieldR = TypeVar("_YieldR")
_T = TypeVar("_T")
def inspect(gen_or_normal: _T, logger: logging.Logger) -> _T:
if not isgenerator(gen_or_normal):
return gen_or_normal
def wrapper():
for item in gen_or_normal:
logger.info(
"received generator item, type=%s, value=%s",
type(item),
item,
)
yield item
return wrapper()
def map_(
gen: Generator[_YieldT, None, None],
mapper: Callable[[_YieldT], _YieldR],
) -> Generator[_YieldR, None, None]:
for item in gen:
yield mapper(item)
def filter_(
gen: Generator[_YieldT, None, None],
mapper: Callable[[_YieldT], bool],
) -> Generator[_YieldT, None, None]:
for item in gen:
if mapper(item):
yield item
def wrap(
gen: Generator[_YieldT, None, None],
funcs: list[Callable[[Generator[_YieldT, None, None]], Generator[_YieldT, None, None]]],
) -> Generator[_YieldT, None, None]:
for f in funcs:
gen = f(gen)
return gen

View File

@ -3,17 +3,19 @@ import logging
from collections.abc import Mapping, Sequence
from typing import Any
from sqlalchemy import orm
from sqlalchemy import Engine, orm
from sqlalchemy.dialects.postgresql import insert
from sqlalchemy.orm import Session
from sqlalchemy.sql.expression import and_, or_
from core.app.entities.app_invoke_entities import InvokeFrom
from core.file.constants import is_dummy_output_variable
from core.variables import Segment
from core.variables import Segment, Variable
from core.variables.consts import MIN_SELECTORS_LENGTH
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, ENVIRONMENT_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
from core.workflow.nodes import NodeType
from core.workflow.variable_loader import VariableLoader
from factories import variable_factory
from factories.variable_factory import build_segment, segment_to_variable
from models.workflow import WorkflowDraftVariable, is_system_variable_editable
_logger = logging.getLogger(__name__)
@ -25,6 +27,36 @@ class WorkflowDraftVariableList:
total: int | None = None
class DraftVarLoader(VariableLoader):
# This implements the VariableLoader interface for loading draft variables.
#
# ref: core.workflow.variable_loader.VariableLoader
def __init__(self, engine: Engine, app_id: str) -> None:
self._engine = engine
self._app_id = app_id
def load_variables(self, selectors: list[list[str]]) -> list[Variable]:
if not selectors:
return []
with Session(bind=self._engine, expire_on_commit=False) as session:
srv = WorkflowDraftVariableService(session)
draft_vars = srv.get_draft_variables_by_selectors(self._app_id, selectors)
variables = []
for draft_var in draft_vars:
segment = build_segment(
draft_var.value,
)
variable = segment_to_variable(
segment=segment,
selector=draft_var.get_selector(),
id=draft_var.id,
name=draft_var.name,
description=draft_var.description,
)
variables.append(variable)
return variables
class WorkflowDraftVariableService:
_session: Session
@ -34,6 +66,28 @@ class WorkflowDraftVariableService:
def get_variable(self, variable_id: str) -> WorkflowDraftVariable | None:
return self._session.query(WorkflowDraftVariable).filter(WorkflowDraftVariable.id == variable_id).first()
def get_draft_variables_by_selectors(
self,
app_id: str,
selectors: Sequence[list[str]],
) -> list[WorkflowDraftVariable]:
ors = []
for selector in selectors:
node_id, name = selector
ors.append(and_(WorkflowDraftVariable.node_id == node_id, WorkflowDraftVariable.name == name))
# NOTE(QuantumGhost): Although the number of `or` expressions may be large, as long as
# each expression includes conditions on both `node_id` and `name` (which are covered by the unique index),
# PostgreSQL can efficiently retrieve the results using a bitmap index scan.
#
# Alternatively, a `SELECT` statement could be constructed for each selector and
# combined using `UNION` to fetch all rows.
# Benchmarking indicates that both approaches yield comparable performance.
variables = (
self._session.query(WorkflowDraftVariable).where(WorkflowDraftVariable.app_id == app_id, or_(*ors)).all()
)
return variables
def save_output_variables(self, app_id: str, node_id: str, node_type: NodeType, output: Mapping[str, Any]):
variable_builder = _DraftVariableBuilder(app_id=app_id)
variable_builder.build(node_id=node_id, node_type=node_type, output=output)
@ -42,29 +96,25 @@ class WorkflowDraftVariableService:
if not draft_variables:
return
# We may use SQLAlchemy ORM operation here. However, considering the fact that:
# Although we could use SQLAlchemy ORM operations here, we choose not to for several reasons:
#
# 1. The variable saving process writes multiple rows into one table (`workflow_draft_variables`).
# Use batch insertion may increase performance dramatically.
# 2. If we use ORM operation, we need to either:
# 1. The variable saving process involves writing multiple rows to the
# `workflow_draft_variables` table. Batch insertion significantly improves performance.
# 2. Using the ORM would require either:
#
# a. Check the existence for each variable before insertion.
# b. Try insertion first, then do update if insertion fails due to unique index violation.
# a. Checking for the existence of each variable before insertion,
# resulting in 2n SQL statements for n variables and potential concurrency issues.
# b. Attempting insertion first, then updating if a unique index violation occurs,
# which still results in n to 2n SQL statements.
#
# Neither of the above is satisfactory.
# Both approaches are inefficient and suboptimal.
# 3. We do not need to retrieve the results of the SQL execution or populate ORM
# model instances with the returned values.
# 4. Batch insertion with `ON CONFLICT DO UPDATE` allows us to insert or update all
# variables in a single SQL statement, avoiding the issues above.
#
# - For implementation "a", we need to issue `2n` sqls for `n` variables in output.
# Besides, it's still suffer from concurrency issues.
# - For implementation "b", we need to issue `n` - `2n` sqls (depending on the existence of
# specific variable), which is lesser than plan "a" but still far from ideal.
#
# 3. We do not need the value of SQL execution, nor do we need populate those values back to ORM model
# instances.
# 4. Batch insertion can be combined with `ON CONFLICT DO UPDATE`, allows us to insert or update
# all variables in one SQL statement, and avoid all problems above.
#
# Given reasons above, we use query builder instead of using ORM layer,
# and rely on dialect specific insert operations.
# For these reasons, we use the SQLAlchemy query builder and rely on dialect-specific
# insert operations instead of the ORM layer.
if node_type == NodeType.CODE:
# Clear existing variable for code node.
self._session.query(WorkflowDraftVariable).filter(
@ -213,26 +263,6 @@ def should_save_output_variables_for_draft(
return True
# def should_save_output_variables_for_draft(invoke_from: InvokeFrom, node_exec: WorkflowNodeExecution) -> bool:
# # Only save output variables for debugging execution of workflow.
# if invoke_from != InvokeFrom.DEBUGGER:
# return False
# exec_metadata = node_exec.execution_metadata_dict
# if exec_metadata is None:
# # No execution metadata, assume the node is not in loop or iteration.
# return True
#
# # Currently we do not save output variables for nodes inside loop or iteration.
# loop_id = exec_metadata.get(NodeRunMetadataKey.LOOP_ID)
# if loop_id is not None:
# return False
# iteration_id = exec_metadata.get(NodeRunMetadataKey.ITERATION_ID)
# if iteration_id is not None:
# return False
# return True
#
class _DraftVariableBuilder:
_app_id: str
_draft_vars: list[WorkflowDraftVariable]
@ -280,7 +310,8 @@ class _DraftVariableBuilder:
original_node_id = node_id
for name, value in output.items():
value_seg = variable_factory.build_segment(value)
if is_dummy_output_variable(name):
node_id, name = self._normalize_variable_for_start_node(node_id, name)
if node_id != SYSTEM_VARIABLE_NODE_ID:
self._draft_vars.append(
WorkflowDraftVariable.new_node_variable(
app_id=self._app_id,
@ -302,20 +333,9 @@ class _DraftVariableBuilder:
)
@staticmethod
def _normalize_variable_for_start_node(node_type: NodeType, node_id: str, name: str):
if node_type != NodeType.START:
return node_id, name
# TODO(QuantumGhost): need special handling for dummy output variable in
# `Start` node.
def _normalize_variable_for_start_node(node_id: str, name: str) -> tuple[str, str]:
if not name.startswith(f"{SYSTEM_VARIABLE_NODE_ID}."):
return node_id, name
_logger.debug(
"Normalizing variable: node_type=%s, node_id=%s, name=%s",
node_type,
node_id,
name,
)
node_id, name_ = name.split(".", maxsplit=1)
return node_id, name_

View File

@ -3,7 +3,6 @@ import logging
import time
from collections.abc import Callable, Generator, Sequence
from datetime import UTC, datetime
from inspect import isgenerator
from typing import Any, Optional
from uuid import uuid4
@ -28,6 +27,7 @@ from core.workflow.nodes.node_mapping import LATEST_VERSION, NODE_TYPE_CLASSES_M
from core.workflow.workflow_entry import WorkflowEntry
from events.app_event import app_draft_workflow_was_synced, app_published_workflow_was_updated
from extensions.ext_database import db
from libs import gen_utils
from models.account import Account
from models.model import App, AppMode
from models.tools import WorkflowToolProvider
@ -42,7 +42,11 @@ from services.errors.app import IsDraftWorkflowError, WorkflowHashNotEqualError
from services.workflow.workflow_converter import WorkflowConverter
from .errors.workflow_service import DraftWorkflowDeletionError, WorkflowInUseError
from .workflow_draft_variable_service import WorkflowDraftVariableService, should_save_output_variables_for_draft
from .workflow_draft_variable_service import (
DraftVarLoader,
WorkflowDraftVariableService,
should_save_output_variables_for_draft,
)
class WorkflowService:
@ -310,26 +314,28 @@ class WorkflowService:
if not draft_workflow:
raise ValueError("Workflow not initialized")
# conv_vars = common_helpers.get_conversation_variables()
# run draft workflow node
start_at = time.perf_counter()
# TODO(QuantumGhost): We may get rid of the `list_conversation_variables`
# here, and rely on `DraftVarLoader` to load conversation variables.
with Session(bind=db.engine) as session:
# TODO(QunatumGhost): inject conversation variables
# to variable pool.
draft_var_srv = WorkflowDraftVariableService(session)
conv_vars_list = draft_var_srv.list_conversation_variables(app_id=app_model.id)
conv_var_mapping = {v.name: v.get_value().value for v in conv_vars_list.variables}
node_execution = self._handle_node_run_result(
invoke_node_fn=lambda: WorkflowEntry.single_step_run(
variable_loader = DraftVarLoader(engine=db.engine, app_id=app_model.id)
run = WorkflowEntry.single_step_run(
workflow=draft_workflow,
node_id=node_id,
user_inputs=user_inputs,
user_id=account.id,
conversation_variables=conv_var_mapping,
),
variable_loader=variable_loader,
)
# run draft workflow node
start_at = time.perf_counter()
node_execution = self._handle_node_run_result(
invoke_node_fn=lambda: run,
start_at=start_at,
node_id=node_id,
)
@ -403,7 +409,7 @@ class WorkflowService:
) -> NodeExecution:
try:
node_instance, generator = invoke_node_fn()
generator = _inspect_generator(generator)
generator = gen_utils.inspect(generator, logging.getLogger(__name__))
node_run_result: NodeRunResult | None = None
for event in generator:
@ -610,19 +616,3 @@ class WorkflowService:
session.delete(workflow)
return True
def _inspect_generator(gen: Generator[Any] | Any) -> Any:
if not isgenerator(gen):
return gen
def wrapper():
for item in gen:
logging.getLogger(__name__).info(
"received generator item, type=%s, value=%s",
type(item),
item,
)
yield item
return wrapper()

View File

@ -4,40 +4,42 @@ import uuid
import pytest
from sqlalchemy.orm import Session
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
from factories.variable_factory import build_segment
from models import db
from models.workflow import WorkflowDraftVariable
from services.workflow_draft_variable_service import WorkflowDraftVariableService
from services.workflow_draft_variable_service import DraftVarLoader, WorkflowDraftVariableService
@pytest.mark.usefixtures("flask_req_ctx")
class TestWorkflowDraftVariableService(unittest.TestCase):
_test_app_id: str
_session: Session
_node1_id = "test_node_1"
_node2_id = "test_node_2"
def setUp(self):
self._test_app_id = str(uuid.uuid4())
self._session: Session = db.session
sys_var = WorkflowDraftVariable.create_sys_variable(
sys_var = WorkflowDraftVariable.new_sys_variable(
app_id=self._test_app_id,
name="sys_var",
value=build_segment("sys_value"),
)
conv_var = WorkflowDraftVariable.create_conversation_variable(
conv_var = WorkflowDraftVariable.new_conversation_variable(
app_id=self._test_app_id,
name="conv_var",
value=build_segment("conv_value"),
)
node2_vars = [
WorkflowDraftVariable.create_node_variable(
WorkflowDraftVariable.new_node_variable(
app_id=self._test_app_id,
node_id=self._node2_id,
name="int_var",
value=build_segment(1),
visible=False,
),
WorkflowDraftVariable.create_node_variable(
WorkflowDraftVariable.new_node_variable(
app_id=self._test_app_id,
node_id=self._node2_id,
name="str_var",
@ -45,9 +47,9 @@ class TestWorkflowDraftVariableService(unittest.TestCase):
visible=True,
),
]
node1_var = WorkflowDraftVariable.create_node_variable(
node1_var = WorkflowDraftVariable.new_node_variable(
app_id=self._test_app_id,
node_id="node_1",
node_id=self._node1_id,
name="str_var",
value=build_segment("str_value"),
visible=True,
@ -92,7 +94,7 @@ class TestWorkflowDraftVariableService(unittest.TestCase):
def test_get_node_variable(self):
srv = self._get_test_srv()
node_var = srv.get_node_variable(self._test_app_id, "node_1", "str_var")
node_var = srv.get_node_variable(self._test_app_id, self._node1_id, "str_var")
assert node_var.id == self._node1_str_var_id
assert node_var.name == "str_var"
assert node_var.get_value() == build_segment("str_value")
@ -138,5 +140,86 @@ class TestWorkflowDraftVariableService(unittest.TestCase):
def test__list_node_variables(self):
srv = self._get_test_srv()
node_vars = srv._list_node_variables(self._test_app_id, self._node2_id)
assert len(node_vars) == 2
assert {v.id for v in node_vars} == set(self._node2_var_ids)
assert len(node_vars.variables) == 2
assert {v.id for v in node_vars.variables} == set(self._node2_var_ids)
def test_get_draft_variables_by_selectors(self):
srv = self._get_test_srv()
selectors = [
[self._node1_id, "str_var"],
[self._node2_id, "str_var"],
[self._node2_id, "int_var"],
]
variables = srv.get_draft_variables_by_selectors(self._test_app_id, selectors)
assert len(variables) == 3
assert {v.id for v in variables} == {self._node1_str_var_id} | set(self._node2_var_ids)
@pytest.mark.usefixtures("flask_req_ctx")
class TestDraftVariableLoader(unittest.TestCase):
_test_app_id: str
_node1_id = "test_loader_node_1"
def setUp(self):
self._test_app_id = str(uuid.uuid4())
sys_var = WorkflowDraftVariable.new_sys_variable(
app_id=self._test_app_id,
name="sys_var",
value=build_segment("sys_value"),
)
conv_var = WorkflowDraftVariable.new_conversation_variable(
app_id=self._test_app_id,
name="conv_var",
value=build_segment("conv_value"),
)
node_var = WorkflowDraftVariable.new_node_variable(
app_id=self._test_app_id,
node_id=self._node1_id,
name="str_var",
value=build_segment("str_value"),
visible=True,
)
_variables = [
node_var,
sys_var,
conv_var,
]
with Session(bind=db.engine, expire_on_commit=False) as session:
session.add_all(_variables)
session.flush()
session.commit()
self._variable_ids = [v.id for v in _variables]
self._node_var_id = node_var.id
self._sys_var_id = sys_var.id
self._conv_var_id = conv_var.id
def tearDown(self):
with Session(bind=db.engine, expire_on_commit=False) as session:
session.query(WorkflowDraftVariable).filter(WorkflowDraftVariable.app_id == self._test_app_id).delete(
synchronize_session=False
)
session.commit()
def test_variable_loader_with_empty_selector(self):
var_loader = DraftVarLoader(engine=db.engine, app_id=self._test_app_id)
variables = var_loader.load_variables([])
assert len(variables) == 0
def test_variable_loader_with_non_empty_selector(self):
var_loader = DraftVarLoader(engine=db.engine, app_id=self._test_app_id)
variables = var_loader.load_variables(
[
[SYSTEM_VARIABLE_NODE_ID, "sys_var"],
[CONVERSATION_VARIABLE_NODE_ID, "conv_var"],
[self._node1_id, "str_var"],
]
)
assert len(variables) == 3
conv_var = next(v for v in variables if v.selector[0] == CONVERSATION_VARIABLE_NODE_ID)
assert conv_var.id == self._conv_var_id
sys_var = next(v for v in variables if v.selector[0] == SYSTEM_VARIABLE_NODE_ID)
assert sys_var.id == self._sys_var_id
node1_var = next(v for v in variables if v.selector[0] == self._node1_id)
assert node1_var.id == self._node_var_id

View File

@ -1,3 +1,4 @@
import math
from dataclasses import dataclass
from uuid import uuid4
@ -232,6 +233,10 @@ def _scalar_value() -> st.SearchStrategy[int | float | str | File]:
@given(_scalar_value())
def test_build_segment_and_extract_values_for_scalar_types(value):
seg = variable_factory.build_segment(value)
# nan == nan yields false, so we need to use `math.isnan` to check `seg.value` here.
if isinstance(value, float) and math.isnan(value):
assert math.isnan(seg.value)
else:
assert seg.value == value

View File

@ -1,22 +1,10 @@
from core.variables import SecretVariable
import dataclasses
from core.workflow.entities.variable_entities import VariableSelector
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.enums import SystemVariableKey
from core.workflow.utils import variable_template_parser
def test_extract_selectors_from_template():
variable_pool = VariablePool(
system_variables={
SystemVariableKey("user_id"): "fake-user-id",
},
user_inputs={},
environment_variables=[
SecretVariable(name="secret_key", value="fake-secret-key"),
],
conversation_variables=[],
)
variable_pool.add(("node_id", "custom_query"), "fake-user-query")
template = (
"Hello, {{#sys.user_id#}}! Your query is {{#node_id.custom_query#}}. And your key is {{#env.secret_key#}}."
)
@ -26,3 +14,35 @@ def test_extract_selectors_from_template():
VariableSelector(variable="#node_id.custom_query#", value_selector=["node_id", "custom_query"]),
VariableSelector(variable="#env.secret_key#", value_selector=["env", "secret_key"]),
]
def test_invalid_references():
@dataclasses.dataclass
class TestCase:
name: str
template: str
cases = [
TestCase(
name="lack of closing brace",
template="Hello, {{#sys.user_id#",
),
TestCase(
name="lack of opening brace",
template="Hello, #sys.user_id#}}",
),
TestCase(
name="lack selector name",
template="Hello, {{#sys#}}",
),
TestCase(
name="empty node name part",
template="Hello, {{#.user_id#}}",
),
]
for idx, c in enumerate(cases, 1):
fail_msg = f"Test case {c.name} failed, index={idx}"
selectors = variable_template_parser.extract_selectors_from_template(c.template)
assert selectors == [], fail_msg
parser = variable_template_parser.VariableTemplateParser(c.template)
assert parser.extract_variable_selectors() == [], fail_msg

View File

@ -5,7 +5,7 @@ from uuid import uuid4
import contexts
from constants import HIDDEN_VALUE
from core.variables import FloatVariable, IntegerVariable, SecretVariable, StringVariable
from models.workflow import Workflow, WorkflowNodeExecution
from models.workflow import Workflow, WorkflowNodeExecution, is_system_variable_editable
def test_environment_variables():
@ -149,3 +149,21 @@ class TestWorkflowNodeExecution:
original = {"a": 1, "b": ["2"]}
node_exec.execution_metadata = json.dumps(original)
assert node_exec.execution_metadata_dict == original
class TestIsSystemVariableEditable:
def test_is_system_variable(self):
cases = [
("query", True),
("files", True),
("dialogue_count", False),
("conversation_id", False),
("user_id", False),
("app_id", False),
("workflow_id", False),
("workflow_run_id", False),
]
for name, editable in cases:
assert editable == is_system_variable_editable(name)
assert is_system_variable_editable("invalid_or_new_system_variable") == False

View File

@ -0,0 +1,82 @@
import dataclasses
import secrets
from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID
from core.workflow.nodes import NodeType
from factories.variable_factory import build_segment
from models.workflow import WorkflowDraftVariable
from services.workflow_draft_variable_service import _DraftVariableBuilder
class TestDraftVariableBuilder:
def _get_test_app_id(self):
suffix = secrets.token_hex(6)
return f"test_app_id_{suffix}"
def test_get_variables(self):
test_app_id = self._get_test_app_id()
builder = _DraftVariableBuilder(app_id=test_app_id)
variables = [
WorkflowDraftVariable.new_node_variable(
app_id=test_app_id,
node_id="test_node_1",
name="test_var_1",
value=build_segment("test_value_1"),
visible=True,
),
WorkflowDraftVariable.new_sys_variable(
app_id=test_app_id,
name="test_sys_var",
value=build_segment("test_sys_value"),
),
WorkflowDraftVariable.new_conversation_variable(
app_id=test_app_id,
name="test_conv_var",
value=build_segment("test_conv_value"),
),
]
builder._draft_vars = variables
assert builder.get_variables() == variables
def test__should_variable_be_visible(self):
assert _DraftVariableBuilder._should_variable_be_visible(NodeType.IF_ELSE, "123_456", "output") == False
assert _DraftVariableBuilder._should_variable_be_visible(NodeType.START, "123", "output") == True
def test__normalize_variable_for_start_node(self):
@dataclasses.dataclass(frozen=True)
class TestCase:
name: str
input_node_id: str
input_name: str
expected_node_id: str
expected_name: str
_NODE_ID = "1747228642872"
cases = [
TestCase(
name="name with `sys.` prefix should return the system node_id",
input_node_id=_NODE_ID,
input_name="sys.workflow_id",
expected_node_id=SYSTEM_VARIABLE_NODE_ID,
expected_name="workflow_id",
),
TestCase(
name="name without `sys.` prefix should return the original input node_id",
input_node_id=_NODE_ID,
input_name="start_input",
expected_node_id=_NODE_ID,
expected_name="start_input",
),
TestCase(
name="dummy_variable should return the original input node_id",
input_node_id=_NODE_ID,
input_name="__dummy__",
expected_node_id=_NODE_ID,
expected_name="__dummy__",
),
]
for idx, c in enumerate(cases, 1):
fail_msg = f"Test case {c.name} failed, index={idx}"
node_id, name = _DraftVariableBuilder._normalize_variable_for_start_node(c.input_node_id, c.input_name)
assert node_id == c.expected_node_id, fail_msg
assert name == c.expected_name, fail_msg