feat(api): Introduce WorkflowDraftVariable Model (#19737)

- Introduce `WorkflowDraftVariable` model and the corresponding migration.
- Implement `EnumText`,  a custom column type for SQLAlchemy designed
  to work seamlessly with enumeration classes based on `StrEnum`.
This commit is contained in:
QuantumGhost 2025-05-19 22:59:56 +08:00 committed by GitHub
parent bbebf9ad3e
commit 6a9e0b1005
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 533 additions and 10 deletions

View File

@ -4,13 +4,14 @@ SQLAlchemy implementation of the WorkflowNodeExecutionRepository.
import json
import logging
from collections.abc import Sequence
from typing import Optional, Union
from collections.abc import Mapping, Sequence
from typing import Any, Optional, Union, cast
from sqlalchemy import UnaryExpression, asc, delete, desc, select
from sqlalchemy.engine import Engine
from sqlalchemy.orm import sessionmaker
from core.workflow.entities.node_entities import NodeRunMetadataKey
from core.workflow.entities.node_execution_entities import (
NodeExecution,
NodeExecutionStatus,
@ -122,7 +123,12 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository)
status=status,
error=db_model.error,
elapsed_time=db_model.elapsed_time,
metadata=metadata,
# FIXME(QuantumGhost): a temporary workaround for the following type check failure in Python 3.11.
# However, this problem is not occurred in Python 3.12.
#
# A case of this error is:
# https://github.com/langgenius/dify/actions/runs/15112698604/job/42475659482?pr=19737#step:9:24
metadata=cast(Mapping[NodeRunMetadataKey, Any] | None, metadata),
created_at=db_model.created_at,
finished_at=db_model.finished_at,
)

View File

@ -0,0 +1,7 @@
# The minimal selector length for valid variables.
#
# The first element of the selector is the node id, and the second element is the variable name.
#
# If the selector length is more than 2, the remaining parts are the keys / indexes paths used
# to extract part of the variable value.
MIN_SELECTORS_LENGTH = 2

View File

@ -0,0 +1,8 @@
from collections.abc import Iterable, Sequence
def to_selector(node_id: str, name: str, paths: Iterable[str] = ()) -> Sequence[str]:
selectors = [node_id, name]
if paths:
selectors.extend(paths)
return selectors

View File

@ -0,0 +1,51 @@
"""add WorkflowDraftVariable model
Revision ID: 2adcbe1f5dfb
Revises: d28f2004b072
Create Date: 2025-05-15 15:31:03.128680
"""
import sqlalchemy as sa
from alembic import op
import models as models
# revision identifiers, used by Alembic.
revision = "2adcbe1f5dfb"
down_revision = "d28f2004b072"
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.create_table(
"workflow_draft_variables",
sa.Column("id", models.types.StringUUID(), server_default=sa.text("uuid_generate_v4()"), nullable=False),
sa.Column("created_at", sa.DateTime(), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=False),
sa.Column("updated_at", sa.DateTime(), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=False),
sa.Column("app_id", models.types.StringUUID(), nullable=False),
sa.Column("last_edited_at", sa.DateTime(), nullable=True),
sa.Column("node_id", sa.String(length=255), nullable=False),
sa.Column("name", sa.String(length=255), nullable=False),
sa.Column("description", sa.String(length=255), nullable=False),
sa.Column("selector", sa.String(length=255), nullable=False),
sa.Column("value_type", sa.String(length=20), nullable=False),
sa.Column("value", sa.Text(), nullable=False),
sa.Column("visible", sa.Boolean(), nullable=False),
sa.Column("editable", sa.Boolean(), nullable=False),
sa.PrimaryKeyConstraint("id", name=op.f("workflow_draft_variables_pkey")),
sa.UniqueConstraint("app_id", "node_id", "name", name=op.f("workflow_draft_variables_app_id_key")),
)
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
# Dropping `workflow_draft_variables` also drops any index associated with it.
op.drop_table("workflow_draft_variables")
# ### end Alembic commands ###

View File

@ -14,3 +14,10 @@ class UserFrom(StrEnum):
class WorkflowRunTriggeredFrom(StrEnum):
DEBUGGING = "debugging"
APP_RUN = "app-run"
class DraftVariableType(StrEnum):
# node means that the correspond variable
NODE = "node"
SYS = "sys"
CONVERSATION = "conversation"

View File

