Merge branch 'main' into main

This commit is contained in:
Sebastián Ramírez
2023-10-29 11:55:17 +04:00
committed by GitHub
165 changed files with 3744 additions and 3153 deletions

View File

@@ -1,16 +1,16 @@
__version__ = "0.0.8"
__version__ = "0.0.10"
# Re-export from SQLAlchemy
from sqlalchemy.engine import create_mock_engine as create_mock_engine
from sqlalchemy.engine import engine_from_config as engine_from_config
from sqlalchemy.inspection import inspect as inspect
from sqlalchemy.schema import BLANK_SCHEMA as BLANK_SCHEMA
from sqlalchemy.schema import DDL as DDL
from sqlalchemy.schema import CheckConstraint as CheckConstraint
from sqlalchemy.schema import Column as Column
from sqlalchemy.schema import ColumnDefault as ColumnDefault
from sqlalchemy.schema import Computed as Computed
from sqlalchemy.schema import Constraint as Constraint
from sqlalchemy.schema import DDL as DDL
from sqlalchemy.schema import DefaultClause as DefaultClause
from sqlalchemy.schema import FetchedValue as FetchedValue
from sqlalchemy.schema import ForeignKey as ForeignKey
@@ -23,6 +23,14 @@ from sqlalchemy.schema import Sequence as Sequence
from sqlalchemy.schema import Table as Table
from sqlalchemy.schema import ThreadLocalMetaData as ThreadLocalMetaData
from sqlalchemy.schema import UniqueConstraint as UniqueConstraint
from sqlalchemy.sql import LABEL_STYLE_DEFAULT as LABEL_STYLE_DEFAULT
from sqlalchemy.sql import (
LABEL_STYLE_DISAMBIGUATE_ONLY as LABEL_STYLE_DISAMBIGUATE_ONLY,
)
from sqlalchemy.sql import LABEL_STYLE_NONE as LABEL_STYLE_NONE
from sqlalchemy.sql import (
LABEL_STYLE_TABLENAME_PLUS_COL as LABEL_STYLE_TABLENAME_PLUS_COL,
)
from sqlalchemy.sql import alias as alias
from sqlalchemy.sql import all_ as all_
from sqlalchemy.sql import and_ as and_
@@ -48,14 +56,6 @@ from sqlalchemy.sql import insert as insert
from sqlalchemy.sql import intersect as intersect
from sqlalchemy.sql import intersect_all as intersect_all
from sqlalchemy.sql import join as join
from sqlalchemy.sql import LABEL_STYLE_DEFAULT as LABEL_STYLE_DEFAULT
from sqlalchemy.sql import (
LABEL_STYLE_DISAMBIGUATE_ONLY as LABEL_STYLE_DISAMBIGUATE_ONLY,
)
from sqlalchemy.sql import LABEL_STYLE_NONE as LABEL_STYLE_NONE
from sqlalchemy.sql import (
LABEL_STYLE_TABLENAME_PLUS_COL as LABEL_STYLE_TABLENAME_PLUS_COL,
)
from sqlalchemy.sql import lambda_stmt as lambda_stmt
from sqlalchemy.sql import lateral as lateral
from sqlalchemy.sql import literal as literal
@@ -85,55 +85,53 @@ from sqlalchemy.sql import values as values
from sqlalchemy.sql import within_group as within_group
from sqlalchemy.types import ARRAY as ARRAY
from sqlalchemy.types import BIGINT as BIGINT
from sqlalchemy.types import BigInteger as BigInteger
from sqlalchemy.types import BINARY as BINARY
from sqlalchemy.types import BLOB as BLOB
from sqlalchemy.types import BOOLEAN as BOOLEAN
from sqlalchemy.types import Boolean as Boolean
from sqlalchemy.types import CHAR as CHAR
from sqlalchemy.types import CLOB as CLOB
from sqlalchemy.types import DATE as DATE
from sqlalchemy.types import Date as Date
from sqlalchemy.types import DATETIME as DATETIME
from sqlalchemy.types import DateTime as DateTime
from sqlalchemy.types import DECIMAL as DECIMAL
from sqlalchemy.types import Enum as Enum
from sqlalchemy.types import FLOAT as FLOAT
from sqlalchemy.types import Float as Float
from sqlalchemy.types import INT as INT
from sqlalchemy.types import INTEGER as INTEGER
from sqlalchemy.types import Integer as Integer
from sqlalchemy.types import Interval as Interval
from sqlalchemy.types import JSON as JSON
from sqlalchemy.types import LargeBinary as LargeBinary
from sqlalchemy.types import NCHAR as NCHAR
from sqlalchemy.types import NUMERIC as NUMERIC
from sqlalchemy.types import Numeric as Numeric
from sqlalchemy.types import NVARCHAR as NVARCHAR
from sqlalchemy.types import PickleType as PickleType
from sqlalchemy.types import REAL as REAL
from sqlalchemy.types import SMALLINT as SMALLINT
from sqlalchemy.types import TEXT as TEXT
from sqlalchemy.types import TIME as TIME
from sqlalchemy.types import TIMESTAMP as TIMESTAMP
from sqlalchemy.types import VARBINARY as VARBINARY
from sqlalchemy.types import VARCHAR as VARCHAR
from sqlalchemy.types import BigInteger as BigInteger
from sqlalchemy.types import Boolean as Boolean
from sqlalchemy.types import Date as Date
from sqlalchemy.types import DateTime as DateTime
from sqlalchemy.types import Enum as Enum
from sqlalchemy.types import Float as Float
from sqlalchemy.types import Integer as Integer
from sqlalchemy.types import Interval as Interval
from sqlalchemy.types import LargeBinary as LargeBinary
from sqlalchemy.types import Numeric as Numeric
from sqlalchemy.types import PickleType as PickleType
from sqlalchemy.types import SmallInteger as SmallInteger
from sqlalchemy.types import String as String
from sqlalchemy.types import TEXT as TEXT
from sqlalchemy.types import Text as Text
from sqlalchemy.types import TIME as TIME
from sqlalchemy.types import Time as Time
from sqlalchemy.types import TIMESTAMP as TIMESTAMP
from sqlalchemy.types import TypeDecorator as TypeDecorator
from sqlalchemy.types import Unicode as Unicode
from sqlalchemy.types import UnicodeText as UnicodeText
from sqlalchemy.types import VARBINARY as VARBINARY
from sqlalchemy.types import VARCHAR as VARCHAR
# Extensions and modifications of SQLAlchemy in SQLModel
# From SQLModel, modifications of SQLAlchemy or equivalents of Pydantic
from .engine.create import create_engine as create_engine
from .orm.session import Session as Session
from .sql.expression import select as select
from .sql.expression import col as col
from .sql.sqltypes import AutoString as AutoString
# Export SQLModel specifics (equivalent to Pydantic)
from .main import SQLModel as SQLModel
from .main import Field as Field
from .main import Relationship as Relationship
from .main import SQLModel as SQLModel
from .orm.session import Session as Session
from .sql.expression import col as col
from .sql.expression import select as select
from .sql.sqltypes import AutoString as AutoString

