mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-16 04:15:57 +08:00
Merge branch 'feat/variable-pool-rebased' into deploy/dev
This commit is contained in:
commit
a7bfe67797
@ -75,6 +75,20 @@ class StringSegment(Segment):
|
|||||||
class FloatSegment(Segment):
|
class FloatSegment(Segment):
|
||||||
value_type: SegmentType = SegmentType.NUMBER
|
value_type: SegmentType = SegmentType.NUMBER
|
||||||
value: float
|
value: float
|
||||||
|
# NOTE(QuantumGhost): seems that the equality for FloatSegment with `NaN` value has some problems.
|
||||||
|
# The following tests cannot pass.
|
||||||
|
#
|
||||||
|
# def test_float_segment_and_nan():
|
||||||
|
# nan = float("nan")
|
||||||
|
# assert nan != nan
|
||||||
|
#
|
||||||
|
# f1 = FloatSegment(value=float("nan"))
|
||||||
|
# f2 = FloatSegment(value=float("nan"))
|
||||||
|
# assert f1 != f2
|
||||||
|
#
|
||||||
|
# f3 = FloatSegment(value=nan)
|
||||||
|
# f4 = FloatSegment(value=nan)
|
||||||
|
# assert f3 != f4
|
||||||
|
|
||||||
|
|
||||||
class IntegerSegment(Segment):
|
class IntegerSegment(Segment):
|
||||||
|
@ -96,7 +96,7 @@ class VariablePool(BaseModel):
|
|||||||
|
|
||||||
if isinstance(value, Variable):
|
if isinstance(value, Variable):
|
||||||
variable = value
|
variable = value
|
||||||
if isinstance(value, Segment):
|
elif isinstance(value, Segment):
|
||||||
variable = variable_factory.segment_to_variable(segment=value, selector=selector)
|
variable = variable_factory.segment_to_variable(segment=value, selector=selector)
|
||||||
else:
|
else:
|
||||||
segment = variable_factory.build_segment(value)
|
segment = variable_factory.build_segment(value)
|
||||||
|
@ -90,8 +90,28 @@ class BaseNode(Generic[GenericNodeData]):
|
|||||||
graph_config: Mapping[str, Any],
|
graph_config: Mapping[str, Any],
|
||||||
config: Mapping[str, Any],
|
config: Mapping[str, Any],
|
||||||
) -> Mapping[str, Sequence[str]]:
|
) -> Mapping[str, Sequence[str]]:
|
||||||
"""
|
"""Extracts references variable selectors from node configuration.
|
||||||
Extract variable selector to variable mapping
|
|
||||||
|
The `config` parameter represents the configuration for a specific node type and corresponds
|
||||||
|
to the `data` field in the node definition object.
|
||||||
|
|
||||||
|
The returned mapping has the following structure:
|
||||||
|
|
||||||
|
{'1747829548239.#1747829667553.result#': ['1747829667553', 'result']}
|
||||||
|
|
||||||
|
Here, the key consists of two parts: the current node ID (provided as the `node_id`
|
||||||
|
parameter to `_extract_variable_selector_to_variable_mapping`) and the variable selector,
|
||||||
|
enclosed in `#` symbols. These two parts are separated by a dot (`.`).
|
||||||
|
|
||||||
|
The value is a list of string representing the variable selector, where the first element is the node ID
|
||||||
|
of the referenced variable, and the second element is the variable name within that node.
|
||||||
|
|
||||||
|
The meaning of the above response is:
|
||||||
|
|
||||||
|
The node with ID `1747829548239` references the variable `result` from the node with
|
||||||
|
ID `1747829667553`. For example, if `1747829548239` is a LLM node, its prompt may contain a
|
||||||
|
reference to the `result` output variable of node `1747829667553`.
|
||||||
|
|
||||||
:param graph_config: graph config
|
:param graph_config: graph config
|
||||||
:param config: node config
|
:param config: node config
|
||||||
:return:
|
:return:
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
from typing import Literal
|
from collections.abc import Mapping, Sequence
|
||||||
|
from typing import Any, Literal
|
||||||
|
|
||||||
from typing_extensions import deprecated
|
from typing_extensions import deprecated
|
||||||
|
|
||||||
@ -91,6 +92,22 @@ class IfElseNode(BaseNode[IfElseNodeData]):
|
|||||||
|
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _extract_variable_selector_to_variable_mapping(
|
||||||
|
cls,
|
||||||
|
*,
|
||||||
|
graph_config: Mapping[str, Any],
|
||||||
|
node_id: str,
|
||||||
|
node_data: IfElseNodeData,
|
||||||
|
) -> Mapping[str, Sequence[str]]:
|
||||||
|
var_mapping: dict[str, list[str]] = {}
|
||||||
|
for case in node_data.cases or []:
|
||||||
|
for condition in case.conditions:
|
||||||
|
key = "{}.#{}#".format(node_id, ".".join(condition.variable_selector))
|
||||||
|
var_mapping[key] = condition.variable_selector
|
||||||
|
|
||||||
|
return var_mapping
|
||||||
|
|
||||||
|
|
||||||
@deprecated("This function is deprecated. You should use the new cases structure.")
|
@deprecated("This function is deprecated. You should use the new cases structure.")
|
||||||
def _should_not_use_old_function(
|
def _should_not_use_old_function(
|
||||||
|
38
api/core/workflow/variable_loader.py
Normal file
38
api/core/workflow/variable_loader.py
Normal file
@ -0,0 +1,38 @@
|
|||||||
|
import abc
|
||||||
|
from typing import Protocol
|
||||||
|
|
||||||
|
from core.variables import Variable
|
||||||
|
|
||||||
|
|
||||||
|
class VariableLoader(Protocol):
|
||||||
|
"""Interface for loading variables based on selectors.
|
||||||
|
|
||||||
|
A `VariableLoader` is responsible for retrieving additional variables required during the execution
|
||||||
|
of a single node, which are not provided as user inputs.
|
||||||
|
|
||||||
|
NOTE(QuantumGhost): Typically, all variables loaded by a `VariableLoader` should belong to the same
|
||||||
|
application and share the same `app_id`. However, this interface does not enforce that constraint,
|
||||||
|
and the `app_id` parameter is intentionally omitted from `load_variables` to achieve separation of
|
||||||
|
concern and allow for flexible implementations.
|
||||||
|
|
||||||
|
Implementations of `VariableLoader` should almost always have an `app_id` parameter in
|
||||||
|
their constructor.
|
||||||
|
|
||||||
|
TODO(QuantumGhost): this is a temporally workaround. If we can move the creation of node instance into
|
||||||
|
`WorkflowService.single_step_run`, we may get rid of this interface.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def load_variables(self, selectors: list[list[str]]) -> list[Variable]:
|
||||||
|
"""Load variables based on the provided selectors. If the selectors are empty,
|
||||||
|
this method should return an empty list.
|
||||||
|
|
||||||
|
The order of the returned variables is not guaranteed. If the caller wants to ensure
|
||||||
|
a specific order, they should sort the returned list themselves.
|
||||||
|
|
||||||
|
:param: selectors: a list of string list, each inner list should have at least two elements:
|
||||||
|
- the first element is the node ID,
|
||||||
|
- the second element is the variable name.
|
||||||
|
:return: a list of Variable objects that match the provided selectors.
|
||||||
|
"""
|
||||||
|
pass
|
@ -2,12 +2,13 @@ import logging
|
|||||||
import time
|
import time
|
||||||
import uuid
|
import uuid
|
||||||
from collections.abc import Callable, Generator, Mapping, Sequence
|
from collections.abc import Callable, Generator, Mapping, Sequence
|
||||||
from typing import Any, Optional, TypeAlias, TypeVar, cast
|
from typing import Any, Optional, TypeAlias, cast
|
||||||
|
|
||||||
from configs import dify_config
|
from configs import dify_config
|
||||||
from core.app.apps.base_app_queue_manager import GenerateTaskStoppedError
|
from core.app.apps.base_app_queue_manager import GenerateTaskStoppedError
|
||||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||||
from core.file.models import File
|
from core.file.models import File
|
||||||
|
from core.variables import Variable
|
||||||
from core.workflow.callbacks import WorkflowCallback
|
from core.workflow.callbacks import WorkflowCallback
|
||||||
from core.workflow.constants import ENVIRONMENT_VARIABLE_NODE_ID
|
from core.workflow.constants import ENVIRONMENT_VARIABLE_NODE_ID
|
||||||
from core.workflow.entities.node_entities import NodeRunMetadataKey
|
from core.workflow.entities.node_entities import NodeRunMetadataKey
|
||||||
@ -22,13 +23,27 @@ from core.workflow.nodes import NodeType
|
|||||||
from core.workflow.nodes.base import BaseNode
|
from core.workflow.nodes.base import BaseNode
|
||||||
from core.workflow.nodes.event import NodeEvent, RunCompletedEvent
|
from core.workflow.nodes.event import NodeEvent, RunCompletedEvent
|
||||||
from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING
|
from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING
|
||||||
|
from core.workflow.variable_loader import VariableLoader
|
||||||
from factories import file_factory
|
from factories import file_factory
|
||||||
|
from libs import gen_utils
|
||||||
from models.enums import UserFrom
|
from models.enums import UserFrom
|
||||||
from models.workflow import (
|
from models.workflow import (
|
||||||
Workflow,
|
Workflow,
|
||||||
WorkflowType,
|
WorkflowType,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class _DummyVariableLoader(VariableLoader):
|
||||||
|
"""A dummy implementation of VariableLoader that does not load any variables.
|
||||||
|
Serves as a placeholder when no variable loading is needed.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def load_variables(self, selectors: list[list[str]]) -> list[Variable]:
|
||||||
|
return []
|
||||||
|
|
||||||
|
|
||||||
|
_DUMMY_VARIABLE_LOADER = _DummyVariableLoader()
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@ -122,6 +137,7 @@ class WorkflowEntry:
|
|||||||
user_id: str,
|
user_id: str,
|
||||||
user_inputs: dict,
|
user_inputs: dict,
|
||||||
conversation_variables: dict | None = None,
|
conversation_variables: dict | None = None,
|
||||||
|
variable_loader: VariableLoader = _DUMMY_VARIABLE_LOADER,
|
||||||
) -> tuple[BaseNode, Generator[NodeEvent | InNodeEvent, None, None]]:
|
) -> tuple[BaseNode, Generator[NodeEvent | InNodeEvent, None, None]]:
|
||||||
"""
|
"""
|
||||||
Single step run workflow node
|
Single step run workflow node
|
||||||
@ -190,6 +206,19 @@ class WorkflowEntry:
|
|||||||
except NotImplementedError:
|
except NotImplementedError:
|
||||||
variable_mapping = {}
|
variable_mapping = {}
|
||||||
|
|
||||||
|
# Loading missing variable from draft var here, and set it into
|
||||||
|
# variable_pool.
|
||||||
|
variables_to_load: list[list[str]] = []
|
||||||
|
for key, selector in variable_mapping.items():
|
||||||
|
trimmed_key = key.removeprefix(f"{node_id}.")
|
||||||
|
if trimmed_key in user_inputs:
|
||||||
|
continue
|
||||||
|
if variable_pool.get(selector) is None:
|
||||||
|
variables_to_load.append(list(selector))
|
||||||
|
loaded = variable_loader.load_variables(variables_to_load)
|
||||||
|
for var in loaded:
|
||||||
|
variable_pool.add(var.selector, var.value)
|
||||||
|
|
||||||
cls.mapping_user_inputs_to_variable_pool(
|
cls.mapping_user_inputs_to_variable_pool(
|
||||||
variable_mapping=variable_mapping,
|
variable_mapping=variable_mapping,
|
||||||
user_inputs=user_inputs,
|
user_inputs=user_inputs,
|
||||||
@ -204,7 +233,7 @@ class WorkflowEntry:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise WorkflowNodeRunFailedError(node_instance=node_instance, error=str(e))
|
raise WorkflowNodeRunFailedError(node_instance=node_instance, error=str(e))
|
||||||
if metadata_attacher:
|
if metadata_attacher:
|
||||||
generator = _wrap_generator(generator, metadata_attacher)
|
generator = gen_utils.map_(generator, metadata_attacher)
|
||||||
return node_instance, generator
|
return node_instance, generator
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -391,18 +420,6 @@ class WorkflowEntry:
|
|||||||
variable_pool.add([variable_node_id] + variable_key_list, input_value)
|
variable_pool.add([variable_node_id] + variable_key_list, input_value)
|
||||||
|
|
||||||
|
|
||||||
_YieldT_co = TypeVar("_YieldT_co", covariant=True)
|
|
||||||
_YieldR_co = TypeVar("_YieldR_co", covariant=True)
|
|
||||||
|
|
||||||
|
|
||||||
def _wrap_generator(
|
|
||||||
gen: Generator[_YieldT_co, None, None],
|
|
||||||
mapper: Callable[[_YieldT_co], _YieldR_co],
|
|
||||||
) -> Generator[_YieldR_co, None, None]:
|
|
||||||
for item in gen:
|
|
||||||
yield mapper(item)
|
|
||||||
|
|
||||||
|
|
||||||
_NodeOrInNodeEvent: TypeAlias = NodeEvent | InNodeEvent
|
_NodeOrInNodeEvent: TypeAlias = NodeEvent | InNodeEvent
|
||||||
|
|
||||||
|
|
||||||
|
53
api/libs/gen_utils.py
Normal file
53
api/libs/gen_utils.py
Normal file
@ -0,0 +1,53 @@
|
|||||||
|
"""Utility functions for working with generators."""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from collections.abc import Callable, Generator
|
||||||
|
from inspect import isgenerator
|
||||||
|
from typing import TypeVar
|
||||||
|
|
||||||
|
_YieldT = TypeVar("_YieldT")
|
||||||
|
_YieldR = TypeVar("_YieldR")
|
||||||
|
|
||||||
|
_T = TypeVar("_T")
|
||||||
|
|
||||||
|
|
||||||
|
def inspect(gen_or_normal: _T, logger: logging.Logger) -> _T:
|
||||||
|
if not isgenerator(gen_or_normal):
|
||||||
|
return gen_or_normal
|
||||||
|
|
||||||
|
def wrapper():
|
||||||
|
for item in gen_or_normal:
|
||||||
|
logger.info(
|
||||||
|
"received generator item, type=%s, value=%s",
|
||||||
|
type(item),
|
||||||
|
item,
|
||||||
|
)
|
||||||
|
yield item
|
||||||
|
|
||||||
|
return wrapper()
|
||||||
|
|
||||||
|
|
||||||
|
def map_(
|
||||||
|
gen: Generator[_YieldT, None, None],
|
||||||
|
mapper: Callable[[_YieldT], _YieldR],
|
||||||
|
) -> Generator[_YieldR, None, None]:
|
||||||
|
for item in gen:
|
||||||
|
yield mapper(item)
|
||||||
|
|
||||||
|
|
||||||
|
def filter_(
|
||||||
|
gen: Generator[_YieldT, None, None],
|
||||||
|
mapper: Callable[[_YieldT], bool],
|
||||||
|
) -> Generator[_YieldT, None, None]:
|
||||||
|
for item in gen:
|
||||||
|
if mapper(item):
|
||||||
|
yield item
|
||||||
|
|
||||||
|
|
||||||
|
def wrap(
|
||||||
|
gen: Generator[_YieldT, None, None],
|
||||||
|
funcs: list[Callable[[Generator[_YieldT, None, None]], Generator[_YieldT, None, None]]],
|
||||||
|
) -> Generator[_YieldT, None, None]:
|
||||||
|
for f in funcs:
|
||||||
|
gen = f(gen)
|
||||||
|
return gen
|
@ -3,17 +3,19 @@ import logging
|
|||||||
from collections.abc import Mapping, Sequence
|
from collections.abc import Mapping, Sequence
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from sqlalchemy import orm
|
from sqlalchemy import Engine, orm
|
||||||
from sqlalchemy.dialects.postgresql import insert
|
from sqlalchemy.dialects.postgresql import insert
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
from sqlalchemy.sql.expression import and_, or_
|
||||||
|
|
||||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||||
from core.file.constants import is_dummy_output_variable
|
from core.variables import Segment, Variable
|
||||||
from core.variables import Segment
|
|
||||||
from core.variables.consts import MIN_SELECTORS_LENGTH
|
from core.variables.consts import MIN_SELECTORS_LENGTH
|
||||||
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, ENVIRONMENT_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
|
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, ENVIRONMENT_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
|
||||||
from core.workflow.nodes import NodeType
|
from core.workflow.nodes import NodeType
|
||||||
|
from core.workflow.variable_loader import VariableLoader
|
||||||
from factories import variable_factory
|
from factories import variable_factory
|
||||||
|
from factories.variable_factory import build_segment, segment_to_variable
|
||||||
from models.workflow import WorkflowDraftVariable, is_system_variable_editable
|
from models.workflow import WorkflowDraftVariable, is_system_variable_editable
|
||||||
|
|
||||||
_logger = logging.getLogger(__name__)
|
_logger = logging.getLogger(__name__)
|
||||||
@ -25,6 +27,36 @@ class WorkflowDraftVariableList:
|
|||||||
total: int | None = None
|
total: int | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class DraftVarLoader(VariableLoader):
|
||||||
|
# This implements the VariableLoader interface for loading draft variables.
|
||||||
|
#
|
||||||
|
# ref: core.workflow.variable_loader.VariableLoader
|
||||||
|
def __init__(self, engine: Engine, app_id: str) -> None:
|
||||||
|
self._engine = engine
|
||||||
|
self._app_id = app_id
|
||||||
|
|
||||||
|
def load_variables(self, selectors: list[list[str]]) -> list[Variable]:
|
||||||
|
if not selectors:
|
||||||
|
return []
|
||||||
|
with Session(bind=self._engine, expire_on_commit=False) as session:
|
||||||
|
srv = WorkflowDraftVariableService(session)
|
||||||
|
draft_vars = srv.get_draft_variables_by_selectors(self._app_id, selectors)
|
||||||
|
variables = []
|
||||||
|
for draft_var in draft_vars:
|
||||||
|
segment = build_segment(
|
||||||
|
draft_var.value,
|
||||||
|
)
|
||||||
|
variable = segment_to_variable(
|
||||||
|
segment=segment,
|
||||||
|
selector=draft_var.get_selector(),
|
||||||
|
id=draft_var.id,
|
||||||
|
name=draft_var.name,
|
||||||
|
description=draft_var.description,
|
||||||
|
)
|
||||||
|
variables.append(variable)
|
||||||
|
return variables
|
||||||
|
|
||||||
|
|
||||||
class WorkflowDraftVariableService:
|
class WorkflowDraftVariableService:
|
||||||
_session: Session
|
_session: Session
|
||||||
|
|
||||||
@ -34,6 +66,28 @@ class WorkflowDraftVariableService:
|
|||||||
def get_variable(self, variable_id: str) -> WorkflowDraftVariable | None:
|
def get_variable(self, variable_id: str) -> WorkflowDraftVariable | None:
|
||||||
return self._session.query(WorkflowDraftVariable).filter(WorkflowDraftVariable.id == variable_id).first()
|
return self._session.query(WorkflowDraftVariable).filter(WorkflowDraftVariable.id == variable_id).first()
|
||||||
|
|
||||||
|
def get_draft_variables_by_selectors(
|
||||||
|
self,
|
||||||
|
app_id: str,
|
||||||
|
selectors: Sequence[list[str]],
|
||||||
|
) -> list[WorkflowDraftVariable]:
|
||||||
|
ors = []
|
||||||
|
for selector in selectors:
|
||||||
|
node_id, name = selector
|
||||||
|
ors.append(and_(WorkflowDraftVariable.node_id == node_id, WorkflowDraftVariable.name == name))
|
||||||
|
|
||||||
|
# NOTE(QuantumGhost): Although the number of `or` expressions may be large, as long as
|
||||||
|
# each expression includes conditions on both `node_id` and `name` (which are covered by the unique index),
|
||||||
|
# PostgreSQL can efficiently retrieve the results using a bitmap index scan.
|
||||||
|
#
|
||||||
|
# Alternatively, a `SELECT` statement could be constructed for each selector and
|
||||||
|
# combined using `UNION` to fetch all rows.
|
||||||
|
# Benchmarking indicates that both approaches yield comparable performance.
|
||||||
|
variables = (
|
||||||
|
self._session.query(WorkflowDraftVariable).where(WorkflowDraftVariable.app_id == app_id, or_(*ors)).all()
|
||||||
|
)
|
||||||
|
return variables
|
||||||
|
|
||||||
def save_output_variables(self, app_id: str, node_id: str, node_type: NodeType, output: Mapping[str, Any]):
|
def save_output_variables(self, app_id: str, node_id: str, node_type: NodeType, output: Mapping[str, Any]):
|
||||||
variable_builder = _DraftVariableBuilder(app_id=app_id)
|
variable_builder = _DraftVariableBuilder(app_id=app_id)
|
||||||
variable_builder.build(node_id=node_id, node_type=node_type, output=output)
|
variable_builder.build(node_id=node_id, node_type=node_type, output=output)
|
||||||
@ -42,29 +96,25 @@ class WorkflowDraftVariableService:
|
|||||||
if not draft_variables:
|
if not draft_variables:
|
||||||
return
|
return
|
||||||
|
|
||||||
# We may use SQLAlchemy ORM operation here. However, considering the fact that:
|
# Although we could use SQLAlchemy ORM operations here, we choose not to for several reasons:
|
||||||
#
|
#
|
||||||
# 1. The variable saving process writes multiple rows into one table (`workflow_draft_variables`).
|
# 1. The variable saving process involves writing multiple rows to the
|
||||||
# Use batch insertion may increase performance dramatically.
|
# `workflow_draft_variables` table. Batch insertion significantly improves performance.
|
||||||
# 2. If we use ORM operation, we need to either:
|
# 2. Using the ORM would require either:
|
||||||
#
|
#
|
||||||
# a. Check the existence for each variable before insertion.
|
# a. Checking for the existence of each variable before insertion,
|
||||||
# b. Try insertion first, then do update if insertion fails due to unique index violation.
|
# resulting in 2n SQL statements for n variables and potential concurrency issues.
|
||||||
|
# b. Attempting insertion first, then updating if a unique index violation occurs,
|
||||||
|
# which still results in n to 2n SQL statements.
|
||||||
#
|
#
|
||||||
# Neither of the above is satisfactory.
|
# Both approaches are inefficient and suboptimal.
|
||||||
|
# 3. We do not need to retrieve the results of the SQL execution or populate ORM
|
||||||
|
# model instances with the returned values.
|
||||||
|
# 4. Batch insertion with `ON CONFLICT DO UPDATE` allows us to insert or update all
|
||||||
|
# variables in a single SQL statement, avoiding the issues above.
|
||||||
#
|
#
|
||||||
# - For implementation "a", we need to issue `2n` sqls for `n` variables in output.
|
# For these reasons, we use the SQLAlchemy query builder and rely on dialect-specific
|
||||||
# Besides, it's still suffer from concurrency issues.
|
# insert operations instead of the ORM layer.
|
||||||
# - For implementation "b", we need to issue `n` - `2n` sqls (depending on the existence of
|
|
||||||
# specific variable), which is lesser than plan "a" but still far from ideal.
|
|
||||||
#
|
|
||||||
# 3. We do not need the value of SQL execution, nor do we need populate those values back to ORM model
|
|
||||||
# instances.
|
|
||||||
# 4. Batch insertion can be combined with `ON CONFLICT DO UPDATE`, allows us to insert or update
|
|
||||||
# all variables in one SQL statement, and avoid all problems above.
|
|
||||||
#
|
|
||||||
# Given reasons above, we use query builder instead of using ORM layer,
|
|
||||||
# and rely on dialect specific insert operations.
|
|
||||||
if node_type == NodeType.CODE:
|
if node_type == NodeType.CODE:
|
||||||
# Clear existing variable for code node.
|
# Clear existing variable for code node.
|
||||||
self._session.query(WorkflowDraftVariable).filter(
|
self._session.query(WorkflowDraftVariable).filter(
|
||||||
@ -213,26 +263,6 @@ def should_save_output_variables_for_draft(
|
|||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
# def should_save_output_variables_for_draft(invoke_from: InvokeFrom, node_exec: WorkflowNodeExecution) -> bool:
|
|
||||||
# # Only save output variables for debugging execution of workflow.
|
|
||||||
# if invoke_from != InvokeFrom.DEBUGGER:
|
|
||||||
# return False
|
|
||||||
# exec_metadata = node_exec.execution_metadata_dict
|
|
||||||
# if exec_metadata is None:
|
|
||||||
# # No execution metadata, assume the node is not in loop or iteration.
|
|
||||||
# return True
|
|
||||||
#
|
|
||||||
# # Currently we do not save output variables for nodes inside loop or iteration.
|
|
||||||
# loop_id = exec_metadata.get(NodeRunMetadataKey.LOOP_ID)
|
|
||||||
# if loop_id is not None:
|
|
||||||
# return False
|
|
||||||
# iteration_id = exec_metadata.get(NodeRunMetadataKey.ITERATION_ID)
|
|
||||||
# if iteration_id is not None:
|
|
||||||
# return False
|
|
||||||
# return True
|
|
||||||
#
|
|
||||||
|
|
||||||
|
|
||||||
class _DraftVariableBuilder:
|
class _DraftVariableBuilder:
|
||||||
_app_id: str
|
_app_id: str
|
||||||
_draft_vars: list[WorkflowDraftVariable]
|
_draft_vars: list[WorkflowDraftVariable]
|
||||||
@ -280,7 +310,8 @@ class _DraftVariableBuilder:
|
|||||||
original_node_id = node_id
|
original_node_id = node_id
|
||||||
for name, value in output.items():
|
for name, value in output.items():
|
||||||
value_seg = variable_factory.build_segment(value)
|
value_seg = variable_factory.build_segment(value)
|
||||||
if is_dummy_output_variable(name):
|
node_id, name = self._normalize_variable_for_start_node(node_id, name)
|
||||||
|
if node_id != SYSTEM_VARIABLE_NODE_ID:
|
||||||
self._draft_vars.append(
|
self._draft_vars.append(
|
||||||
WorkflowDraftVariable.new_node_variable(
|
WorkflowDraftVariable.new_node_variable(
|
||||||
app_id=self._app_id,
|
app_id=self._app_id,
|
||||||
@ -302,20 +333,9 @@ class _DraftVariableBuilder:
|
|||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _normalize_variable_for_start_node(node_type: NodeType, node_id: str, name: str):
|
def _normalize_variable_for_start_node(node_id: str, name: str) -> tuple[str, str]:
|
||||||
if node_type != NodeType.START:
|
|
||||||
return node_id, name
|
|
||||||
|
|
||||||
# TODO(QuantumGhost): need special handling for dummy output variable in
|
|
||||||
# `Start` node.
|
|
||||||
if not name.startswith(f"{SYSTEM_VARIABLE_NODE_ID}."):
|
if not name.startswith(f"{SYSTEM_VARIABLE_NODE_ID}."):
|
||||||
return node_id, name
|
return node_id, name
|
||||||
_logger.debug(
|
|
||||||
"Normalizing variable: node_type=%s, node_id=%s, name=%s",
|
|
||||||
node_type,
|
|
||||||
node_id,
|
|
||||||
name,
|
|
||||||
)
|
|
||||||
node_id, name_ = name.split(".", maxsplit=1)
|
node_id, name_ = name.split(".", maxsplit=1)
|
||||||
return node_id, name_
|
return node_id, name_
|
||||||
|
|
||||||
|
@ -3,7 +3,6 @@ import logging
|
|||||||
import time
|
import time
|
||||||
from collections.abc import Callable, Generator, Sequence
|
from collections.abc import Callable, Generator, Sequence
|
||||||
from datetime import UTC, datetime
|
from datetime import UTC, datetime
|
||||||
from inspect import isgenerator
|
|
||||||
from typing import Any, Optional
|
from typing import Any, Optional
|
||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
|
|
||||||
@ -28,6 +27,7 @@ from core.workflow.nodes.node_mapping import LATEST_VERSION, NODE_TYPE_CLASSES_M
|
|||||||
from core.workflow.workflow_entry import WorkflowEntry
|
from core.workflow.workflow_entry import WorkflowEntry
|
||||||
from events.app_event import app_draft_workflow_was_synced, app_published_workflow_was_updated
|
from events.app_event import app_draft_workflow_was_synced, app_published_workflow_was_updated
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
|
from libs import gen_utils
|
||||||
from models.account import Account
|
from models.account import Account
|
||||||
from models.model import App, AppMode
|
from models.model import App, AppMode
|
||||||
from models.tools import WorkflowToolProvider
|
from models.tools import WorkflowToolProvider
|
||||||
@ -42,7 +42,11 @@ from services.errors.app import IsDraftWorkflowError, WorkflowHashNotEqualError
|
|||||||
from services.workflow.workflow_converter import WorkflowConverter
|
from services.workflow.workflow_converter import WorkflowConverter
|
||||||
|
|
||||||
from .errors.workflow_service import DraftWorkflowDeletionError, WorkflowInUseError
|
from .errors.workflow_service import DraftWorkflowDeletionError, WorkflowInUseError
|
||||||
from .workflow_draft_variable_service import WorkflowDraftVariableService, should_save_output_variables_for_draft
|
from .workflow_draft_variable_service import (
|
||||||
|
DraftVarLoader,
|
||||||
|
WorkflowDraftVariableService,
|
||||||
|
should_save_output_variables_for_draft,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class WorkflowService:
|
class WorkflowService:
|
||||||
@ -310,26 +314,28 @@ class WorkflowService:
|
|||||||
if not draft_workflow:
|
if not draft_workflow:
|
||||||
raise ValueError("Workflow not initialized")
|
raise ValueError("Workflow not initialized")
|
||||||
|
|
||||||
# conv_vars = common_helpers.get_conversation_variables()
|
# TODO(QuantumGhost): We may get rid of the `list_conversation_variables`
|
||||||
|
# here, and rely on `DraftVarLoader` to load conversation variables.
|
||||||
# run draft workflow node
|
|
||||||
start_at = time.perf_counter()
|
|
||||||
with Session(bind=db.engine) as session:
|
with Session(bind=db.engine) as session:
|
||||||
# TODO(QunatumGhost): inject conversation variables
|
|
||||||
# to variable pool.
|
|
||||||
draft_var_srv = WorkflowDraftVariableService(session)
|
draft_var_srv = WorkflowDraftVariableService(session)
|
||||||
|
|
||||||
conv_vars_list = draft_var_srv.list_conversation_variables(app_id=app_model.id)
|
conv_vars_list = draft_var_srv.list_conversation_variables(app_id=app_model.id)
|
||||||
conv_var_mapping = {v.name: v.get_value().value for v in conv_vars_list.variables}
|
conv_var_mapping = {v.name: v.get_value().value for v in conv_vars_list.variables}
|
||||||
|
|
||||||
node_execution = self._handle_node_run_result(
|
variable_loader = DraftVarLoader(engine=db.engine, app_id=app_model.id)
|
||||||
invoke_node_fn=lambda: WorkflowEntry.single_step_run(
|
run = WorkflowEntry.single_step_run(
|
||||||
workflow=draft_workflow,
|
workflow=draft_workflow,
|
||||||
node_id=node_id,
|
node_id=node_id,
|
||||||
user_inputs=user_inputs,
|
user_inputs=user_inputs,
|
||||||
user_id=account.id,
|
user_id=account.id,
|
||||||
conversation_variables=conv_var_mapping,
|
conversation_variables=conv_var_mapping,
|
||||||
),
|
variable_loader=variable_loader,
|
||||||
|
)
|
||||||
|
|
||||||
|
# run draft workflow node
|
||||||
|
start_at = time.perf_counter()
|
||||||
|
node_execution = self._handle_node_run_result(
|
||||||
|
invoke_node_fn=lambda: run,
|
||||||
start_at=start_at,
|
start_at=start_at,
|
||||||
node_id=node_id,
|
node_id=node_id,
|
||||||
)
|
)
|
||||||
@ -403,7 +409,7 @@ class WorkflowService:
|
|||||||
) -> NodeExecution:
|
) -> NodeExecution:
|
||||||
try:
|
try:
|
||||||
node_instance, generator = invoke_node_fn()
|
node_instance, generator = invoke_node_fn()
|
||||||
generator = _inspect_generator(generator)
|
generator = gen_utils.inspect(generator, logging.getLogger(__name__))
|
||||||
|
|
||||||
node_run_result: NodeRunResult | None = None
|
node_run_result: NodeRunResult | None = None
|
||||||
for event in generator:
|
for event in generator:
|
||||||
@ -610,19 +616,3 @@ class WorkflowService:
|
|||||||
|
|
||||||
session.delete(workflow)
|
session.delete(workflow)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
def _inspect_generator(gen: Generator[Any] | Any) -> Any:
|
|
||||||
if not isgenerator(gen):
|
|
||||||
return gen
|
|
||||||
|
|
||||||
def wrapper():
|
|
||||||
for item in gen:
|
|
||||||
logging.getLogger(__name__).info(
|
|
||||||
"received generator item, type=%s, value=%s",
|
|
||||||
type(item),
|
|
||||||
item,
|
|
||||||
)
|
|
||||||
yield item
|
|
||||||
|
|
||||||
return wrapper()
|
|
||||||
|
@ -4,40 +4,42 @@ import uuid
|
|||||||
import pytest
|
import pytest
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
|
||||||
from factories.variable_factory import build_segment
|
from factories.variable_factory import build_segment
|
||||||
from models import db
|
from models import db
|
||||||
from models.workflow import WorkflowDraftVariable
|
from models.workflow import WorkflowDraftVariable
|
||||||
from services.workflow_draft_variable_service import WorkflowDraftVariableService
|
from services.workflow_draft_variable_service import DraftVarLoader, WorkflowDraftVariableService
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.usefixtures("flask_req_ctx")
|
@pytest.mark.usefixtures("flask_req_ctx")
|
||||||
class TestWorkflowDraftVariableService(unittest.TestCase):
|
class TestWorkflowDraftVariableService(unittest.TestCase):
|
||||||
_test_app_id: str
|
_test_app_id: str
|
||||||
_session: Session
|
_session: Session
|
||||||
|
_node1_id = "test_node_1"
|
||||||
_node2_id = "test_node_2"
|
_node2_id = "test_node_2"
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self._test_app_id = str(uuid.uuid4())
|
self._test_app_id = str(uuid.uuid4())
|
||||||
self._session: Session = db.session
|
self._session: Session = db.session
|
||||||
sys_var = WorkflowDraftVariable.create_sys_variable(
|
sys_var = WorkflowDraftVariable.new_sys_variable(
|
||||||
app_id=self._test_app_id,
|
app_id=self._test_app_id,
|
||||||
name="sys_var",
|
name="sys_var",
|
||||||
value=build_segment("sys_value"),
|
value=build_segment("sys_value"),
|
||||||
)
|
)
|
||||||
conv_var = WorkflowDraftVariable.create_conversation_variable(
|
conv_var = WorkflowDraftVariable.new_conversation_variable(
|
||||||
app_id=self._test_app_id,
|
app_id=self._test_app_id,
|
||||||
name="conv_var",
|
name="conv_var",
|
||||||
value=build_segment("conv_value"),
|
value=build_segment("conv_value"),
|
||||||
)
|
)
|
||||||
node2_vars = [
|
node2_vars = [
|
||||||
WorkflowDraftVariable.create_node_variable(
|
WorkflowDraftVariable.new_node_variable(
|
||||||
app_id=self._test_app_id,
|
app_id=self._test_app_id,
|
||||||
node_id=self._node2_id,
|
node_id=self._node2_id,
|
||||||
name="int_var",
|
name="int_var",
|
||||||
value=build_segment(1),
|
value=build_segment(1),
|
||||||
visible=False,
|
visible=False,
|
||||||
),
|
),
|
||||||
WorkflowDraftVariable.create_node_variable(
|
WorkflowDraftVariable.new_node_variable(
|
||||||
app_id=self._test_app_id,
|
app_id=self._test_app_id,
|
||||||
node_id=self._node2_id,
|
node_id=self._node2_id,
|
||||||
name="str_var",
|
name="str_var",
|
||||||
@ -45,9 +47,9 @@ class TestWorkflowDraftVariableService(unittest.TestCase):
|
|||||||
visible=True,
|
visible=True,
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
node1_var = WorkflowDraftVariable.create_node_variable(
|
node1_var = WorkflowDraftVariable.new_node_variable(
|
||||||
app_id=self._test_app_id,
|
app_id=self._test_app_id,
|
||||||
node_id="node_1",
|
node_id=self._node1_id,
|
||||||
name="str_var",
|
name="str_var",
|
||||||
value=build_segment("str_value"),
|
value=build_segment("str_value"),
|
||||||
visible=True,
|
visible=True,
|
||||||
@ -92,7 +94,7 @@ class TestWorkflowDraftVariableService(unittest.TestCase):
|
|||||||
|
|
||||||
def test_get_node_variable(self):
|
def test_get_node_variable(self):
|
||||||
srv = self._get_test_srv()
|
srv = self._get_test_srv()
|
||||||
node_var = srv.get_node_variable(self._test_app_id, "node_1", "str_var")
|
node_var = srv.get_node_variable(self._test_app_id, self._node1_id, "str_var")
|
||||||
assert node_var.id == self._node1_str_var_id
|
assert node_var.id == self._node1_str_var_id
|
||||||
assert node_var.name == "str_var"
|
assert node_var.name == "str_var"
|
||||||
assert node_var.get_value() == build_segment("str_value")
|
assert node_var.get_value() == build_segment("str_value")
|
||||||
@ -138,5 +140,86 @@ class TestWorkflowDraftVariableService(unittest.TestCase):
|
|||||||
def test__list_node_variables(self):
|
def test__list_node_variables(self):
|
||||||
srv = self._get_test_srv()
|
srv = self._get_test_srv()
|
||||||
node_vars = srv._list_node_variables(self._test_app_id, self._node2_id)
|
node_vars = srv._list_node_variables(self._test_app_id, self._node2_id)
|
||||||
assert len(node_vars) == 2
|
assert len(node_vars.variables) == 2
|
||||||
assert {v.id for v in node_vars} == set(self._node2_var_ids)
|
assert {v.id for v in node_vars.variables} == set(self._node2_var_ids)
|
||||||
|
|
||||||
|
def test_get_draft_variables_by_selectors(self):
|
||||||
|
srv = self._get_test_srv()
|
||||||
|
selectors = [
|
||||||
|
[self._node1_id, "str_var"],
|
||||||
|
[self._node2_id, "str_var"],
|
||||||
|
[self._node2_id, "int_var"],
|
||||||
|
]
|
||||||
|
variables = srv.get_draft_variables_by_selectors(self._test_app_id, selectors)
|
||||||
|
assert len(variables) == 3
|
||||||
|
assert {v.id for v in variables} == {self._node1_str_var_id} | set(self._node2_var_ids)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.usefixtures("flask_req_ctx")
|
||||||
|
class TestDraftVariableLoader(unittest.TestCase):
|
||||||
|
_test_app_id: str
|
||||||
|
|
||||||
|
_node1_id = "test_loader_node_1"
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
self._test_app_id = str(uuid.uuid4())
|
||||||
|
sys_var = WorkflowDraftVariable.new_sys_variable(
|
||||||
|
app_id=self._test_app_id,
|
||||||
|
name="sys_var",
|
||||||
|
value=build_segment("sys_value"),
|
||||||
|
)
|
||||||
|
conv_var = WorkflowDraftVariable.new_conversation_variable(
|
||||||
|
app_id=self._test_app_id,
|
||||||
|
name="conv_var",
|
||||||
|
value=build_segment("conv_value"),
|
||||||
|
)
|
||||||
|
node_var = WorkflowDraftVariable.new_node_variable(
|
||||||
|
app_id=self._test_app_id,
|
||||||
|
node_id=self._node1_id,
|
||||||
|
name="str_var",
|
||||||
|
value=build_segment("str_value"),
|
||||||
|
visible=True,
|
||||||
|
)
|
||||||
|
_variables = [
|
||||||
|
node_var,
|
||||||
|
sys_var,
|
||||||
|
conv_var,
|
||||||
|
]
|
||||||
|
|
||||||
|
with Session(bind=db.engine, expire_on_commit=False) as session:
|
||||||
|
session.add_all(_variables)
|
||||||
|
session.flush()
|
||||||
|
session.commit()
|
||||||
|
self._variable_ids = [v.id for v in _variables]
|
||||||
|
self._node_var_id = node_var.id
|
||||||
|
self._sys_var_id = sys_var.id
|
||||||
|
self._conv_var_id = conv_var.id
|
||||||
|
|
||||||
|
def tearDown(self):
|
||||||
|
with Session(bind=db.engine, expire_on_commit=False) as session:
|
||||||
|
session.query(WorkflowDraftVariable).filter(WorkflowDraftVariable.app_id == self._test_app_id).delete(
|
||||||
|
synchronize_session=False
|
||||||
|
)
|
||||||
|
session.commit()
|
||||||
|
|
||||||
|
def test_variable_loader_with_empty_selector(self):
|
||||||
|
var_loader = DraftVarLoader(engine=db.engine, app_id=self._test_app_id)
|
||||||
|
variables = var_loader.load_variables([])
|
||||||
|
assert len(variables) == 0
|
||||||
|
|
||||||
|
def test_variable_loader_with_non_empty_selector(self):
|
||||||
|
var_loader = DraftVarLoader(engine=db.engine, app_id=self._test_app_id)
|
||||||
|
variables = var_loader.load_variables(
|
||||||
|
[
|
||||||
|
[SYSTEM_VARIABLE_NODE_ID, "sys_var"],
|
||||||
|
[CONVERSATION_VARIABLE_NODE_ID, "conv_var"],
|
||||||
|
[self._node1_id, "str_var"],
|
||||||
|
]
|
||||||
|
)
|
||||||
|
assert len(variables) == 3
|
||||||
|
conv_var = next(v for v in variables if v.selector[0] == CONVERSATION_VARIABLE_NODE_ID)
|
||||||
|
assert conv_var.id == self._conv_var_id
|
||||||
|
sys_var = next(v for v in variables if v.selector[0] == SYSTEM_VARIABLE_NODE_ID)
|
||||||
|
assert sys_var.id == self._sys_var_id
|
||||||
|
node1_var = next(v for v in variables if v.selector[0] == self._node1_id)
|
||||||
|
assert node1_var.id == self._node_var_id
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
import math
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
|
|
||||||
@ -232,6 +233,10 @@ def _scalar_value() -> st.SearchStrategy[int | float | str | File]:
|
|||||||
@given(_scalar_value())
|
@given(_scalar_value())
|
||||||
def test_build_segment_and_extract_values_for_scalar_types(value):
|
def test_build_segment_and_extract_values_for_scalar_types(value):
|
||||||
seg = variable_factory.build_segment(value)
|
seg = variable_factory.build_segment(value)
|
||||||
|
# nan == nan yields false, so we need to use `math.isnan` to check `seg.value` here.
|
||||||
|
if isinstance(value, float) and math.isnan(value):
|
||||||
|
assert math.isnan(seg.value)
|
||||||
|
else:
|
||||||
assert seg.value == value
|
assert seg.value == value
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,22 +1,10 @@
|
|||||||
from core.variables import SecretVariable
|
import dataclasses
|
||||||
|
|
||||||
from core.workflow.entities.variable_entities import VariableSelector
|
from core.workflow.entities.variable_entities import VariableSelector
|
||||||
from core.workflow.entities.variable_pool import VariablePool
|
|
||||||
from core.workflow.enums import SystemVariableKey
|
|
||||||
from core.workflow.utils import variable_template_parser
|
from core.workflow.utils import variable_template_parser
|
||||||
|
|
||||||
|
|
||||||
def test_extract_selectors_from_template():
|
def test_extract_selectors_from_template():
|
||||||
variable_pool = VariablePool(
|
|
||||||
system_variables={
|
|
||||||
SystemVariableKey("user_id"): "fake-user-id",
|
|
||||||
},
|
|
||||||
user_inputs={},
|
|
||||||
environment_variables=[
|
|
||||||
SecretVariable(name="secret_key", value="fake-secret-key"),
|
|
||||||
],
|
|
||||||
conversation_variables=[],
|
|
||||||
)
|
|
||||||
variable_pool.add(("node_id", "custom_query"), "fake-user-query")
|
|
||||||
template = (
|
template = (
|
||||||
"Hello, {{#sys.user_id#}}! Your query is {{#node_id.custom_query#}}. And your key is {{#env.secret_key#}}."
|
"Hello, {{#sys.user_id#}}! Your query is {{#node_id.custom_query#}}. And your key is {{#env.secret_key#}}."
|
||||||
)
|
)
|
||||||
@ -26,3 +14,35 @@ def test_extract_selectors_from_template():
|
|||||||
VariableSelector(variable="#node_id.custom_query#", value_selector=["node_id", "custom_query"]),
|
VariableSelector(variable="#node_id.custom_query#", value_selector=["node_id", "custom_query"]),
|
||||||
VariableSelector(variable="#env.secret_key#", value_selector=["env", "secret_key"]),
|
VariableSelector(variable="#env.secret_key#", value_selector=["env", "secret_key"]),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def test_invalid_references():
|
||||||
|
@dataclasses.dataclass
|
||||||
|
class TestCase:
|
||||||
|
name: str
|
||||||
|
template: str
|
||||||
|
|
||||||
|
cases = [
|
||||||
|
TestCase(
|
||||||
|
name="lack of closing brace",
|
||||||
|
template="Hello, {{#sys.user_id#",
|
||||||
|
),
|
||||||
|
TestCase(
|
||||||
|
name="lack of opening brace",
|
||||||
|
template="Hello, #sys.user_id#}}",
|
||||||
|
),
|
||||||
|
TestCase(
|
||||||
|
name="lack selector name",
|
||||||
|
template="Hello, {{#sys#}}",
|
||||||
|
),
|
||||||
|
TestCase(
|
||||||
|
name="empty node name part",
|
||||||
|
template="Hello, {{#.user_id#}}",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
for idx, c in enumerate(cases, 1):
|
||||||
|
fail_msg = f"Test case {c.name} failed, index={idx}"
|
||||||
|
selectors = variable_template_parser.extract_selectors_from_template(c.template)
|
||||||
|
assert selectors == [], fail_msg
|
||||||
|
parser = variable_template_parser.VariableTemplateParser(c.template)
|
||||||
|
assert parser.extract_variable_selectors() == [], fail_msg
|
||||||
|
@ -5,7 +5,7 @@ from uuid import uuid4
|
|||||||
import contexts
|
import contexts
|
||||||
from constants import HIDDEN_VALUE
|
from constants import HIDDEN_VALUE
|
||||||
from core.variables import FloatVariable, IntegerVariable, SecretVariable, StringVariable
|
from core.variables import FloatVariable, IntegerVariable, SecretVariable, StringVariable
|
||||||
from models.workflow import Workflow, WorkflowNodeExecution
|
from models.workflow import Workflow, WorkflowNodeExecution, is_system_variable_editable
|
||||||
|
|
||||||
|
|
||||||
def test_environment_variables():
|
def test_environment_variables():
|
||||||
@ -149,3 +149,21 @@ class TestWorkflowNodeExecution:
|
|||||||
original = {"a": 1, "b": ["2"]}
|
original = {"a": 1, "b": ["2"]}
|
||||||
node_exec.execution_metadata = json.dumps(original)
|
node_exec.execution_metadata = json.dumps(original)
|
||||||
assert node_exec.execution_metadata_dict == original
|
assert node_exec.execution_metadata_dict == original
|
||||||
|
|
||||||
|
|
||||||
|
class TestIsSystemVariableEditable:
|
||||||
|
def test_is_system_variable(self):
|
||||||
|
cases = [
|
||||||
|
("query", True),
|
||||||
|
("files", True),
|
||||||
|
("dialogue_count", False),
|
||||||
|
("conversation_id", False),
|
||||||
|
("user_id", False),
|
||||||
|
("app_id", False),
|
||||||
|
("workflow_id", False),
|
||||||
|
("workflow_run_id", False),
|
||||||
|
]
|
||||||
|
for name, editable in cases:
|
||||||
|
assert editable == is_system_variable_editable(name)
|
||||||
|
|
||||||
|
assert is_system_variable_editable("invalid_or_new_system_variable") == False
|
||||||
|
@ -0,0 +1,82 @@
|
|||||||
|
import dataclasses
|
||||||
|
import secrets
|
||||||
|
|
||||||
|
from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID
|
||||||
|
from core.workflow.nodes import NodeType
|
||||||
|
from factories.variable_factory import build_segment
|
||||||
|
from models.workflow import WorkflowDraftVariable
|
||||||
|
from services.workflow_draft_variable_service import _DraftVariableBuilder
|
||||||
|
|
||||||
|
|
||||||
|
class TestDraftVariableBuilder:
|
||||||
|
def _get_test_app_id(self):
|
||||||
|
suffix = secrets.token_hex(6)
|
||||||
|
return f"test_app_id_{suffix}"
|
||||||
|
|
||||||
|
def test_get_variables(self):
|
||||||
|
test_app_id = self._get_test_app_id()
|
||||||
|
builder = _DraftVariableBuilder(app_id=test_app_id)
|
||||||
|
variables = [
|
||||||
|
WorkflowDraftVariable.new_node_variable(
|
||||||
|
app_id=test_app_id,
|
||||||
|
node_id="test_node_1",
|
||||||
|
name="test_var_1",
|
||||||
|
value=build_segment("test_value_1"),
|
||||||
|
visible=True,
|
||||||
|
),
|
||||||
|
WorkflowDraftVariable.new_sys_variable(
|
||||||
|
app_id=test_app_id,
|
||||||
|
name="test_sys_var",
|
||||||
|
value=build_segment("test_sys_value"),
|
||||||
|
),
|
||||||
|
WorkflowDraftVariable.new_conversation_variable(
|
||||||
|
app_id=test_app_id,
|
||||||
|
name="test_conv_var",
|
||||||
|
value=build_segment("test_conv_value"),
|
||||||
|
),
|
||||||
|
]
|
||||||
|
builder._draft_vars = variables
|
||||||
|
assert builder.get_variables() == variables
|
||||||
|
|
||||||
|
def test__should_variable_be_visible(self):
|
||||||
|
assert _DraftVariableBuilder._should_variable_be_visible(NodeType.IF_ELSE, "123_456", "output") == False
|
||||||
|
assert _DraftVariableBuilder._should_variable_be_visible(NodeType.START, "123", "output") == True
|
||||||
|
|
||||||
|
def test__normalize_variable_for_start_node(self):
|
||||||
|
@dataclasses.dataclass(frozen=True)
|
||||||
|
class TestCase:
|
||||||
|
name: str
|
||||||
|
input_node_id: str
|
||||||
|
input_name: str
|
||||||
|
expected_node_id: str
|
||||||
|
expected_name: str
|
||||||
|
|
||||||
|
_NODE_ID = "1747228642872"
|
||||||
|
cases = [
|
||||||
|
TestCase(
|
||||||
|
name="name with `sys.` prefix should return the system node_id",
|
||||||
|
input_node_id=_NODE_ID,
|
||||||
|
input_name="sys.workflow_id",
|
||||||
|
expected_node_id=SYSTEM_VARIABLE_NODE_ID,
|
||||||
|
expected_name="workflow_id",
|
||||||
|
),
|
||||||
|
TestCase(
|
||||||
|
name="name without `sys.` prefix should return the original input node_id",
|
||||||
|
input_node_id=_NODE_ID,
|
||||||
|
input_name="start_input",
|
||||||
|
expected_node_id=_NODE_ID,
|
||||||
|
expected_name="start_input",
|
||||||
|
),
|
||||||
|
TestCase(
|
||||||
|
name="dummy_variable should return the original input node_id",
|
||||||
|
input_node_id=_NODE_ID,
|
||||||
|
input_name="__dummy__",
|
||||||
|
expected_node_id=_NODE_ID,
|
||||||
|
expected_name="__dummy__",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
for idx, c in enumerate(cases, 1):
|
||||||
|
fail_msg = f"Test case {c.name} failed, index={idx}"
|
||||||
|
node_id, name = _DraftVariableBuilder._normalize_variable_for_start_node(c.input_node_id, c.input_name)
|
||||||
|
assert node_id == c.expected_node_id, fail_msg
|
||||||
|
assert name == c.expected_name, fail_msg
|
Loading…
x
Reference in New Issue
Block a user