@ -1,4 +1,7 @@
from sqlalchemy import CHAR, TypeDecorator
import enum
from typing import Generic, TypeVar
from sqlalchemy import CHAR, VARCHAR, TypeDecorator
from sqlalchemy.dialects.postgresql import UUID
@ -24,3 +27,51 @@ class StringUUID(TypeDecorator):
if value is None:
return value
return str(value)
_E = TypeVar("_E", bound=enum.StrEnum)
class EnumText(TypeDecorator, Generic[_E]):
impl = VARCHAR
cache_ok = True
_length: int
_enum_class: type[_E]
def __init__(self, enum_class: type[_E], length: int | None = None):
self._enum_class = enum_class
max_enum_value_len = max(len(e.value) for e in enum_class)
if length is not None:
if length < max_enum_value_len:
raise ValueError("length should be greater than enum value length.")
self._length = length
else:
# leave some rooms for future longer enum values.
self._length = max(max_enum_value_len, 20)
def process_bind_param(self, value: _E | str | None, dialect):
if value is None:
return value
if isinstance(value, self._enum_class):
return value.value
elif isinstance(value, str):
self._enum_class(value)
return value
else:
raise TypeError(f"expected str or {self._enum_class}, got {type(value)}")
def load_dialect_impl(self, dialect):
return dialect.type_descriptor(VARCHAR(self._length))
def process_result_value(self, value, dialect) -> _E | None:
if value is None:
return value
if not isinstance(value, str):
raise TypeError(f"expected str, got {type(value)}")
return self._enum_class(value)
def compare_values(self, x, y):
if x is None or y is None:
return x is y
return x == y

View File