View File

@@ -6,7 +6,7 @@ class _DefaultPlaceholder:
You shouldn't use this class directly.
It's used internally to recognize when a default value has been overwritten, even
if the overriden default value was truthy.
if the overridden default value was truthy.
"""
def __init__(self, value: Any):
@@ -27,6 +27,6 @@ def Default(value: _TDefaultType) -> _TDefaultType:
You shouldn't use this function directly.
It's used internally to recognize when a default value has been overwritten, even
if the overriden default value was truthy.
if the overridden default value was truthy.
"""
return _DefaultPlaceholder(value) # type: ignore

View File

@@ -1,17 +1,17 @@
from typing import Any, Mapping, Optional, Sequence, TypeVar, Union
from typing import Any, Mapping, Optional, Sequence, TypeVar, Union, overload
from sqlalchemy import util
from sqlalchemy.ext.asyncio import AsyncSession as _AsyncSession
from sqlalchemy.ext.asyncio import engine
from sqlalchemy.ext.asyncio.engine import AsyncConnection, AsyncEngine
from sqlalchemy.util.concurrency import greenlet_spawn
from sqlmodel.sql.base import Executable
from ...engine.result import ScalarResult
from ...engine.result import Result, ScalarResult
from ...orm.session import Session
from ...sql.expression import Select
from ...sql.base import Executable
from ...sql.expression import Select, SelectOfScalar
_T = TypeVar("_T")
_TSelectParam = TypeVar("_TSelectParam")
class AsyncSession(_AsyncSession):
@@ -40,14 +40,46 @@ class AsyncSession(_AsyncSession):
Session(bind=bind, binds=binds, **kw) # type: ignore
)
@overload
async def exec(
self,
statement: Union[Select[_T], Executable[_T]],
statement: Select[_TSelectParam],
*,
params: Optional[Union[Mapping[str, Any], Sequence[Mapping[str, Any]]]] = None,
execution_options: Mapping[str, Any] = util.EMPTY_DICT,
bind_arguments: Optional[Mapping[str, Any]] = None,
_parent_execute_state: Optional[Any] = None,
_add_event: Optional[Any] = None,
**kw: Any,
) -> Result[_TSelectParam]:
...
@overload
async def exec(
self,
statement: SelectOfScalar[_TSelectParam],
*,
params: Optional[Union[Mapping[str, Any], Sequence[Mapping[str, Any]]]] = None,
execution_options: Mapping[str, Any] = util.EMPTY_DICT,
bind_arguments: Optional[Mapping[str, Any]] = None,
_parent_execute_state: Optional[Any] = None,
_add_event: Optional[Any] = None,
**kw: Any,
) -> ScalarResult[_TSelectParam]:
...
async def exec(
self,
statement: Union[
Select[_TSelectParam],
SelectOfScalar[_TSelectParam],
Executable[_TSelectParam],
],
params: Optional[Union[Mapping[str, Any], Sequence[Mapping[str, Any]]]] = None,
execution_options: Mapping[Any, Any] = util.EMPTY_DICT,
bind_arguments: Optional[Mapping[str, Any]] = None,
**kw: Any,
) -> ScalarResult[_T]:
) -> Union[Result[_TSelectParam], ScalarResult[_TSelectParam]]:
# TODO: the documentation says execution_options accepts a dict, but only
# util.immutabledict has the union() method. Is this a bug in SQLAlchemy?
execution_options = execution_options.union({"prebuffer_rows": True}) # type: ignore

