mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-20 20:19:16 +08:00
feat(api): Introduce WorkflowDraftVariable
Model (#19737)
- Introduce `WorkflowDraftVariable` model and the corresponding migration. - Implement `EnumText`, a custom column type for SQLAlchemy designed to work seamlessly with enumeration classes based on `StrEnum`.
This commit is contained in:
parent
bbebf9ad3e
commit
6a9e0b1005
@ -4,13 +4,14 @@ SQLAlchemy implementation of the WorkflowNodeExecutionRepository.
|
||||
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import Sequence
|
||||
from typing import Optional, Union
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any, Optional, Union, cast
|
||||
|
||||
from sqlalchemy import UnaryExpression, asc, delete, desc, select
|
||||
from sqlalchemy.engine import Engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from core.workflow.entities.node_entities import NodeRunMetadataKey
|
||||
from core.workflow.entities.node_execution_entities import (
|
||||
NodeExecution,
|
||||
NodeExecutionStatus,
|
||||
@ -122,7 +123,12 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository)
|
||||
status=status,
|
||||
error=db_model.error,
|
||||
elapsed_time=db_model.elapsed_time,
|
||||
metadata=metadata,
|
||||
# FIXME(QuantumGhost): a temporary workaround for the following type check failure in Python 3.11.
|
||||
# However, this problem is not occurred in Python 3.12.
|
||||
#
|
||||
# A case of this error is:
|
||||
# https://github.com/langgenius/dify/actions/runs/15112698604/job/42475659482?pr=19737#step:9:24
|
||||
metadata=cast(Mapping[NodeRunMetadataKey, Any] | None, metadata),
|
||||
created_at=db_model.created_at,
|
||||
finished_at=db_model.finished_at,
|
||||
)
|
||||
|
7
api/core/variables/consts.py
Normal file
7
api/core/variables/consts.py
Normal file
@ -0,0 +1,7 @@
|
||||
# The minimal selector length for valid variables.
|
||||
#
|
||||
# The first element of the selector is the node id, and the second element is the variable name.
|
||||
#
|
||||
# If the selector length is more than 2, the remaining parts are the keys / indexes paths used
|
||||
# to extract part of the variable value.
|
||||
MIN_SELECTORS_LENGTH = 2
|
8
api/core/variables/utils.py
Normal file
8
api/core/variables/utils.py
Normal file
@ -0,0 +1,8 @@
|
||||
from collections.abc import Iterable, Sequence
|
||||
|
||||
|
||||
def to_selector(node_id: str, name: str, paths: Iterable[str] = ()) -> Sequence[str]:
|
||||
selectors = [node_id, name]
|
||||
if paths:
|
||||
selectors.extend(paths)
|
||||
return selectors
|
@ -0,0 +1,51 @@
|
||||
"""add WorkflowDraftVariable model
|
||||
|
||||
Revision ID: 2adcbe1f5dfb
|
||||
Revises: d28f2004b072
|
||||
Create Date: 2025-05-15 15:31:03.128680
|
||||
|
||||
"""
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
import models as models
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "2adcbe1f5dfb"
|
||||
down_revision = "d28f2004b072"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.create_table(
|
||||
"workflow_draft_variables",
|
||||
sa.Column("id", models.types.StringUUID(), server_default=sa.text("uuid_generate_v4()"), nullable=False),
|
||||
sa.Column("created_at", sa.DateTime(), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=False),
|
||||
sa.Column("updated_at", sa.DateTime(), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=False),
|
||||
sa.Column("app_id", models.types.StringUUID(), nullable=False),
|
||||
sa.Column("last_edited_at", sa.DateTime(), nullable=True),
|
||||
sa.Column("node_id", sa.String(length=255), nullable=False),
|
||||
sa.Column("name", sa.String(length=255), nullable=False),
|
||||
sa.Column("description", sa.String(length=255), nullable=False),
|
||||
sa.Column("selector", sa.String(length=255), nullable=False),
|
||||
sa.Column("value_type", sa.String(length=20), nullable=False),
|
||||
sa.Column("value", sa.Text(), nullable=False),
|
||||
sa.Column("visible", sa.Boolean(), nullable=False),
|
||||
sa.Column("editable", sa.Boolean(), nullable=False),
|
||||
sa.PrimaryKeyConstraint("id", name=op.f("workflow_draft_variables_pkey")),
|
||||
sa.UniqueConstraint("app_id", "node_id", "name", name=op.f("workflow_draft_variables_app_id_key")),
|
||||
)
|
||||
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
|
||||
# Dropping `workflow_draft_variables` also drops any index associated with it.
|
||||
op.drop_table("workflow_draft_variables")
|
||||
|
||||
# ### end Alembic commands ###
|
@ -14,3 +14,10 @@ class UserFrom(StrEnum):
|
||||
class WorkflowRunTriggeredFrom(StrEnum):
|
||||
DEBUGGING = "debugging"
|
||||
APP_RUN = "app-run"
|
||||
|
||||
|
||||
class DraftVariableType(StrEnum):
|
||||
# node means that the correspond variable
|
||||
NODE = "node"
|
||||
SYS = "sys"
|
||||
CONVERSATION = "conversation"
|
||||
|
@ -1,4 +1,7 @@
|
||||
from sqlalchemy import CHAR, TypeDecorator
|
||||
import enum
|
||||
from typing import Generic, TypeVar
|
||||
|
||||
from sqlalchemy import CHAR, VARCHAR, TypeDecorator
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
|
||||
|
||||
@ -24,3 +27,51 @@ class StringUUID(TypeDecorator):
|
||||
if value is None:
|
||||
return value
|
||||
return str(value)
|
||||
|
||||
|
||||
_E = TypeVar("_E", bound=enum.StrEnum)
|
||||
|
||||
|
||||
class EnumText(TypeDecorator, Generic[_E]):
|
||||
impl = VARCHAR
|
||||
cache_ok = True
|
||||
|
||||
_length: int
|
||||
_enum_class: type[_E]
|
||||
|
||||
def __init__(self, enum_class: type[_E], length: int | None = None):
|
||||
self._enum_class = enum_class
|
||||
max_enum_value_len = max(len(e.value) for e in enum_class)
|
||||
if length is not None:
|
||||
if length < max_enum_value_len:
|
||||
raise ValueError("length should be greater than enum value length.")
|
||||
self._length = length
|
||||
else:
|
||||
# leave some rooms for future longer enum values.
|
||||
self._length = max(max_enum_value_len, 20)
|
||||
|
||||
def process_bind_param(self, value: _E | str | None, dialect):
|
||||
if value is None:
|
||||
return value
|
||||
if isinstance(value, self._enum_class):
|
||||
return value.value
|
||||
elif isinstance(value, str):
|
||||
self._enum_class(value)
|
||||
return value
|
||||
else:
|
||||
raise TypeError(f"expected str or {self._enum_class}, got {type(value)}")
|
||||
|
||||
def load_dialect_impl(self, dialect):
|
||||
return dialect.type_descriptor(VARCHAR(self._length))
|
||||
|
||||
def process_result_value(self, value, dialect) -> _E | None:
|
||||
if value is None:
|
||||
return value
|
||||
if not isinstance(value, str):
|
||||
raise TypeError(f"expected str, got {type(value)}")
|
||||
return self._enum_class(value)
|
||||
|
||||
def compare_values(self, x, y):
|
||||
if x is None or y is None:
|
||||
return x is y
|
||||
return x == y
|
||||
|
@ -1,29 +1,36 @@
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import Mapping, Sequence
|
||||
from datetime import UTC, datetime
|
||||
from enum import Enum, StrEnum
|
||||
from typing import TYPE_CHECKING, Any, Optional, Self, Union
|
||||
from uuid import uuid4
|
||||
|
||||
from core.variables import utils as variable_utils
|
||||
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
|
||||
from factories.variable_factory import build_segment
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from models.model import AppMode
|
||||
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy import UniqueConstraint, func
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
import contexts
|
||||
from constants import DEFAULT_FILE_NUMBER_LIMITS, HIDDEN_VALUE
|
||||
from core.helper import encrypter
|
||||
from core.variables import SecretVariable, Variable
|
||||
from core.variables import SecretVariable, Segment, SegmentType, Variable
|
||||
from factories import variable_factory
|
||||
from libs import helper
|
||||
|
||||
from .account import Account
|
||||
from .base import Base
|
||||
from .engine import db
|
||||
from .enums import CreatorUserRole
|
||||
from .types import StringUUID
|
||||
from .enums import CreatorUserRole, DraftVariableType
|
||||
from .types import EnumText, StringUUID
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from models.model import AppMode
|
||||
@ -651,7 +658,7 @@ class WorkflowNodeExecution(Base):
|
||||
return json.loads(self.inputs) if self.inputs else None
|
||||
|
||||
@property
|
||||
def outputs_dict(self):
|
||||
def outputs_dict(self) -> dict[str, Any] | None:
|
||||
return json.loads(self.outputs) if self.outputs else None
|
||||
|
||||
@property
|
||||
@ -659,7 +666,7 @@ class WorkflowNodeExecution(Base):
|
||||
return json.loads(self.process_data) if self.process_data else None
|
||||
|
||||
@property
|
||||
def execution_metadata_dict(self):
|
||||
def execution_metadata_dict(self) -> dict[str, Any] | None:
|
||||
return json.loads(self.execution_metadata) if self.execution_metadata else None
|
||||
|
||||
@property
|
||||
@ -797,3 +804,202 @@ class ConversationVariable(Base):
|
||||
def to_variable(self) -> Variable:
|
||||
mapping = json.loads(self.data)
|
||||
return variable_factory.build_conversation_variable_from_mapping(mapping)
|
||||
|
||||
|
||||
# Only `sys.query` and `sys.files` could be modified.
|
||||
_EDITABLE_SYSTEM_VARIABLE = frozenset(["query", "files"])
|
||||
|
||||
|
||||
def _naive_utc_datetime():
|
||||
return datetime.now(UTC).replace(tzinfo=None)
|
||||
|
||||
|
||||
class WorkflowDraftVariable(Base):
|
||||
@staticmethod
|
||||
def unique_columns() -> list[str]:
|
||||
return [
|
||||
"app_id",
|
||||
"node_id",
|
||||
"name",
|
||||
]
|
||||
|
||||
__tablename__ = "workflow_draft_variables"
|
||||
__table_args__ = (UniqueConstraint(*unique_columns()),)
|
||||
|
||||
# id is the unique identifier of a draft variable.
|
||||
id: Mapped[str] = mapped_column(StringUUID, primary_key=True, server_default=db.text("uuid_generate_v4()"))
|
||||
|
||||
created_at = mapped_column(
|
||||
db.DateTime,
|
||||
nullable=False,
|
||||
default=_naive_utc_datetime,
|
||||
server_default=func.current_timestamp(),
|
||||
)
|
||||
|
||||
updated_at = mapped_column(
|
||||
db.DateTime,
|
||||
nullable=False,
|
||||
default=_naive_utc_datetime,
|
||||
server_default=func.current_timestamp(),
|
||||
onupdate=func.current_timestamp(),
|
||||
)
|
||||
|
||||
# "`app_id` maps to the `id` field in the `model.App` model."
|
||||
app_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
|
||||
# `last_edited_at` records when the value of a given draft variable
|
||||
# is edited.
|
||||
#
|
||||
# If it's not edited after creation, its value is `None`.
|
||||
last_edited_at: Mapped[datetime | None] = mapped_column(
|
||||
db.DateTime,
|
||||
nullable=True,
|
||||
default=None,
|
||||
)
|
||||
|
||||
# The `node_id` field is special.
|
||||
#
|
||||
# If the variable is a conversation variable or a system variable, then the value of `node_id`
|
||||
# is `conversation` or `sys`, respective.
|
||||
#
|
||||
# Otherwise, if the variable is a variable belonging to a specific node, the value of `_node_id` is
|
||||
# the identity of correspond node in graph definition. An example of node id is `"1745769620734"`.
|
||||
#
|
||||
# However, there's one caveat. The id of the first "Answer" node in chatflow is "answer". (Other
|
||||
# "Answer" node conform the rules above.)
|
||||
node_id: Mapped[str] = mapped_column(sa.String(255), nullable=False, name="node_id")
|
||||
|
||||
# From `VARIABLE_PATTERN`, we may conclude that the length of a top level variable is less than
|
||||
# 80 chars.
|
||||
#
|
||||
# ref: api/core/workflow/entities/variable_pool.py:18
|
||||
name: Mapped[str] = mapped_column(sa.String(255), nullable=False)
|
||||
description: Mapped[str] = mapped_column(
|
||||
sa.String(255),
|
||||
default="",
|
||||
nullable=False,
|
||||
)
|
||||
|
||||
selector: Mapped[str] = mapped_column(sa.String(255), nullable=False, name="selector")
|
||||
|
||||
value_type: Mapped[SegmentType] = mapped_column(EnumText(SegmentType, length=20))
|
||||
# JSON string
|
||||
value: Mapped[str] = mapped_column(sa.Text, nullable=False, name="value")
|
||||
|
||||
# visible
|
||||
visible: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, default=True)
|
||||
editable: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, default=False)
|
||||
|
||||
def get_selector(self) -> list[str]:
|
||||
selector = json.loads(self.selector)
|
||||
if not isinstance(selector, list):
|
||||
_logger.error(
|
||||
"invalid selector loaded from database, type=%s, value=%s",
|
||||
type(selector),
|
||||
self.selector,
|
||||
)
|
||||
raise ValueError("invalid selector.")
|
||||
return selector
|
||||
|
||||
def _set_selector(self, value: list[str]):
|
||||
self.selector = json.dumps(value)
|
||||
|
||||
def get_value(self) -> Segment | None:
|
||||
return build_segment(json.loads(self.value))
|
||||
|
||||
def set_name(self, name: str):
|
||||
self.name = name
|
||||
self._set_selector([self.node_id, name])
|
||||
|
||||
def set_value(self, value: Segment):
|
||||
self.value = json.dumps(value.value)
|
||||
self.value_type = value.value_type
|
||||
|
||||
def get_node_id(self) -> str | None:
|
||||
if self.get_variable_type() == DraftVariableType.NODE:
|
||||
return self.node_id
|
||||
else:
|
||||
return None
|
||||
|
||||
def get_variable_type(self) -> DraftVariableType:
|
||||
match self.node_id:
|
||||
case DraftVariableType.CONVERSATION:
|
||||
return DraftVariableType.CONVERSATION
|
||||
case DraftVariableType.SYS:
|
||||
return DraftVariableType.SYS
|
||||
case _:
|
||||
return DraftVariableType.NODE
|
||||
|
||||
@classmethod
|
||||
def _new(
|
||||
cls,
|
||||
*,
|
||||
app_id: str,
|
||||
node_id: str,
|
||||
name: str,
|
||||
value: Segment,
|
||||
description: str = "",
|
||||
) -> "WorkflowDraftVariable":
|
||||
variable = WorkflowDraftVariable()
|
||||
variable.created_at = _naive_utc_datetime()
|
||||
variable.updated_at = _naive_utc_datetime()
|
||||
variable.description = description
|
||||
variable.app_id = app_id
|
||||
variable.node_id = node_id
|
||||
variable.name = name
|
||||
variable.app_id = app_id
|
||||
variable.set_value(value)
|
||||
variable._set_selector(list(variable_utils.to_selector(node_id, name)))
|
||||
return variable
|
||||
|
||||
@classmethod
|
||||
def new_conversation_variable(
|
||||
cls,
|
||||
*,
|
||||
app_id: str,
|
||||
name: str,
|
||||
value: Segment,
|
||||
) -> "WorkflowDraftVariable":
|
||||
variable = cls._new(
|
||||
app_id=app_id,
|
||||
node_id=CONVERSATION_VARIABLE_NODE_ID,
|
||||
name=name,
|
||||
value=value,
|
||||
)
|
||||
return variable
|
||||
|
||||
@classmethod
|
||||
def new_sys_variable(
|
||||
cls,
|
||||
*,
|
||||
app_id: str,
|
||||
name: str,
|
||||
value: Segment,
|
||||
editable: bool = False,
|
||||
) -> "WorkflowDraftVariable":
|
||||
variable = cls._new(app_id=app_id, node_id=SYSTEM_VARIABLE_NODE_ID, name=name, value=value)
|
||||
variable.editable = editable
|
||||
return variable
|
||||
|
||||
@classmethod
|
||||
def new_node_variable(
|
||||
cls,
|
||||
*,
|
||||
app_id: str,
|
||||
node_id: str,
|
||||
name: str,
|
||||
value: Segment,
|
||||
visible: bool = True,
|
||||
) -> "WorkflowDraftVariable":
|
||||
variable = cls._new(app_id=app_id, node_id=node_id, name=name, value=value)
|
||||
variable.visible = visible
|
||||
variable.editable = True
|
||||
return variable
|
||||
|
||||
@property
|
||||
def edited(self):
|
||||
return self.last_edited_at is not None
|
||||
|
||||
|
||||
def is_system_variable_editable(name: str) -> bool:
|
||||
return name in _EDITABLE_SYSTEM_VARIABLE
|
||||
|
187
api/tests/unit_tests/models/test_types_enum_text.py
Normal file
187
api/tests/unit_tests/models/test_types_enum_text.py
Normal file
@ -0,0 +1,187 @@
|
||||
from collections.abc import Callable, Iterable
|
||||
from enum import StrEnum
|
||||
from typing import Any, NamedTuple, TypeVar
|
||||
|
||||
import pytest
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy import exc as sa_exc
|
||||
from sqlalchemy import insert
|
||||
from sqlalchemy.orm import DeclarativeBase, Mapped, Session
|
||||
from sqlalchemy.sql.sqltypes import VARCHAR
|
||||
|
||||
from models.types import EnumText
|
||||
|
||||
_user_type_admin = "admin"
|
||||
_user_type_normal = "normal"
|
||||
|
||||
|
||||
class _Base(DeclarativeBase):
|
||||
pass
|
||||
|
||||
|
||||
class _UserType(StrEnum):
|
||||
admin = _user_type_admin
|
||||
normal = _user_type_normal
|
||||
|
||||
|
||||
class _EnumWithLongValue(StrEnum):
|
||||
unknown = "unknown"
|
||||
a_really_long_enum_values = "a_really_long_enum_values"
|
||||
|
||||
|
||||
class _User(_Base):
|
||||
__tablename__ = "users"
|
||||
|
||||
id: Mapped[int] = sa.Column(sa.Integer, primary_key=True)
|
||||
name: Mapped[str] = sa.Column(sa.String(length=255), nullable=False)
|
||||
user_type: Mapped[_UserType] = sa.Column(EnumText(enum_class=_UserType), nullable=False, default=_UserType.normal)
|
||||
user_type_nullable: Mapped[_UserType | None] = sa.Column(EnumText(enum_class=_UserType), nullable=True)
|
||||
|
||||
|
||||
class _ColumnTest(_Base):
|
||||
__tablename__ = "column_test"
|
||||
|
||||
id: Mapped[int] = sa.Column(sa.Integer, primary_key=True)
|
||||
|
||||
user_type: Mapped[_UserType] = sa.Column(EnumText(enum_class=_UserType), nullable=False, default=_UserType.normal)
|
||||
explicit_length: Mapped[_UserType | None] = sa.Column(
|
||||
EnumText(_UserType, length=50), nullable=True, default=_UserType.normal
|
||||
)
|
||||
long_value: Mapped[_EnumWithLongValue] = sa.Column(EnumText(enum_class=_EnumWithLongValue), nullable=False)
|
||||
|
||||
|
||||
_T = TypeVar("_T")
|
||||
|
||||
|
||||
def _first(it: Iterable[_T]) -> _T:
|
||||
ls = list(it)
|
||||
if not ls:
|
||||
raise ValueError("List is empty")
|
||||
return ls[0]
|
||||
|
||||
|
||||
class TestEnumText:
|
||||
def test_column_impl(self):
|
||||
engine = sa.create_engine("sqlite://", echo=False)
|
||||
_Base.metadata.create_all(engine)
|
||||
|
||||
inspector = sa.inspect(engine)
|
||||
columns = inspector.get_columns(_ColumnTest.__tablename__)
|
||||
|
||||
user_type_column = _first(c for c in columns if c["name"] == "user_type")
|
||||
sql_type = user_type_column["type"]
|
||||
assert isinstance(user_type_column["type"], VARCHAR)
|
||||
assert sql_type.length == 20
|
||||
assert user_type_column["nullable"] is False
|
||||
|
||||
explicit_length_column = _first(c for c in columns if c["name"] == "explicit_length")
|
||||
sql_type = explicit_length_column["type"]
|
||||
assert isinstance(sql_type, VARCHAR)
|
||||
assert sql_type.length == 50
|
||||
assert explicit_length_column["nullable"] is True
|
||||
|
||||
long_value_column = _first(c for c in columns if c["name"] == "long_value")
|
||||
sql_type = long_value_column["type"]
|
||||
assert isinstance(sql_type, VARCHAR)
|
||||
assert sql_type.length == len(_EnumWithLongValue.a_really_long_enum_values)
|
||||
|
||||
def test_insert_and_select(self):
|
||||
engine = sa.create_engine("sqlite://", echo=False)
|
||||
_Base.metadata.create_all(engine)
|
||||
|
||||
with Session(engine) as session:
|
||||
admin_user = _User(
|
||||
name="admin",
|
||||
user_type=_UserType.admin,
|
||||
user_type_nullable=None,
|
||||
)
|
||||
session.add(admin_user)
|
||||
session.flush()
|
||||
admin_user_id = admin_user.id
|
||||
|
||||
normal_user = _User(
|
||||
name="normal",
|
||||
user_type=_UserType.normal.value,
|
||||
user_type_nullable=_UserType.normal.value,
|
||||
)
|
||||
session.add(normal_user)
|
||||
session.flush()
|
||||
normal_user_id = normal_user.id
|
||||
session.commit()
|
||||
|
||||
with Session(engine) as session:
|
||||
user = session.query(_User).filter(_User.id == admin_user_id).first()
|
||||
assert user.user_type == _UserType.admin
|
||||
assert user.user_type_nullable is None
|
||||
|
||||
with Session(engine) as session:
|
||||
user = session.query(_User).filter(_User.id == normal_user_id).first()
|
||||
assert user.user_type == _UserType.normal
|
||||
assert user.user_type_nullable == _UserType.normal
|
||||
|
||||
def test_insert_invalid_values(self):
|
||||
def _session_insert_with_value(sess: Session, user_type: Any):
|
||||
user = _User(name="test_user", user_type=user_type)
|
||||
sess.add(user)
|
||||
sess.flush()
|
||||
|
||||
def _insert_with_user(sess: Session, user_type: Any):
|
||||
stmt = insert(_User).values(
|
||||
{
|
||||
"name": "test_user",
|
||||
"user_type": user_type,
|
||||
}
|
||||
)
|
||||
sess.execute(stmt)
|
||||
|
||||
class TestCase(NamedTuple):
|
||||
name: str
|
||||
action: Callable[[Session], None]
|
||||
exc_type: type[Exception]
|
||||
|
||||
engine = sa.create_engine("sqlite://", echo=False)
|
||||
_Base.metadata.create_all(engine)
|
||||
cases = [
|
||||
TestCase(
|
||||
name="session insert with invalid value",
|
||||
action=lambda s: _session_insert_with_value(s, "invalid"),
|
||||
exc_type=ValueError,
|
||||
),
|
||||
TestCase(
|
||||
name="session insert with invalid type",
|
||||
action=lambda s: _session_insert_with_value(s, 1),
|
||||
exc_type=TypeError,
|
||||
),
|
||||
TestCase(
|
||||
name="insert with invalid value",
|
||||
action=lambda s: _insert_with_user(s, "invalid"),
|
||||
exc_type=ValueError,
|
||||
),
|
||||
TestCase(
|
||||
name="insert with invalid type",
|
||||
action=lambda s: _insert_with_user(s, 1),
|
||||
exc_type=TypeError,
|
||||
),
|
||||
]
|
||||
for idx, c in enumerate(cases, 1):
|
||||
with pytest.raises(sa_exc.StatementError) as exc:
|
||||
with Session(engine) as session:
|
||||
c.action(session)
|
||||
|
||||
assert isinstance(exc.value.orig, c.exc_type), f"test case {idx} failed, name={c.name}"
|
||||
|
||||
def test_select_invalid_values(self):
|
||||
engine = sa.create_engine("sqlite://", echo=False)
|
||||
_Base.metadata.create_all(engine)
|
||||
|
||||
insertion_sql = """
|
||||
INSERT INTO users (id, name, user_type) VALUES
|
||||
(1, 'invalid_value', 'invalid');
|
||||
"""
|
||||
with Session(engine) as session:
|
||||
session.execute(sa.text(insertion_sql))
|
||||
session.commit()
|
||||
|
||||
with pytest.raises(ValueError) as exc:
|
||||
with Session(engine) as session:
|
||||
_user = session.query(_User).filter(_User.id == 1).first()
|
Loading…
x
Reference in New Issue
Block a user