@ -1,29 +1,36 @@
import json
import logging
from collections.abc import Mapping, Sequence
from datetime import UTC, datetime
from enum import Enum, StrEnum
from typing import TYPE_CHECKING, Any, Optional, Self, Union
from uuid import uuid4
from core.variables import utils as variable_utils
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
from factories.variable_factory import build_segment
if TYPE_CHECKING:
from models.model import AppMode
import sqlalchemy as sa
from sqlalchemy import func
from sqlalchemy import UniqueConstraint, func
from sqlalchemy.orm import Mapped, mapped_column
import contexts
from constants import DEFAULT_FILE_NUMBER_LIMITS, HIDDEN_VALUE
from core.helper import encrypter
from core.variables import SecretVariable, Variable
from core.variables import SecretVariable, Segment, SegmentType, Variable
from factories import variable_factory
from libs import helper
from .account import Account
from .base import Base
from .engine import db
from .enums import CreatorUserRole
from .types import StringUUID
from .enums import CreatorUserRole, DraftVariableType
from .types import EnumText, StringUUID
_logger = logging.getLogger(__name__)
if TYPE_CHECKING:
from models.model import AppMode
@ -651,7 +658,7 @@ class WorkflowNodeExecution(Base):
return json.loads(self.inputs) if self.inputs else None
@property
def outputs_dict(self):
def outputs_dict(self) -> dict[str, Any] | None:
return json.loads(self.outputs) if self.outputs else None
@property
@ -659,7 +666,7 @@ class WorkflowNodeExecution(Base):
return json.loads(self.process_data) if self.process_data else None
@property
def execution_metadata_dict(self):
def execution_metadata_dict(self) -> dict[str, Any] | None:
return json.loads(self.execution_metadata) if self.execution_metadata else None
@property
@ -797,3 +804,202 @@ class ConversationVariable(Base):
def to_variable(self) -> Variable:
mapping = json.loads(self.data)
return variable_factory.build_conversation_variable_from_mapping(mapping)
# Only `sys.query` and `sys.files` could be modified.
_EDITABLE_SYSTEM_VARIABLE = frozenset(["query", "files"])
def _naive_utc_datetime():
return datetime.now(UTC).replace(tzinfo=None)
class WorkflowDraftVariable(Base):
@staticmethod
def unique_columns() -> list[str]:
return [
"app_id",
"node_id",
"name",
]
__tablename__ = "workflow_draft_variables"
__table_args__ = (UniqueConstraint(*unique_columns()),)
# id is the unique identifier of a draft variable.
id: Mapped[str] = mapped_column(StringUUID, primary_key=True, server_default=db.text("uuid_generate_v4()"))
created_at = mapped_column(
db.DateTime,
nullable=False,
default=_naive_utc_datetime,
server_default=func.current_timestamp(),
)
updated_at = mapped_column(
db.DateTime,
nullable=False,
default=_naive_utc_datetime,
server_default=func.current_timestamp(),
onupdate=func.current_timestamp(),
)
# "`app_id` maps to the `id` field in the `model.App` model."
app_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
# `last_edited_at` records when the value of a given draft variable
# is edited.
#
# If it's not edited after creation, its value is `None`.
last_edited_at: Mapped[datetime | None] = mapped_column(
db.DateTime,
nullable=True,
default=None,
)
# The `node_id` field is special.
#
# If the variable is a conversation variable or a system variable, then the value of `node_id`
# is `conversation` or `sys`, respective.
#
# Otherwise, if the variable is a variable belonging to a specific node, the value of `_node_id` is
# the identity of correspond node in graph definition. An example of node id is `"1745769620734"`.
#
# However, there's one caveat. The id of the first "Answer" node in chatflow is "answer". (Other
# "Answer" node conform the rules above.)
node_id: Mapped[str] = mapped_column(sa.String(255), nullable=False, name="node_id")
# From `VARIABLE_PATTERN`, we may conclude that the length of a top level variable is less than
# 80 chars.
#
# ref: api/core/workflow/entities/variable_pool.py:18
name: Mapped[str] = mapped_column(sa.String(255), nullable=False)
description: Mapped[str] = mapped_column(
sa.String(255),
default="",
nullable=False,
)
selector: Mapped[str] = mapped_column(sa.String(255), nullable=False, name="selector")
value_type: Mapped[SegmentType] = mapped_column(EnumText(SegmentType, length=20))
# JSON string
value: Mapped[str] = mapped_column(sa.Text, nullable=False, name="value")
# visible
visible: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, default=True)
editable: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, default=False)
def get_selector(self) -> list[str]:
selector = json.loads(self.selector)
if not isinstance(selector, list):
_logger.error(
"invalid selector loaded from database, type=%s, value=%s",
type(selector),
self.selector,
)
raise ValueError("invalid selector.")
return selector
def _set_selector(self, value: list[str]):
self.selector = json.dumps(value)
def get_value(self) -> Segment | None:
return build_segment(json.loads(self.value))
def set_name(self, name: str):
self.name = name
self._set_selector([self.node_id, name])
def set_value(self, value: Segment):
self.value = json.dumps(value.value)
self.value_type = value.value_type
def get_node_id(self) -> str | None:
if self.get_variable_type() == DraftVariableType.NODE:
return self.node_id
else:
return None
def get_variable_type(self) -> DraftVariableType:
match self.node_id:
case DraftVariableType.CONVERSATION:
return DraftVariableType.CONVERSATION
case DraftVariableType.SYS:
return DraftVariableType.SYS
case _:
return DraftVariableType.NODE
@classmethod
def _new(
cls,
*,
app_id: str,
node_id: str,
name: str,
value: Segment,
description: str = "",
) -> "WorkflowDraftVariable":
variable = WorkflowDraftVariable()
variable.created_at = _naive_utc_datetime()
variable.updated_at = _naive_utc_datetime()
variable.description = description
variable.app_id = app_id
variable.node_id = node_id
variable.name = name
variable.app_id = app_id
variable.set_value(value)
variable._set_selector(list(variable_utils.to_selector(node_id, name)))
return variable
@classmethod
def new_conversation_variable(
cls,
*,
app_id: str,
name: str,
value: Segment,
) -> "WorkflowDraftVariable":
variable = cls._new(
app_id=app_id,
node_id=CONVERSATION_VARIABLE_NODE_ID,
name=name,
value=value,
)
return variable
@classmethod
def new_sys_variable(
cls,
*,
app_id: str,
name: str,
value: Segment,
editable: bool = False,
) -> "WorkflowDraftVariable":
variable = cls._new(app_id=app_id, node_id=SYSTEM_VARIABLE_NODE_ID, name=name, value=value)
variable.editable = editable
return variable
@classmethod
def new_node_variable(
cls,
*,
app_id: str,
node_id: str,
name: str,
value: Segment,
visible: bool = True,
) -> "WorkflowDraftVariable":
variable = cls._new(app_id=app_id, node_id=node_id, name=name, value=value)
variable.visible = visible
variable.editable = True
return variable
@property
def edited(self):
return self.last_edited_at is not None
def is_system_variable_editable(name: str) -> bool:
return name in _EDITABLE_SYSTEM_VARIABLE

View File