View File

@@ -11,6 +11,7 @@ from typing import (
Callable,
ClassVar,
Dict,
ForwardRef,
List,
Mapping,
Optional,
@@ -21,19 +22,29 @@ from typing import (
TypeVar,
Union,
cast,
overload,
)
from pydantic import BaseConfig, BaseModel
from pydantic.errors import ConfigError, DictError
from pydantic.fields import SHAPE_SINGLETON
from pydantic.fields import SHAPE_SINGLETON, ModelField, Undefined, UndefinedType
from pydantic.fields import FieldInfo as PydanticFieldInfo
from pydantic.fields import ModelField, Undefined, UndefinedType
from pydantic.main import ModelMetaclass, validate_model
from pydantic.typing import ForwardRef, NoArgAnyCallable, resolve_annotations
from pydantic.typing import NoArgAnyCallable, resolve_annotations
from pydantic.utils import ROOT_KEY, Representation
from sqlalchemy import Boolean, Column, Date, DateTime
from sqlalchemy import (
Boolean,
Column,
Date,
DateTime,
Float,
ForeignKey,
Integer,
Interval,
Numeric,
inspect,
)
from sqlalchemy import Enum as sa_Enum
from sqlalchemy import Float, ForeignKey, Integer, Interval, Numeric, inspect
from sqlalchemy.orm import RelationshipProperty, declared_attr, registry, relationship
from sqlalchemy.orm.attributes import set_attribute
from sqlalchemy.orm.decl_api import DeclarativeMeta
@@ -78,6 +89,28 @@ class FieldInfo(PydanticFieldInfo):
"Passing sa_column_kwargs is not supported when "
"also passing a sa_column"
)
if primary_key is not Undefined:
raise RuntimeError(
"Passing primary_key is not supported when "
"also passing a sa_column"
)
if nullable is not Undefined:
raise RuntimeError(
"Passing nullable is not supported when " "also passing a sa_column"
)
if foreign_key is not Undefined:
raise RuntimeError(
"Passing foreign_key is not supported when "
"also passing a sa_column"
)
if unique is not Undefined:
raise RuntimeError(
"Passing unique is not supported when " "also passing a sa_column"
)
if index is not Undefined:
raise RuntimeError(
"Passing index is not supported when " "also passing a sa_column"
)
super().__init__(default=default, **kwargs)
self.primary_key = primary_key
self.nullable = nullable
@@ -118,6 +151,7 @@ class RelationshipInfo(Representation):
self.sa_relationship_kwargs = sa_relationship_kwargs
@overload
def Field(
default: Any = Undefined,
*,
@@ -137,15 +171,99 @@ def Field(
lt: Optional[float] = None,
le: Optional[float] = None,
multiple_of: Optional[float] = None,
max_digits: Optional[int] = None,
decimal_places: Optional[int] = None,
min_items: Optional[int] = None,
max_items: Optional[int] = None,
unique_items: Optional[bool] = None,
min_length: Optional[int] = None,
max_length: Optional[int] = None,
allow_mutation: bool = True,
regex: Optional[str] = None,
primary_key: bool = False,
foreign_key: Optional[Any] = None,
unique: bool = False,
discriminator: Optional[str] = None,
repr: bool = True,
primary_key: Union[bool, UndefinedType] = Undefined,
foreign_key: Any = Undefined,
unique: Union[bool, UndefinedType] = Undefined,
nullable: Union[bool, UndefinedType] = Undefined,
index: Union[bool, UndefinedType] = Undefined,
sa_column_args: Union[Sequence[Any], UndefinedType] = Undefined,
sa_column_kwargs: Union[Mapping[str, Any], UndefinedType] = Undefined,
schema_extra: Optional[Dict[str, Any]] = None,
) -> Any:
...
@overload
def Field(
default: Any = Undefined,
*,
default_factory: Optional[NoArgAnyCallable] = None,
alias: Optional[str] = None,
title: Optional[str] = None,
description: Optional[str] = None,
exclude: Union[
AbstractSet[Union[int, str]], Mapping[Union[int, str], Any], Any
] = None,
include: Union[
AbstractSet[Union[int, str]], Mapping[Union[int, str], Any], Any
] = None,
const: Optional[bool] = None,
gt: Optional[float] = None,
ge: Optional[float] = None,
lt: Optional[float] = None,
le: Optional[float] = None,
multiple_of: Optional[float] = None,
max_digits: Optional[int] = None,
decimal_places: Optional[int] = None,
min_items: Optional[int] = None,
max_items: Optional[int] = None,
unique_items: Optional[bool] = None,
min_length: Optional[int] = None,
max_length: Optional[int] = None,
allow_mutation: bool = True,
regex: Optional[str] = None,
discriminator: Optional[str] = None,
repr: bool = True,
sa_column: Union[Column, UndefinedType] = Undefined, # type: ignore
schema_extra: Optional[Dict[str, Any]] = None,
) -> Any:
...
def Field(
default: Any = Undefined,
*,
default_factory: Optional[NoArgAnyCallable] = None,
alias: Optional[str] = None,
title: Optional[str] = None,
description: Optional[str] = None,
exclude: Union[
AbstractSet[Union[int, str]], Mapping[Union[int, str], Any], Any
] = None,
include: Union[
AbstractSet[Union[int, str]], Mapping[Union[int, str], Any], Any
] = None,
const: Optional[bool] = None,
gt: Optional[float] = None,
ge: Optional[float] = None,
lt: Optional[float] = None,
le: Optional[float] = None,
multiple_of: Optional[float] = None,
max_digits: Optional[int] = None,
decimal_places: Optional[int] = None,
min_items: Optional[int] = None,
max_items: Optional[int] = None,
unique_items: Optional[bool] = None,
min_length: Optional[int] = None,
max_length: Optional[int] = None,
allow_mutation: bool = True,
regex: Optional[str] = None,
discriminator: Optional[str] = None,
repr: bool = True,
primary_key: Union[bool, UndefinedType] = Undefined,
foreign_key: Any = Undefined,
unique: Union[bool, UndefinedType] = Undefined,
nullable: Union[bool, UndefinedType] = Undefined,
index: Union[bool, UndefinedType] = Undefined,
sa_type: Type[Any] = Undefined,
@@ -169,12 +287,17 @@ def Field(
lt=lt,
le=le,
multiple_of=multiple_of,
max_digits=max_digits,
decimal_places=decimal_places,
min_items=min_items,
max_items=max_items,
unique_items=unique_items,
min_length=min_length,
max_length=max_length,
allow_mutation=allow_mutation,
regex=regex,
discriminator=discriminator,
repr=repr,
primary_key=primary_key,
foreign_key=foreign_key,
unique=unique,
@@ -190,6 +313,27 @@ def Field(
return field_info
@overload
def Relationship(
*,
back_populates: Optional[str] = None,
link_model: Optional[Any] = None,
sa_relationship_args: Optional[Sequence[Any]] = None,
sa_relationship_kwargs: Optional[Mapping[str, Any]] = None,
) -> Any:
...
@overload
def Relationship(
*,
back_populates: Optional[str] = None,
link_model: Optional[Any] = None,
sa_relationship: Optional[RelationshipProperty] = None, # type: ignore
) -> Any:
...
def Relationship(
*,
back_populates: Optional[str] = None,
@@ -308,9 +452,9 @@ class SQLModelMetaclass(ModelMetaclass, DeclarativeMeta):
config_registry = cast(registry, config_registry)
# If it was passed by kwargs, ensure it's also set in config
new_cls.__config__.registry = config_table
setattr(new_cls, "_sa_registry", config_registry)
setattr(new_cls, "metadata", config_registry.metadata)
setattr(new_cls, "__abstract__", True)
setattr(new_cls, "_sa_registry", config_registry) # noqa: B010
setattr(new_cls, "metadata", config_registry.metadata) # noqa: B010
setattr(new_cls, "__abstract__", True) # noqa: B010
return new_cls
# Override SQLAlchemy, allow both SQLAlchemy and plain Pydantic models
@@ -323,19 +467,16 @@ class SQLModelMetaclass(ModelMetaclass, DeclarativeMeta):
# triggers an error
base_is_table = False
for base in bases:
config = getattr(base, "__config__")
config = getattr(base, "__config__") # noqa: B009
if config and getattr(config, "table", False):
base_is_table = True
break
if getattr(cls.__config__, "table", False) and not base_is_table:
dict_used = dict_.copy()
for field_name, field_value in cls.__fields__.items():
dict_used[field_name] = get_column_from_field(field_value)
for rel_name, rel_info in cls.__sqlmodel_relationships__.items():
if rel_info.sa_relationship:
# There's a SQLAlchemy relationship declared, that takes precedence
# over anything else, use that and continue with the next attribute
dict_used[rel_name] = rel_info.sa_relationship
setattr(cls, rel_name, rel_info.sa_relationship) # Fix #315
continue
ann = cls.__annotations__[rel_name]
temp_field = ModelField.infer(
@@ -353,7 +494,7 @@ class SQLModelMetaclass(ModelMetaclass, DeclarativeMeta):
rel_kwargs["back_populates"] = rel_info.back_populates
if rel_info.link_model:
ins = inspect(rel_info.link_model)
local_table = getattr(ins, "local_table")
local_table = getattr(ins, "local_table") # noqa: B009
if local_table is None:
raise RuntimeError(
"Couldn't find the secondary table for "
@@ -368,9 +509,11 @@ class SQLModelMetaclass(ModelMetaclass, DeclarativeMeta):
rel_value: RelationshipProperty = relationship( # type: ignore
relationship_to, *rel_args, **rel_kwargs
)
dict_used[rel_name] = rel_value
setattr(cls, rel_name, rel_value) # Fix #315
DeclarativeMeta.__init__(cls, classname, bases, dict_used, **kw)
# SQLAlchemy no longer uses dict_
# Ref: https://github.com/sqlalchemy/sqlalchemy/commit/428ea01f00a9cc7f85e435018565eb6da7af1b77
# Tag: 1.4.36
DeclarativeMeta.__init__(cls, classname, bases, dict_, **kw)
else:
ModelMetaclass.__init__(cls, classname, bases, dict_, **kw)
@@ -379,45 +522,47 @@ def get_sqlalchemy_type(field: ModelField) -> Any:
if hasattr(field.field_info, "sa_type"):
if not issubclass(type(field.field_info.sa_type), type(Undefined)):
return field.field_info.sa_type
if issubclass(field.type_, str):
if field.field_info.max_length:
return AutoString(length=field.field_info.max_length)
return AutoString
if issubclass(field.type_, float):
return Float
if issubclass(field.type_, bool):
return Boolean
if issubclass(field.type_, int):
return Integer
if issubclass(field.type_, datetime):
return DateTime
if issubclass(field.type_, date):
return Date
if issubclass(field.type_, timedelta):
return Interval
if issubclass(field.type_, time):
return Time
if issubclass(field.type_, Enum):
return sa_Enum(field.type_)
if issubclass(field.type_, bytes):
return LargeBinary
if issubclass(field.type_, Decimal):
return Numeric(
precision=getattr(field.type_, "max_digits", None),
scale=getattr(field.type_, "decimal_places", None),
)
if issubclass(field.type_, ipaddress.IPv4Address):
return AutoString
if issubclass(field.type_, ipaddress.IPv4Network):
return AutoString
if issubclass(field.type_, ipaddress.IPv6Address):
return AutoString
if issubclass(field.type_, ipaddress.IPv6Network):
return AutoString
if issubclass(field.type_, Path):
return AutoString
if issubclass(field.type_, uuid.UUID):
return GUID
if isinstance(field.type_, type) and field.shape == SHAPE_SINGLETON:
# Check enums first as an enum can also be a str, needed by Pydantic/FastAPI
if issubclass(field.type_, Enum):
return sa_Enum(field.type_)
if issubclass(field.type_, str):
if field.field_info.max_length:
return AutoString(length=field.field_info.max_length)
return AutoString
if issubclass(field.type_, float):
return Float
if issubclass(field.type_, bool):
return Boolean
if issubclass(field.type_, int):
return Integer
if issubclass(field.type_, datetime):
return DateTime
if issubclass(field.type_, date):
return Date
if issubclass(field.type_, timedelta):
return Interval
if issubclass(field.type_, time):
return Time
if issubclass(field.type_, bytes):
return LargeBinary
if issubclass(field.type_, Decimal):
return Numeric(
precision=getattr(field.type_, "max_digits", None),
scale=getattr(field.type_, "decimal_places", None),
)
if issubclass(field.type_, ipaddress.IPv4Address):
return AutoString
if issubclass(field.type_, ipaddress.IPv4Network):
return AutoString
if issubclass(field.type_, ipaddress.IPv6Address):
return AutoString
if issubclass(field.type_, ipaddress.IPv6Network):
return AutoString
if issubclass(field.type_, Path):
return AutoString
if issubclass(field.type_, uuid.UUID):
return GUID
raise ValueError(f"The field {field.name} has no matching SQLAlchemy type")
@@ -426,21 +571,28 @@ def get_column_from_field(field: ModelField) -> Column: # type: ignore
if isinstance(sa_column, Column):
return sa_column
sa_type = get_sqlalchemy_type(field)
primary_key = getattr(field.field_info, "primary_key", False)
primary_key = getattr(field.field_info, "primary_key", Undefined)
if primary_key is Undefined:
primary_key = False
index = getattr(field.field_info, "index", Undefined)
if index is Undefined:
index = False
nullable = not primary_key and _is_field_noneable(field)
# Override derived nullability if the nullable property is set explicitly
# on the field
if hasattr(field.field_info, "nullable"):
field_nullable = getattr(field.field_info, "nullable")
if field_nullable != Undefined:
nullable = field_nullable
field_nullable = getattr(field.field_info, "nullable", Undefined) # noqa: B009
if field_nullable != Undefined:
assert not isinstance(field_nullable, UndefinedType)
nullable = field_nullable
args = []
foreign_key = getattr(field.field_info, "foreign_key", None)
unique = getattr(field.field_info, "unique", False)
foreign_key = getattr(field.field_info, "foreign_key", Undefined)
if foreign_key is Undefined:
foreign_key = None
unique = getattr(field.field_info, "unique", Undefined)
if unique is Undefined:
unique = False
if foreign_key:
assert isinstance(foreign_key, str)
args.append(ForeignKey(foreign_key))
kwargs = {
"primary_key": primary_key,
@@ -583,7 +735,11 @@ class SQLModel(BaseModel, metaclass=SQLModelMetaclass, registry=default_registry
def __repr_args__(self) -> Sequence[Tuple[Optional[str], Any]]:
# Don't show SQLAlchemy private attributes
return [(k, v) for k, v in self.__dict__.items() if not k.startswith("_sa_")]
return [
(k, v)
for k, v in super().__repr_args__()
if not (isinstance(k, str) and k.startswith("_sa_"))
]
# From Pydantic, override to enforce validation with dict
@classmethod

View File

@@ -4,11 +4,11 @@ from sqlalchemy import util
from sqlalchemy.orm import Query as _Query
from sqlalchemy.orm import Session as _Session
from sqlalchemy.sql.base import Executable as _Executable
from sqlmodel.sql.expression import Select, SelectOfScalar
from typing_extensions import Literal
from ..engine.result import Result, ScalarResult
from ..sql.base import Executable
from ..sql.expression import Select, SelectOfScalar
_TSelectParam = TypeVar("_TSelectParam")

View File

@@ -1,6 +1,5 @@
# WARNING: do not modify this code, it is generated by expression.py.jinja2
import sys
from datetime import datetime
from typing import (
TYPE_CHECKING,
@@ -12,7 +11,6 @@ from typing import (
Type,
TypeVar,
Union,
cast,
overload,
)
from uuid import UUID
@@ -24,36 +22,17 @@ from sqlalchemy.sql.expression import Select as _Select
_TSelect = TypeVar("_TSelect")
# Workaround Generics incompatibility in Python 3.6
# Ref: https://github.com/python/typing/issues/449#issuecomment-316061322
if sys.version_info.minor >= 7:
class Select(_Select, Generic[_TSelect]):
inherit_cache = True
class Select(_Select, Generic[_TSelect]):
inherit_cache = True
# This is not comparable to sqlalchemy.sql.selectable.ScalarSelect, that has a different
# purpose. This is the same as a normal SQLAlchemy Select class where there's only one
# entity, so the result will be converted to a scalar by default. This way writing
# for loops on the results will feel natural.
class SelectOfScalar(_Select, Generic[_TSelect]):
inherit_cache = True
else:
from typing import GenericMeta # type: ignore
class GenericSelectMeta(GenericMeta, _Select.__class__): # type: ignore
pass
class _Py36Select(_Select, Generic[_TSelect], metaclass=GenericSelectMeta):
inherit_cache = True
class _Py36SelectOfScalar(_Select, Generic[_TSelect], metaclass=GenericSelectMeta):
inherit_cache = True
# Cast them for editors to work correctly, from several tricks tried, this works
# for both VS Code and PyCharm
Select = cast("Select", _Py36Select) # type: ignore
SelectOfScalar = cast("SelectOfScalar", _Py36SelectOfScalar) # type: ignore
# This is not comparable to sqlalchemy.sql.selectable.ScalarSelect, that has a different
# purpose. This is the same as a normal SQLAlchemy Select class where there's only one
# entity, so the result will be converted to a scalar by default. This way writing
# for loops on the results will feel natural.
class SelectOfScalar(_Select, Generic[_TSelect]):
inherit_cache = True
if TYPE_CHECKING: # pragma: no cover

View File

@@ -1,4 +1,3 @@
import sys
from datetime import datetime
from typing import (
TYPE_CHECKING,
@@ -10,7 +9,6 @@ from typing import (
Type,
TypeVar,
Union,
cast,
overload,
)
from uuid import UUID
@@ -22,37 +20,15 @@ from sqlalchemy.sql.expression import Select as _Select
_TSelect = TypeVar("_TSelect")
# Workaround Generics incompatibility in Python 3.6
# Ref: https://github.com/python/typing/issues/449#issuecomment-316061322
if sys.version_info.minor >= 7:
class Select(_Select, Generic[_TSelect]):
inherit_cache = True
# This is not comparable to sqlalchemy.sql.selectable.ScalarSelect, that has a different
# purpose. This is the same as a normal SQLAlchemy Select class where there's only one
# entity, so the result will be converted to a scalar by default. This way writing
# for loops on the results will feel natural.
class SelectOfScalar(_Select, Generic[_TSelect]):
inherit_cache = True
else:
from typing import GenericMeta # type: ignore
class GenericSelectMeta(GenericMeta, _Select.__class__): # type: ignore
pass
class _Py36Select(_Select, Generic[_TSelect], metaclass=GenericSelectMeta):
inherit_cache = True
class _Py36SelectOfScalar(_Select, Generic[_TSelect], metaclass=GenericSelectMeta):
inherit_cache = True
# Cast them for editors to work correctly, from several tricks tried, this works
# for both VS Code and PyCharm
Select = cast("Select", _Py36Select) # type: ignore
SelectOfScalar = cast("SelectOfScalar", _Py36SelectOfScalar) # type: ignore
class Select(_Select, Generic[_TSelect]):
inherit_cache = True
# This is not comparable to sqlalchemy.sql.selectable.ScalarSelect, that has a different
# purpose. This is the same as a normal SQLAlchemy Select class where there's only one
# entity, so the result will be converted to a scalar by default. This way writing
# for loops on the results will feel natural.
class SelectOfScalar(_Select, Generic[_TSelect]):
inherit_cache = True
if TYPE_CHECKING: # pragma: no cover
from ..main import SQLModel

View File

@@ -8,7 +8,6 @@ from sqlalchemy.sql.type_api import TypeEngine
class AutoString(types.TypeDecorator): # type: ignore
impl = types.String
cache_ok = True
mysql_default_length = 255