From 5cfd60ca0f3e956e9f7809ebed9f796c7b60d42d Mon Sep 17 00:00:00 2001 From: QuantumGhost Date: Mon, 26 May 2025 14:04:30 +0800 Subject: [PATCH 1/9] docs(api): Add a documentation about equality of `FloatSegment` --- api/core/variables/segments.py | 17 +++++++++++++++-- 1 file changed, 15 insertions(+), 2 deletions(-) diff --git a/api/core/variables/segments.py b/api/core/variables/segments.py index 64ba16c367..5a470c0d2d 100644 --- a/api/core/variables/segments.py +++ b/api/core/variables/segments.py @@ -1,12 +1,11 @@ import json -import sys from collections.abc import Mapping, Sequence from typing import Any +import sys from pydantic import BaseModel, ConfigDict, field_validator from core.file import File - from .types import SegmentType @@ -75,6 +74,20 @@ class StringSegment(Segment): class FloatSegment(Segment): value_type: SegmentType = SegmentType.NUMBER value: float + # NOTE(QuantumGhost): seems that the equality for FloatSegment with `NaN` value has some problems. + # The following tests cannot pass. + # + # def test_float_segment_and_nan(): + # nan = float("nan") + # assert nan != nan + # + # f1 = FloatSegment(value=float("nan")) + # f2 = FloatSegment(value=float("nan")) + # assert f1 != f2 + # + # f3 = FloatSegment(value=nan) + # f4 = FloatSegment(value=nan) + # assert f3 != f4 class IntegerSegment(Segment): From 036f0eba2c0a4123422fdbaa679bbd9dcdebd5e3 Mon Sep 17 00:00:00 2001 From: QuantumGhost Date: Mon, 26 May 2025 14:05:45 +0800 Subject: [PATCH 2/9] fix(api): Fix incorrect handling of `Variable` types in `VariablePool` --- api/core/variables/segments.py | 3 ++- api/core/workflow/entities/variable_pool.py | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/api/core/variables/segments.py b/api/core/variables/segments.py index 5a470c0d2d..6cf09e0372 100644 --- a/api/core/variables/segments.py +++ b/api/core/variables/segments.py @@ -1,11 +1,12 @@ import json +import sys from collections.abc import Mapping, Sequence from typing import Any -import sys from pydantic import BaseModel, ConfigDict, field_validator from core.file import File + from .types import SegmentType diff --git a/api/core/workflow/entities/variable_pool.py b/api/core/workflow/entities/variable_pool.py index 74540491e5..d3bbc742a1 100644 --- a/api/core/workflow/entities/variable_pool.py +++ b/api/core/workflow/entities/variable_pool.py @@ -96,7 +96,7 @@ class VariablePool(BaseModel): if isinstance(value, Variable): variable = value - if isinstance(value, Segment): + elif isinstance(value, Segment): variable = variable_factory.segment_to_variable(segment=value, selector=selector) else: segment = variable_factory.build_segment(value) From 865dd4869c51a62fd3f4d1bcb3ad6af40845d97c Mon Sep 17 00:00:00 2001 From: QuantumGhost Date: Mon, 26 May 2025 14:08:39 +0800 Subject: [PATCH 3/9] docs(api): Add documentation for `extract_variable_selector_to_variable_mapping` Add a detail description about the arguments and the return value of `BaseNode.extract_variable_selector_to_variable_mapping`. --- api/core/workflow/nodes/base/node.py | 24 ++++++++++++++++++++++-- 1 file changed, 22 insertions(+), 2 deletions(-) diff --git a/api/core/workflow/nodes/base/node.py b/api/core/workflow/nodes/base/node.py index b2b4fe0cf1..81aa669381 100644 --- a/api/core/workflow/nodes/base/node.py +++ b/api/core/workflow/nodes/base/node.py @@ -90,8 +90,28 @@ class BaseNode(Generic[GenericNodeData]): graph_config: Mapping[str, Any], config: Mapping[str, Any], ) -> Mapping[str, Sequence[str]]: - """ - Extract variable selector to variable mapping + """Extracts references variable selectors from node configuration. + + The `config` parameter represents the configuration for a specific node type and corresponds + to the `data` field in the node definition object. + + The returned mapping has the following structure: + + {'1747829548239.#1747829667553.result#': ['1747829667553', 'result']} + + Here, the key consists of two parts: the current node ID (provided as the `node_id` + parameter to `_extract_variable_selector_to_variable_mapping`) and the variable selector, + enclosed in `#` symbols. These two parts are separated by a dot (`.`). + + The value is a list of string representing the variable selector, where the first element is the node ID + of the referenced variable, and the second element is the variable name within that node. + + The meaning of the above response is: + + The node with ID `1747829548239` references the variable `result` from the node with + ID `1747829667553`. For example, if `1747829548239` is a LLM node, its prompt may contain a + reference to the `result` output variable of node `1747829667553`. + :param graph_config: graph config :param config: node config :return: From 0984f580a3fd5d5dfb086d322bae9f78c17d87e9 Mon Sep 17 00:00:00 2001 From: QuantumGhost Date: Mon, 26 May 2025 14:10:58 +0800 Subject: [PATCH 4/9] feat(api): Implement variable mapping extraction for `IfElseNode` This allows us to extract referenced variables from `IfElseNode` configurations. --- .../workflow/nodes/if_else/if_else_node.py | 19 ++++++++++++++++++- 1 file changed, 18 insertions(+), 1 deletion(-) diff --git a/api/core/workflow/nodes/if_else/if_else_node.py b/api/core/workflow/nodes/if_else/if_else_node.py index 57792ca09a..0aa78a68e4 100644 --- a/api/core/workflow/nodes/if_else/if_else_node.py +++ b/api/core/workflow/nodes/if_else/if_else_node.py @@ -1,4 +1,5 @@ -from typing import Literal +from collections.abc import Mapping, Sequence +from typing import Any, Literal from typing_extensions import deprecated @@ -91,6 +92,22 @@ class IfElseNode(BaseNode[IfElseNodeData]): return data + @classmethod + def _extract_variable_selector_to_variable_mapping( + cls, + *, + graph_config: Mapping[str, Any], + node_id: str, + node_data: IfElseNodeData, + ) -> Mapping[str, Sequence[str]]: + var_mapping: dict[str, list[str]] = {} + for case in node_data.cases or []: + for condition in case.conditions: + key = "{}.#{}#".format(node_id, ".".join(condition.variable_selector)) + var_mapping[key] = condition.variable_selector + + return var_mapping + @deprecated("This function is deprecated. You should use the new cases structure.") def _should_not_use_old_function( From 141f9b4d51496505c311d7e7843f95912c5aa6f7 Mon Sep 17 00:00:00 2001 From: QuantumGhost Date: Mon, 26 May 2025 14:17:15 +0800 Subject: [PATCH 5/9] feat(api): Introduce a VariableLoader interface. The `VariableLoader` interface is used to load referenced variables when running a single node. --- api/core/workflow/variable_loader.py | 38 ++++++++++++++++++++++++++++ 1 file changed, 38 insertions(+) create mode 100644 api/core/workflow/variable_loader.py diff --git a/api/core/workflow/variable_loader.py b/api/core/workflow/variable_loader.py new file mode 100644 index 0000000000..59fb3a90a1 --- /dev/null +++ b/api/core/workflow/variable_loader.py @@ -0,0 +1,38 @@ +import abc +from typing import Protocol + +from core.variables import Variable + + +class VariableLoader(Protocol): + """Interface for loading variables based on selectors. + + A `VariableLoader` is responsible for retrieving additional variables required during the execution + of a single node, which are not provided as user inputs. + + NOTE(QuantumGhost): Typically, all variables loaded by a `VariableLoader` should belong to the same + application and share the same `app_id`. However, this interface does not enforce that constraint, + and the `app_id` parameter is intentionally omitted from `load_variables` to achieve separation of + concern and allow for flexible implementations. + + Implementations of `VariableLoader` should almost always have an `app_id` parameter in + their constructor. + + TODO(QuantumGhost): this is a temporally workaround. If we can move the creation of node instance into + `WorkflowService.single_step_run`, we may get rid of this interface. + """ + + @abc.abstractmethod + def load_variables(self, selectors: list[list[str]]) -> list[Variable]: + """Load variables based on the provided selectors. If the selectors are empty, + this method should return an empty list. + + The order of the returned variables is not guaranteed. If the caller wants to ensure + a specific order, they should sort the returned list themselves. + + :param: selectors: a list of string list, each inner list should have at least two elements: + - the first element is the node ID, + - the second element is the variable name. + :return: a list of Variable objects that match the provided selectors. + """ + pass From defe8fea636173d280a164f623e25edf7e7712ed Mon Sep 17 00:00:00 2001 From: QuantumGhost Date: Mon, 26 May 2025 14:20:56 +0800 Subject: [PATCH 6/9] feat(api): Implement VariableLoader in `DraftVarLoader`. Use DraftVarLoader to load required variables when single stepping a node. --- api/core/workflow/workflow_entry.py | 45 +++++++++----- .../workflow_draft_variable_service.py | 60 ++++++++++++++++++- api/services/workflow_service.py | 54 +++++++---------- 3 files changed, 110 insertions(+), 49 deletions(-) diff --git a/api/core/workflow/workflow_entry.py b/api/core/workflow/workflow_entry.py index b2cfa23aa2..4225ad7c44 100644 --- a/api/core/workflow/workflow_entry.py +++ b/api/core/workflow/workflow_entry.py @@ -2,12 +2,13 @@ import logging import time import uuid from collections.abc import Callable, Generator, Mapping, Sequence -from typing import Any, Optional, TypeAlias, TypeVar, cast +from typing import Any, Optional, TypeAlias, cast from configs import dify_config from core.app.apps.base_app_queue_manager import GenerateTaskStoppedError from core.app.entities.app_invoke_entities import InvokeFrom from core.file.models import File +from core.variables import Variable from core.workflow.callbacks import WorkflowCallback from core.workflow.constants import ENVIRONMENT_VARIABLE_NODE_ID from core.workflow.entities.node_entities import NodeRunMetadataKey @@ -22,13 +23,27 @@ from core.workflow.nodes import NodeType from core.workflow.nodes.base import BaseNode from core.workflow.nodes.event import NodeEvent, RunCompletedEvent from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING +from core.workflow.variable_loader import VariableLoader from factories import file_factory +from libs import gen_utils from models.enums import UserFrom from models.workflow import ( Workflow, WorkflowType, ) + +class _DummyVariableLoader(VariableLoader): + """A dummy implementation of VariableLoader that does not load any variables. + Serves as a placeholder when no variable loading is needed. + """ + + def load_variables(self, selectors: list[list[str]]) -> list[Variable]: + return [] + + +_DUMMY_VARIABLE_LOADER = _DummyVariableLoader() + logger = logging.getLogger(__name__) @@ -122,6 +137,7 @@ class WorkflowEntry: user_id: str, user_inputs: dict, conversation_variables: dict | None = None, + variable_loader: VariableLoader = _DUMMY_VARIABLE_LOADER, ) -> tuple[BaseNode, Generator[NodeEvent | InNodeEvent, None, None]]: """ Single step run workflow node @@ -190,6 +206,19 @@ class WorkflowEntry: except NotImplementedError: variable_mapping = {} + # Loading missing variable from draft var here, and set it into + # variable_pool. + variables_to_load: list[list[str]] = [] + for key, selector in variable_mapping.items(): + trimmed_key = key.removeprefix(f"{node_id}.") + if trimmed_key in user_inputs: + continue + if variable_pool.get(selector) is None: + variables_to_load.append(list(selector)) + loaded = variable_loader.load_variables(variables_to_load) + for var in loaded: + variable_pool.add(var.selector, var.value) + cls.mapping_user_inputs_to_variable_pool( variable_mapping=variable_mapping, user_inputs=user_inputs, @@ -204,7 +233,7 @@ class WorkflowEntry: except Exception as e: raise WorkflowNodeRunFailedError(node_instance=node_instance, error=str(e)) if metadata_attacher: - generator = _wrap_generator(generator, metadata_attacher) + generator = gen_utils.map_(generator, metadata_attacher) return node_instance, generator @classmethod @@ -391,18 +420,6 @@ class WorkflowEntry: variable_pool.add([variable_node_id] + variable_key_list, input_value) -_YieldT_co = TypeVar("_YieldT_co", covariant=True) -_YieldR_co = TypeVar("_YieldR_co", covariant=True) - - -def _wrap_generator( - gen: Generator[_YieldT_co, None, None], - mapper: Callable[[_YieldT_co], _YieldR_co], -) -> Generator[_YieldR_co, None, None]: - for item in gen: - yield mapper(item) - - _NodeOrInNodeEvent: TypeAlias = NodeEvent | InNodeEvent diff --git a/api/services/workflow_draft_variable_service.py b/api/services/workflow_draft_variable_service.py index 15ff43b6e5..a37efd9000 100644 --- a/api/services/workflow_draft_variable_service.py +++ b/api/services/workflow_draft_variable_service.py @@ -3,17 +3,19 @@ import logging from collections.abc import Mapping, Sequence from typing import Any -from sqlalchemy import orm +from sqlalchemy import Engine, orm from sqlalchemy.dialects.postgresql import insert from sqlalchemy.orm import Session +from sqlalchemy.sql.expression import and_, or_ from core.app.entities.app_invoke_entities import InvokeFrom -from core.file.constants import is_dummy_output_variable -from core.variables import Segment +from core.variables import Segment, Variable from core.variables.consts import MIN_SELECTORS_LENGTH from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, ENVIRONMENT_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID from core.workflow.nodes import NodeType +from core.workflow.variable_loader import VariableLoader from factories import variable_factory +from factories.variable_factory import build_segment, segment_to_variable from models.workflow import WorkflowDraftVariable, is_system_variable_editable _logger = logging.getLogger(__name__) @@ -25,6 +27,36 @@ class WorkflowDraftVariableList: total: int | None = None +class DraftVarLoader(VariableLoader): + # This implements the VariableLoader interface for loading draft variables. + # + # ref: core.workflow.variable_loader.VariableLoader + def __init__(self, engine: Engine, app_id: str) -> None: + self._engine = engine + self._app_id = app_id + + def load_variables(self, selectors: list[list[str]]) -> list[Variable]: + if not selectors: + return [] + with Session(bind=self._engine, expire_on_commit=False) as session: + srv = WorkflowDraftVariableService(session) + draft_vars = srv.get_draft_variables_by_selectors(self._app_id, selectors) + variables = [] + for draft_var in draft_vars: + segment = build_segment( + draft_var.value, + ) + variable = segment_to_variable( + segment=segment, + selector=draft_var.get_selector(), + id=draft_var.id, + name=draft_var.name, + description=draft_var.description, + ) + variables.append(variable) + return variables + + class WorkflowDraftVariableService: _session: Session @@ -34,6 +66,28 @@ class WorkflowDraftVariableService: def get_variable(self, variable_id: str) -> WorkflowDraftVariable | None: return self._session.query(WorkflowDraftVariable).filter(WorkflowDraftVariable.id == variable_id).first() + def get_draft_variables_by_selectors( + self, + app_id: str, + selectors: Sequence[list[str]], + ) -> list[WorkflowDraftVariable]: + ors = [] + for selector in selectors: + node_id, name = selector + ors.append(and_(WorkflowDraftVariable.node_id == node_id, WorkflowDraftVariable.name == name)) + + # NOTE(QuantumGhost): Although the number of `or` expressions may be large, as long as + # each expression includes conditions on both `node_id` and `name` (which are covered by the unique index), + # PostgreSQL can efficiently retrieve the results using a bitmap index scan. + # + # Alternatively, a `SELECT` statement could be constructed for each selector and + # combined using `UNION` to fetch all rows. + # Benchmarking indicates that both approaches yield comparable performance. + variables = ( + self._session.query(WorkflowDraftVariable).where(WorkflowDraftVariable.app_id == app_id, or_(*ors)).all() + ) + return variables + def save_output_variables(self, app_id: str, node_id: str, node_type: NodeType, output: Mapping[str, Any]): variable_builder = _DraftVariableBuilder(app_id=app_id) variable_builder.build(node_id=node_id, node_type=node_type, output=output) diff --git a/api/services/workflow_service.py b/api/services/workflow_service.py index b44f2706a8..23e19235b6 100644 --- a/api/services/workflow_service.py +++ b/api/services/workflow_service.py @@ -3,7 +3,6 @@ import logging import time from collections.abc import Callable, Generator, Sequence from datetime import UTC, datetime -from inspect import isgenerator from typing import Any, Optional from uuid import uuid4 @@ -28,6 +27,7 @@ from core.workflow.nodes.node_mapping import LATEST_VERSION, NODE_TYPE_CLASSES_M from core.workflow.workflow_entry import WorkflowEntry from events.app_event import app_draft_workflow_was_synced, app_published_workflow_was_updated from extensions.ext_database import db +from libs import gen_utils from models.account import Account from models.model import App, AppMode from models.tools import WorkflowToolProvider @@ -42,7 +42,11 @@ from services.errors.app import IsDraftWorkflowError, WorkflowHashNotEqualError from services.workflow.workflow_converter import WorkflowConverter from .errors.workflow_service import DraftWorkflowDeletionError, WorkflowInUseError -from .workflow_draft_variable_service import WorkflowDraftVariableService, should_save_output_variables_for_draft +from .workflow_draft_variable_service import ( + DraftVarLoader, + WorkflowDraftVariableService, + should_save_output_variables_for_draft, +) class WorkflowService: @@ -310,26 +314,28 @@ class WorkflowService: if not draft_workflow: raise ValueError("Workflow not initialized") - # conv_vars = common_helpers.get_conversation_variables() - - # run draft workflow node - start_at = time.perf_counter() + # TODO(QuantumGhost): We may get rid of the `list_conversation_variables` + # here, and rely on `DraftVarLoader` to load conversation variables. with Session(bind=db.engine) as session: - # TODO(QunatumGhost): inject conversation variables - # to variable pool. draft_var_srv = WorkflowDraftVariableService(session) conv_vars_list = draft_var_srv.list_conversation_variables(app_id=app_model.id) conv_var_mapping = {v.name: v.get_value().value for v in conv_vars_list.variables} + variable_loader = DraftVarLoader(engine=db.engine, app_id=app_model.id) + run = WorkflowEntry.single_step_run( + workflow=draft_workflow, + node_id=node_id, + user_inputs=user_inputs, + user_id=account.id, + conversation_variables=conv_var_mapping, + variable_loader=variable_loader, + ) + + # run draft workflow node + start_at = time.perf_counter() node_execution = self._handle_node_run_result( - invoke_node_fn=lambda: WorkflowEntry.single_step_run( - workflow=draft_workflow, - node_id=node_id, - user_inputs=user_inputs, - user_id=account.id, - conversation_variables=conv_var_mapping, - ), + invoke_node_fn=lambda: run, start_at=start_at, node_id=node_id, ) @@ -403,7 +409,7 @@ class WorkflowService: ) -> NodeExecution: try: node_instance, generator = invoke_node_fn() - generator = _inspect_generator(generator) + generator = gen_utils.inspect(generator, logging.getLogger(__name__)) node_run_result: NodeRunResult | None = None for event in generator: @@ -610,19 +616,3 @@ class WorkflowService: session.delete(workflow) return True - - -def _inspect_generator(gen: Generator[Any] | Any) -> Any: - if not isgenerator(gen): - return gen - - def wrapper(): - for item in gen: - logging.getLogger(__name__).info( - "received generator item, type=%s, value=%s", - type(item), - item, - ) - yield item - - return wrapper() From 3b45b87d02304b5d29f514fa976b45a34794c698 Mon Sep 17 00:00:00 2001 From: QuantumGhost Date: Mon, 26 May 2025 14:22:15 +0800 Subject: [PATCH 7/9] fix(api): Fix input variable handling for `Start` node. --- .../workflow_draft_variable_service.py | 70 +++++-------------- 1 file changed, 18 insertions(+), 52 deletions(-) diff --git a/api/services/workflow_draft_variable_service.py b/api/services/workflow_draft_variable_service.py index a37efd9000..504d693742 100644 --- a/api/services/workflow_draft_variable_service.py +++ b/api/services/workflow_draft_variable_service.py @@ -96,29 +96,25 @@ class WorkflowDraftVariableService: if not draft_variables: return - # We may use SQLAlchemy ORM operation here. However, considering the fact that: + # Although we could use SQLAlchemy ORM operations here, we choose not to for several reasons: # - # 1. The variable saving process writes multiple rows into one table (`workflow_draft_variables`). - # Use batch insertion may increase performance dramatically. - # 2. If we use ORM operation, we need to either: + # 1. The variable saving process involves writing multiple rows to the + # `workflow_draft_variables` table. Batch insertion significantly improves performance. + # 2. Using the ORM would require either: # - # a. Check the existence for each variable before insertion. - # b. Try insertion first, then do update if insertion fails due to unique index violation. + # a. Checking for the existence of each variable before insertion, + # resulting in 2n SQL statements for n variables and potential concurrency issues. + # b. Attempting insertion first, then updating if a unique index violation occurs, + # which still results in n to 2n SQL statements. # - # Neither of the above is satisfactory. + # Both approaches are inefficient and suboptimal. + # 3. We do not need to retrieve the results of the SQL execution or populate ORM + # model instances with the returned values. + # 4. Batch insertion with `ON CONFLICT DO UPDATE` allows us to insert or update all + # variables in a single SQL statement, avoiding the issues above. # - # - For implementation "a", we need to issue `2n` sqls for `n` variables in output. - # Besides, it's still suffer from concurrency issues. - # - For implementation "b", we need to issue `n` - `2n` sqls (depending on the existence of - # specific variable), which is lesser than plan "a" but still far from ideal. - # - # 3. We do not need the value of SQL execution, nor do we need populate those values back to ORM model - # instances. - # 4. Batch insertion can be combined with `ON CONFLICT DO UPDATE`, allows us to insert or update - # all variables in one SQL statement, and avoid all problems above. - # - # Given reasons above, we use query builder instead of using ORM layer, - # and rely on dialect specific insert operations. + # For these reasons, we use the SQLAlchemy query builder and rely on dialect-specific + # insert operations instead of the ORM layer. if node_type == NodeType.CODE: # Clear existing variable for code node. self._session.query(WorkflowDraftVariable).filter( @@ -267,26 +263,6 @@ def should_save_output_variables_for_draft( return True -# def should_save_output_variables_for_draft(invoke_from: InvokeFrom, node_exec: WorkflowNodeExecution) -> bool: -# # Only save output variables for debugging execution of workflow. -# if invoke_from != InvokeFrom.DEBUGGER: -# return False -# exec_metadata = node_exec.execution_metadata_dict -# if exec_metadata is None: -# # No execution metadata, assume the node is not in loop or iteration. -# return True -# -# # Currently we do not save output variables for nodes inside loop or iteration. -# loop_id = exec_metadata.get(NodeRunMetadataKey.LOOP_ID) -# if loop_id is not None: -# return False -# iteration_id = exec_metadata.get(NodeRunMetadataKey.ITERATION_ID) -# if iteration_id is not None: -# return False -# return True -# - - class _DraftVariableBuilder: _app_id: str _draft_vars: list[WorkflowDraftVariable] @@ -334,7 +310,8 @@ class _DraftVariableBuilder: original_node_id = node_id for name, value in output.items(): value_seg = variable_factory.build_segment(value) - if is_dummy_output_variable(name): + node_id, name = self._normalize_variable_for_start_node(node_id, name) + if node_id != SYSTEM_VARIABLE_NODE_ID: self._draft_vars.append( WorkflowDraftVariable.new_node_variable( app_id=self._app_id, @@ -356,20 +333,9 @@ class _DraftVariableBuilder: ) @staticmethod - def _normalize_variable_for_start_node(node_type: NodeType, node_id: str, name: str): - if node_type != NodeType.START: - return node_id, name - - # TODO(QuantumGhost): need special handling for dummy output variable in - # `Start` node. + def _normalize_variable_for_start_node(node_id: str, name: str) -> tuple[str, str]: if not name.startswith(f"{SYSTEM_VARIABLE_NODE_ID}."): return node_id, name - _logger.debug( - "Normalizing variable: node_type=%s, node_id=%s, name=%s", - node_type, - node_id, - name, - ) node_id, name_ = name.split(".", maxsplit=1) return node_id, name_ From 4cf9d2306903a43e0b325d1188d8f14a6f2067dd Mon Sep 17 00:00:00 2001 From: QuantumGhost Date: Mon, 26 May 2025 14:22:57 +0800 Subject: [PATCH 8/9] feat(api): Add some utility functions for working with generators --- api/libs/gen_utils.py | 53 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 53 insertions(+) create mode 100644 api/libs/gen_utils.py diff --git a/api/libs/gen_utils.py b/api/libs/gen_utils.py new file mode 100644 index 0000000000..e457176187 --- /dev/null +++ b/api/libs/gen_utils.py @@ -0,0 +1,53 @@ +"""Utility functions for working with generators.""" + +import logging +from collections.abc import Callable, Generator +from inspect import isgenerator +from typing import TypeVar + +_YieldT = TypeVar("_YieldT") +_YieldR = TypeVar("_YieldR") + +_T = TypeVar("_T") + + +def inspect(gen_or_normal: _T, logger: logging.Logger) -> _T: + if not isgenerator(gen_or_normal): + return gen_or_normal + + def wrapper(): + for item in gen_or_normal: + logger.info( + "received generator item, type=%s, value=%s", + type(item), + item, + ) + yield item + + return wrapper() + + +def map_( + gen: Generator[_YieldT, None, None], + mapper: Callable[[_YieldT], _YieldR], +) -> Generator[_YieldR, None, None]: + for item in gen: + yield mapper(item) + + +def filter_( + gen: Generator[_YieldT, None, None], + mapper: Callable[[_YieldT], bool], +) -> Generator[_YieldT, None, None]: + for item in gen: + if mapper(item): + yield item + + +def wrap( + gen: Generator[_YieldT, None, None], + funcs: list[Callable[[Generator[_YieldT, None, None]], Generator[_YieldT, None, None]]], +) -> Generator[_YieldT, None, None]: + for f in funcs: + gen = f(gen) + return gen From f1fd05ccda7eca9637d2c904956a548af422a717 Mon Sep 17 00:00:00 2001 From: QuantumGhost Date: Mon, 26 May 2025 14:23:34 +0800 Subject: [PATCH 9/9] test(api): Add some tests for DraftVariable related code --- .../test_workflow_draft_variable_service.py | 103 ++++++++++++++++-- .../core/app/segments/test_factory.py | 7 +- .../utils/test_variable_template_parser.py | 48 +++++--- api/tests/unit_tests/models/test_workflow.py | 20 +++- .../test_workflow_draft_variable_service.py | 82 ++++++++++++++ 5 files changed, 234 insertions(+), 26 deletions(-) create mode 100644 api/tests/unit_tests/services/workflow/test_workflow_draft_variable_service.py diff --git a/api/tests/integration_tests/services/test_workflow_draft_variable_service.py b/api/tests/integration_tests/services/test_workflow_draft_variable_service.py index fbe7826b3a..8288dc7bd8 100644 --- a/api/tests/integration_tests/services/test_workflow_draft_variable_service.py +++ b/api/tests/integration_tests/services/test_workflow_draft_variable_service.py @@ -4,40 +4,42 @@ import uuid import pytest from sqlalchemy.orm import Session +from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID from factories.variable_factory import build_segment from models import db from models.workflow import WorkflowDraftVariable -from services.workflow_draft_variable_service import WorkflowDraftVariableService +from services.workflow_draft_variable_service import DraftVarLoader, WorkflowDraftVariableService @pytest.mark.usefixtures("flask_req_ctx") class TestWorkflowDraftVariableService(unittest.TestCase): _test_app_id: str _session: Session + _node1_id = "test_node_1" _node2_id = "test_node_2" def setUp(self): self._test_app_id = str(uuid.uuid4()) self._session: Session = db.session - sys_var = WorkflowDraftVariable.create_sys_variable( + sys_var = WorkflowDraftVariable.new_sys_variable( app_id=self._test_app_id, name="sys_var", value=build_segment("sys_value"), ) - conv_var = WorkflowDraftVariable.create_conversation_variable( + conv_var = WorkflowDraftVariable.new_conversation_variable( app_id=self._test_app_id, name="conv_var", value=build_segment("conv_value"), ) node2_vars = [ - WorkflowDraftVariable.create_node_variable( + WorkflowDraftVariable.new_node_variable( app_id=self._test_app_id, node_id=self._node2_id, name="int_var", value=build_segment(1), visible=False, ), - WorkflowDraftVariable.create_node_variable( + WorkflowDraftVariable.new_node_variable( app_id=self._test_app_id, node_id=self._node2_id, name="str_var", @@ -45,9 +47,9 @@ class TestWorkflowDraftVariableService(unittest.TestCase): visible=True, ), ] - node1_var = WorkflowDraftVariable.create_node_variable( + node1_var = WorkflowDraftVariable.new_node_variable( app_id=self._test_app_id, - node_id="node_1", + node_id=self._node1_id, name="str_var", value=build_segment("str_value"), visible=True, @@ -92,7 +94,7 @@ class TestWorkflowDraftVariableService(unittest.TestCase): def test_get_node_variable(self): srv = self._get_test_srv() - node_var = srv.get_node_variable(self._test_app_id, "node_1", "str_var") + node_var = srv.get_node_variable(self._test_app_id, self._node1_id, "str_var") assert node_var.id == self._node1_str_var_id assert node_var.name == "str_var" assert node_var.get_value() == build_segment("str_value") @@ -138,5 +140,86 @@ class TestWorkflowDraftVariableService(unittest.TestCase): def test__list_node_variables(self): srv = self._get_test_srv() node_vars = srv._list_node_variables(self._test_app_id, self._node2_id) - assert len(node_vars) == 2 - assert {v.id for v in node_vars} == set(self._node2_var_ids) + assert len(node_vars.variables) == 2 + assert {v.id for v in node_vars.variables} == set(self._node2_var_ids) + + def test_get_draft_variables_by_selectors(self): + srv = self._get_test_srv() + selectors = [ + [self._node1_id, "str_var"], + [self._node2_id, "str_var"], + [self._node2_id, "int_var"], + ] + variables = srv.get_draft_variables_by_selectors(self._test_app_id, selectors) + assert len(variables) == 3 + assert {v.id for v in variables} == {self._node1_str_var_id} | set(self._node2_var_ids) + + +@pytest.mark.usefixtures("flask_req_ctx") +class TestDraftVariableLoader(unittest.TestCase): + _test_app_id: str + + _node1_id = "test_loader_node_1" + + def setUp(self): + self._test_app_id = str(uuid.uuid4()) + sys_var = WorkflowDraftVariable.new_sys_variable( + app_id=self._test_app_id, + name="sys_var", + value=build_segment("sys_value"), + ) + conv_var = WorkflowDraftVariable.new_conversation_variable( + app_id=self._test_app_id, + name="conv_var", + value=build_segment("conv_value"), + ) + node_var = WorkflowDraftVariable.new_node_variable( + app_id=self._test_app_id, + node_id=self._node1_id, + name="str_var", + value=build_segment("str_value"), + visible=True, + ) + _variables = [ + node_var, + sys_var, + conv_var, + ] + + with Session(bind=db.engine, expire_on_commit=False) as session: + session.add_all(_variables) + session.flush() + session.commit() + self._variable_ids = [v.id for v in _variables] + self._node_var_id = node_var.id + self._sys_var_id = sys_var.id + self._conv_var_id = conv_var.id + + def tearDown(self): + with Session(bind=db.engine, expire_on_commit=False) as session: + session.query(WorkflowDraftVariable).filter(WorkflowDraftVariable.app_id == self._test_app_id).delete( + synchronize_session=False + ) + session.commit() + + def test_variable_loader_with_empty_selector(self): + var_loader = DraftVarLoader(engine=db.engine, app_id=self._test_app_id) + variables = var_loader.load_variables([]) + assert len(variables) == 0 + + def test_variable_loader_with_non_empty_selector(self): + var_loader = DraftVarLoader(engine=db.engine, app_id=self._test_app_id) + variables = var_loader.load_variables( + [ + [SYSTEM_VARIABLE_NODE_ID, "sys_var"], + [CONVERSATION_VARIABLE_NODE_ID, "conv_var"], + [self._node1_id, "str_var"], + ] + ) + assert len(variables) == 3 + conv_var = next(v for v in variables if v.selector[0] == CONVERSATION_VARIABLE_NODE_ID) + assert conv_var.id == self._conv_var_id + sys_var = next(v for v in variables if v.selector[0] == SYSTEM_VARIABLE_NODE_ID) + assert sys_var.id == self._sys_var_id + node1_var = next(v for v in variables if v.selector[0] == self._node1_id) + assert node1_var.id == self._node_var_id diff --git a/api/tests/unit_tests/core/app/segments/test_factory.py b/api/tests/unit_tests/core/app/segments/test_factory.py index 68fc85aa17..725351d429 100644 --- a/api/tests/unit_tests/core/app/segments/test_factory.py +++ b/api/tests/unit_tests/core/app/segments/test_factory.py @@ -1,3 +1,4 @@ +import math from dataclasses import dataclass from uuid import uuid4 @@ -232,7 +233,11 @@ def _scalar_value() -> st.SearchStrategy[int | float | str | File]: @given(_scalar_value()) def test_build_segment_and_extract_values_for_scalar_types(value): seg = variable_factory.build_segment(value) - assert seg.value == value + # nan == nan yields false, so we need to use `math.isnan` to check `seg.value` here. + if isinstance(value, float) and math.isnan(value): + assert math.isnan(seg.value) + else: + assert seg.value == value @given(st.lists(_scalar_value())) diff --git a/api/tests/unit_tests/core/workflow/utils/test_variable_template_parser.py b/api/tests/unit_tests/core/workflow/utils/test_variable_template_parser.py index 2f90afcf89..28ef05edde 100644 --- a/api/tests/unit_tests/core/workflow/utils/test_variable_template_parser.py +++ b/api/tests/unit_tests/core/workflow/utils/test_variable_template_parser.py @@ -1,22 +1,10 @@ -from core.variables import SecretVariable +import dataclasses + from core.workflow.entities.variable_entities import VariableSelector -from core.workflow.entities.variable_pool import VariablePool -from core.workflow.enums import SystemVariableKey from core.workflow.utils import variable_template_parser def test_extract_selectors_from_template(): - variable_pool = VariablePool( - system_variables={ - SystemVariableKey("user_id"): "fake-user-id", - }, - user_inputs={}, - environment_variables=[ - SecretVariable(name="secret_key", value="fake-secret-key"), - ], - conversation_variables=[], - ) - variable_pool.add(("node_id", "custom_query"), "fake-user-query") template = ( "Hello, {{#sys.user_id#}}! Your query is {{#node_id.custom_query#}}. And your key is {{#env.secret_key#}}." ) @@ -26,3 +14,35 @@ def test_extract_selectors_from_template(): VariableSelector(variable="#node_id.custom_query#", value_selector=["node_id", "custom_query"]), VariableSelector(variable="#env.secret_key#", value_selector=["env", "secret_key"]), ] + + +def test_invalid_references(): + @dataclasses.dataclass + class TestCase: + name: str + template: str + + cases = [ + TestCase( + name="lack of closing brace", + template="Hello, {{#sys.user_id#", + ), + TestCase( + name="lack of opening brace", + template="Hello, #sys.user_id#}}", + ), + TestCase( + name="lack selector name", + template="Hello, {{#sys#}}", + ), + TestCase( + name="empty node name part", + template="Hello, {{#.user_id#}}", + ), + ] + for idx, c in enumerate(cases, 1): + fail_msg = f"Test case {c.name} failed, index={idx}" + selectors = variable_template_parser.extract_selectors_from_template(c.template) + assert selectors == [], fail_msg + parser = variable_template_parser.VariableTemplateParser(c.template) + assert parser.extract_variable_selectors() == [], fail_msg diff --git a/api/tests/unit_tests/models/test_workflow.py b/api/tests/unit_tests/models/test_workflow.py index 70ce224eb6..e7633e6141 100644 --- a/api/tests/unit_tests/models/test_workflow.py +++ b/api/tests/unit_tests/models/test_workflow.py @@ -5,7 +5,7 @@ from uuid import uuid4 import contexts from constants import HIDDEN_VALUE from core.variables import FloatVariable, IntegerVariable, SecretVariable, StringVariable -from models.workflow import Workflow, WorkflowNodeExecution +from models.workflow import Workflow, WorkflowNodeExecution, is_system_variable_editable def test_environment_variables(): @@ -149,3 +149,21 @@ class TestWorkflowNodeExecution: original = {"a": 1, "b": ["2"]} node_exec.execution_metadata = json.dumps(original) assert node_exec.execution_metadata_dict == original + + +class TestIsSystemVariableEditable: + def test_is_system_variable(self): + cases = [ + ("query", True), + ("files", True), + ("dialogue_count", False), + ("conversation_id", False), + ("user_id", False), + ("app_id", False), + ("workflow_id", False), + ("workflow_run_id", False), + ] + for name, editable in cases: + assert editable == is_system_variable_editable(name) + + assert is_system_variable_editable("invalid_or_new_system_variable") == False diff --git a/api/tests/unit_tests/services/workflow/test_workflow_draft_variable_service.py b/api/tests/unit_tests/services/workflow/test_workflow_draft_variable_service.py new file mode 100644 index 0000000000..c1da6eaede --- /dev/null +++ b/api/tests/unit_tests/services/workflow/test_workflow_draft_variable_service.py @@ -0,0 +1,82 @@ +import dataclasses +import secrets + +from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID +from core.workflow.nodes import NodeType +from factories.variable_factory import build_segment +from models.workflow import WorkflowDraftVariable +from services.workflow_draft_variable_service import _DraftVariableBuilder + + +class TestDraftVariableBuilder: + def _get_test_app_id(self): + suffix = secrets.token_hex(6) + return f"test_app_id_{suffix}" + + def test_get_variables(self): + test_app_id = self._get_test_app_id() + builder = _DraftVariableBuilder(app_id=test_app_id) + variables = [ + WorkflowDraftVariable.new_node_variable( + app_id=test_app_id, + node_id="test_node_1", + name="test_var_1", + value=build_segment("test_value_1"), + visible=True, + ), + WorkflowDraftVariable.new_sys_variable( + app_id=test_app_id, + name="test_sys_var", + value=build_segment("test_sys_value"), + ), + WorkflowDraftVariable.new_conversation_variable( + app_id=test_app_id, + name="test_conv_var", + value=build_segment("test_conv_value"), + ), + ] + builder._draft_vars = variables + assert builder.get_variables() == variables + + def test__should_variable_be_visible(self): + assert _DraftVariableBuilder._should_variable_be_visible(NodeType.IF_ELSE, "123_456", "output") == False + assert _DraftVariableBuilder._should_variable_be_visible(NodeType.START, "123", "output") == True + + def test__normalize_variable_for_start_node(self): + @dataclasses.dataclass(frozen=True) + class TestCase: + name: str + input_node_id: str + input_name: str + expected_node_id: str + expected_name: str + + _NODE_ID = "1747228642872" + cases = [ + TestCase( + name="name with `sys.` prefix should return the system node_id", + input_node_id=_NODE_ID, + input_name="sys.workflow_id", + expected_node_id=SYSTEM_VARIABLE_NODE_ID, + expected_name="workflow_id", + ), + TestCase( + name="name without `sys.` prefix should return the original input node_id", + input_node_id=_NODE_ID, + input_name="start_input", + expected_node_id=_NODE_ID, + expected_name="start_input", + ), + TestCase( + name="dummy_variable should return the original input node_id", + input_node_id=_NODE_ID, + input_name="__dummy__", + expected_node_id=_NODE_ID, + expected_name="__dummy__", + ), + ] + for idx, c in enumerate(cases, 1): + fail_msg = f"Test case {c.name} failed, index={idx}" + node_id, name = _DraftVariableBuilder._normalize_variable_for_start_node(c.input_node_id, c.input_name) + assert node_id == c.expected_node_id, fail_msg + assert name == c.expected_name, fail_msg