@ -0,0 +1,187 @@
from collections.abc import Callable, Iterable
from enum import StrEnum
from typing import Any, NamedTuple, TypeVar
import pytest
import sqlalchemy as sa
from sqlalchemy import exc as sa_exc
from sqlalchemy import insert
from sqlalchemy.orm import DeclarativeBase, Mapped, Session
from sqlalchemy.sql.sqltypes import VARCHAR
from models.types import EnumText
_user_type_admin = "admin"
_user_type_normal = "normal"
class _Base(DeclarativeBase):
pass
class _UserType(StrEnum):
admin = _user_type_admin
normal = _user_type_normal
class _EnumWithLongValue(StrEnum):
unknown = "unknown"
a_really_long_enum_values = "a_really_long_enum_values"
class _User(_Base):
__tablename__ = "users"
id: Mapped[int] = sa.Column(sa.Integer, primary_key=True)
name: Mapped[str] = sa.Column(sa.String(length=255), nullable=False)
user_type: Mapped[_UserType] = sa.Column(EnumText(enum_class=_UserType), nullable=False, default=_UserType.normal)
user_type_nullable: Mapped[_UserType | None] = sa.Column(EnumText(enum_class=_UserType), nullable=True)
class _ColumnTest(_Base):
__tablename__ = "column_test"
id: Mapped[int] = sa.Column(sa.Integer, primary_key=True)
user_type: Mapped[_UserType] = sa.Column(EnumText(enum_class=_UserType), nullable=False, default=_UserType.normal)
explicit_length: Mapped[_UserType | None] = sa.Column(
EnumText(_UserType, length=50), nullable=True, default=_UserType.normal
)
long_value: Mapped[_EnumWithLongValue] = sa.Column(EnumText(enum_class=_EnumWithLongValue), nullable=False)
_T = TypeVar("_T")
def _first(it: Iterable[_T]) -> _T:
ls = list(it)
if not ls:
raise ValueError("List is empty")
return ls[0]
class TestEnumText:
def test_column_impl(self):
engine = sa.create_engine("sqlite://", echo=False)
_Base.metadata.create_all(engine)
inspector = sa.inspect(engine)
columns = inspector.get_columns(_ColumnTest.__tablename__)
user_type_column = _first(c for c in columns if c["name"] == "user_type")
sql_type = user_type_column["type"]
assert isinstance(user_type_column["type"], VARCHAR)
assert sql_type.length == 20
assert user_type_column["nullable"] is False
explicit_length_column = _first(c for c in columns if c["name"] == "explicit_length")
sql_type = explicit_length_column["type"]
assert isinstance(sql_type, VARCHAR)
assert sql_type.length == 50
assert explicit_length_column["nullable"] is True
long_value_column = _first(c for c in columns if c["name"] == "long_value")
sql_type = long_value_column["type"]
assert isinstance(sql_type, VARCHAR)
assert sql_type.length == len(_EnumWithLongValue.a_really_long_enum_values)
def test_insert_and_select(self):
engine = sa.create_engine("sqlite://", echo=False)
_Base.metadata.create_all(engine)
with Session(engine) as session:
admin_user = _User(
name="admin",
user_type=_UserType.admin,
user_type_nullable=None,
)
session.add(admin_user)
session.flush()
admin_user_id = admin_user.id
normal_user = _User(
name="normal",
user_type=_UserType.normal.value,
user_type_nullable=_UserType.normal.value,
)
session.add(normal_user)
session.flush()
normal_user_id = normal_user.id
session.commit()
with Session(engine) as session:
user = session.query(_User).filter(_User.id == admin_user_id).first()
assert user.user_type == _UserType.admin
assert user.user_type_nullable is None
with Session(engine) as session:
user = session.query(_User).filter(_User.id == normal_user_id).first()
assert user.user_type == _UserType.normal
assert user.user_type_nullable == _UserType.normal
def test_insert_invalid_values(self):
def _session_insert_with_value(sess: Session, user_type: Any):
user = _User(name="test_user", user_type=user_type)
sess.add(user)
sess.flush()
def _insert_with_user(sess: Session, user_type: Any):
stmt = insert(_User).values(
{
"name": "test_user",
"user_type": user_type,
}
)
sess.execute(stmt)
class TestCase(NamedTuple):
name: str
action: Callable[[Session], None]
exc_type: type[Exception]
engine = sa.create_engine("sqlite://", echo=False)
_Base.metadata.create_all(engine)
cases = [
TestCase(
name="session insert with invalid value",
action=lambda s: _session_insert_with_value(s, "invalid"),
exc_type=ValueError,
),
TestCase(
name="session insert with invalid type",
action=lambda s: _session_insert_with_value(s, 1),
exc_type=TypeError,
),
TestCase(
name="insert with invalid value",
action=lambda s: _insert_with_user(s, "invalid"),
exc_type=ValueError,
),
TestCase(
name="insert with invalid type",
action=lambda s: _insert_with_user(s, 1),
exc_type=TypeError,
),
]
for idx, c in enumerate(cases, 1):
with pytest.raises(sa_exc.StatementError) as exc:
with Session(engine) as session:
c.action(session)
assert isinstance(exc.value.orig, c.exc_type), f"test case {idx} failed, name={c.name}"
def test_select_invalid_values(self):
engine = sa.create_engine("sqlite://", echo=False)
_Base.metadata.create_all(engine)
insertion_sql = """
INSERT INTO users (id, name, user_type) VALUES
(1, 'invalid_value', 'invalid');
"""
with Session(engine) as session:
session.execute(sa.text(insertion_sql))
session.commit()
with pytest.raises(ValueError) as exc:
with Session(engine) as session:
_user = session.query(_User).filter(_User.id == 1).first()