Merge branch 'main' into main
This commit is contained in:
@@ -1,16 +1,16 @@
|
||||
__version__ = "0.0.8"
|
||||
__version__ = "0.0.10"
|
||||
|
||||
# Re-export from SQLAlchemy
|
||||
from sqlalchemy.engine import create_mock_engine as create_mock_engine
|
||||
from sqlalchemy.engine import engine_from_config as engine_from_config
|
||||
from sqlalchemy.inspection import inspect as inspect
|
||||
from sqlalchemy.schema import BLANK_SCHEMA as BLANK_SCHEMA
|
||||
from sqlalchemy.schema import DDL as DDL
|
||||
from sqlalchemy.schema import CheckConstraint as CheckConstraint
|
||||
from sqlalchemy.schema import Column as Column
|
||||
from sqlalchemy.schema import ColumnDefault as ColumnDefault
|
||||
from sqlalchemy.schema import Computed as Computed
|
||||
from sqlalchemy.schema import Constraint as Constraint
|
||||
from sqlalchemy.schema import DDL as DDL
|
||||
from sqlalchemy.schema import DefaultClause as DefaultClause
|
||||
from sqlalchemy.schema import FetchedValue as FetchedValue
|
||||
from sqlalchemy.schema import ForeignKey as ForeignKey
|
||||
@@ -23,6 +23,14 @@ from sqlalchemy.schema import Sequence as Sequence
|
||||
from sqlalchemy.schema import Table as Table
|
||||
from sqlalchemy.schema import ThreadLocalMetaData as ThreadLocalMetaData
|
||||
from sqlalchemy.schema import UniqueConstraint as UniqueConstraint
|
||||
from sqlalchemy.sql import LABEL_STYLE_DEFAULT as LABEL_STYLE_DEFAULT
|
||||
from sqlalchemy.sql import (
|
||||
LABEL_STYLE_DISAMBIGUATE_ONLY as LABEL_STYLE_DISAMBIGUATE_ONLY,
|
||||
)
|
||||
from sqlalchemy.sql import LABEL_STYLE_NONE as LABEL_STYLE_NONE
|
||||
from sqlalchemy.sql import (
|
||||
LABEL_STYLE_TABLENAME_PLUS_COL as LABEL_STYLE_TABLENAME_PLUS_COL,
|
||||
)
|
||||
from sqlalchemy.sql import alias as alias
|
||||
from sqlalchemy.sql import all_ as all_
|
||||
from sqlalchemy.sql import and_ as and_
|
||||
@@ -48,14 +56,6 @@ from sqlalchemy.sql import insert as insert
|
||||
from sqlalchemy.sql import intersect as intersect
|
||||
from sqlalchemy.sql import intersect_all as intersect_all
|
||||
from sqlalchemy.sql import join as join
|
||||
from sqlalchemy.sql import LABEL_STYLE_DEFAULT as LABEL_STYLE_DEFAULT
|
||||
from sqlalchemy.sql import (
|
||||
LABEL_STYLE_DISAMBIGUATE_ONLY as LABEL_STYLE_DISAMBIGUATE_ONLY,
|
||||
)
|
||||
from sqlalchemy.sql import LABEL_STYLE_NONE as LABEL_STYLE_NONE
|
||||
from sqlalchemy.sql import (
|
||||
LABEL_STYLE_TABLENAME_PLUS_COL as LABEL_STYLE_TABLENAME_PLUS_COL,
|
||||
)
|
||||
from sqlalchemy.sql import lambda_stmt as lambda_stmt
|
||||
from sqlalchemy.sql import lateral as lateral
|
||||
from sqlalchemy.sql import literal as literal
|
||||
@@ -85,55 +85,53 @@ from sqlalchemy.sql import values as values
|
||||
from sqlalchemy.sql import within_group as within_group
|
||||
from sqlalchemy.types import ARRAY as ARRAY
|
||||
from sqlalchemy.types import BIGINT as BIGINT
|
||||
from sqlalchemy.types import BigInteger as BigInteger
|
||||
from sqlalchemy.types import BINARY as BINARY
|
||||
from sqlalchemy.types import BLOB as BLOB
|
||||
from sqlalchemy.types import BOOLEAN as BOOLEAN
|
||||
from sqlalchemy.types import Boolean as Boolean
|
||||
from sqlalchemy.types import CHAR as CHAR
|
||||
from sqlalchemy.types import CLOB as CLOB
|
||||
from sqlalchemy.types import DATE as DATE
|
||||
from sqlalchemy.types import Date as Date
|
||||
from sqlalchemy.types import DATETIME as DATETIME
|
||||
from sqlalchemy.types import DateTime as DateTime
|
||||
from sqlalchemy.types import DECIMAL as DECIMAL
|
||||
from sqlalchemy.types import Enum as Enum
|
||||
from sqlalchemy.types import FLOAT as FLOAT
|
||||
from sqlalchemy.types import Float as Float
|
||||
from sqlalchemy.types import INT as INT
|
||||
from sqlalchemy.types import INTEGER as INTEGER
|
||||
from sqlalchemy.types import Integer as Integer
|
||||
from sqlalchemy.types import Interval as Interval
|
||||
from sqlalchemy.types import JSON as JSON
|
||||
from sqlalchemy.types import LargeBinary as LargeBinary
|
||||
from sqlalchemy.types import NCHAR as NCHAR
|
||||
from sqlalchemy.types import NUMERIC as NUMERIC
|
||||
from sqlalchemy.types import Numeric as Numeric
|
||||
from sqlalchemy.types import NVARCHAR as NVARCHAR
|
||||
from sqlalchemy.types import PickleType as PickleType
|
||||
from sqlalchemy.types import REAL as REAL
|
||||
from sqlalchemy.types import SMALLINT as SMALLINT
|
||||
from sqlalchemy.types import TEXT as TEXT
|
||||
from sqlalchemy.types import TIME as TIME
|
||||
from sqlalchemy.types import TIMESTAMP as TIMESTAMP
|
||||
from sqlalchemy.types import VARBINARY as VARBINARY
|
||||
from sqlalchemy.types import VARCHAR as VARCHAR
|
||||
from sqlalchemy.types import BigInteger as BigInteger
|
||||
from sqlalchemy.types import Boolean as Boolean
|
||||
from sqlalchemy.types import Date as Date
|
||||
from sqlalchemy.types import DateTime as DateTime
|
||||
from sqlalchemy.types import Enum as Enum
|
||||
from sqlalchemy.types import Float as Float
|
||||
from sqlalchemy.types import Integer as Integer
|
||||
from sqlalchemy.types import Interval as Interval
|
||||
from sqlalchemy.types import LargeBinary as LargeBinary
|
||||
from sqlalchemy.types import Numeric as Numeric
|
||||
from sqlalchemy.types import PickleType as PickleType
|
||||
from sqlalchemy.types import SmallInteger as SmallInteger
|
||||
from sqlalchemy.types import String as String
|
||||
from sqlalchemy.types import TEXT as TEXT
|
||||
from sqlalchemy.types import Text as Text
|
||||
from sqlalchemy.types import TIME as TIME
|
||||
from sqlalchemy.types import Time as Time
|
||||
from sqlalchemy.types import TIMESTAMP as TIMESTAMP
|
||||
from sqlalchemy.types import TypeDecorator as TypeDecorator
|
||||
from sqlalchemy.types import Unicode as Unicode
|
||||
from sqlalchemy.types import UnicodeText as UnicodeText
|
||||
from sqlalchemy.types import VARBINARY as VARBINARY
|
||||
from sqlalchemy.types import VARCHAR as VARCHAR
|
||||
|
||||
# Extensions and modifications of SQLAlchemy in SQLModel
|
||||
# From SQLModel, modifications of SQLAlchemy or equivalents of Pydantic
|
||||
from .engine.create import create_engine as create_engine
|
||||
from .orm.session import Session as Session
|
||||
from .sql.expression import select as select
|
||||
from .sql.expression import col as col
|
||||
from .sql.sqltypes import AutoString as AutoString
|
||||
|
||||
# Export SQLModel specifics (equivalent to Pydantic)
|
||||
from .main import SQLModel as SQLModel
|
||||
from .main import Field as Field
|
||||
from .main import Relationship as Relationship
|
||||
from .main import SQLModel as SQLModel
|
||||
from .orm.session import Session as Session
|
||||
from .sql.expression import col as col
|
||||
from .sql.expression import select as select
|
||||
from .sql.sqltypes import AutoString as AutoString
|
||||
|
||||
@@ -6,7 +6,7 @@ class _DefaultPlaceholder:
|
||||
You shouldn't use this class directly.
|
||||
|
||||
It's used internally to recognize when a default value has been overwritten, even
|
||||
if the overriden default value was truthy.
|
||||
if the overridden default value was truthy.
|
||||
"""
|
||||
|
||||
def __init__(self, value: Any):
|
||||
@@ -27,6 +27,6 @@ def Default(value: _TDefaultType) -> _TDefaultType:
|
||||
You shouldn't use this function directly.
|
||||
|
||||
It's used internally to recognize when a default value has been overwritten, even
|
||||
if the overriden default value was truthy.
|
||||
if the overridden default value was truthy.
|
||||
"""
|
||||
return _DefaultPlaceholder(value) # type: ignore
|
||||
|
||||
@@ -1,17 +1,17 @@
|
||||
from typing import Any, Mapping, Optional, Sequence, TypeVar, Union
|
||||
from typing import Any, Mapping, Optional, Sequence, TypeVar, Union, overload
|
||||
|
||||
from sqlalchemy import util
|
||||
from sqlalchemy.ext.asyncio import AsyncSession as _AsyncSession
|
||||
from sqlalchemy.ext.asyncio import engine
|
||||
from sqlalchemy.ext.asyncio.engine import AsyncConnection, AsyncEngine
|
||||
from sqlalchemy.util.concurrency import greenlet_spawn
|
||||
from sqlmodel.sql.base import Executable
|
||||
|
||||
from ...engine.result import ScalarResult
|
||||
from ...engine.result import Result, ScalarResult
|
||||
from ...orm.session import Session
|
||||
from ...sql.expression import Select
|
||||
from ...sql.base import Executable
|
||||
from ...sql.expression import Select, SelectOfScalar
|
||||
|
||||
_T = TypeVar("_T")
|
||||
_TSelectParam = TypeVar("_TSelectParam")
|
||||
|
||||
|
||||
class AsyncSession(_AsyncSession):
|
||||
@@ -40,14 +40,46 @@ class AsyncSession(_AsyncSession):
|
||||
Session(bind=bind, binds=binds, **kw) # type: ignore
|
||||
)
|
||||
|
||||
@overload
|
||||
async def exec(
|
||||
self,
|
||||
statement: Union[Select[_T], Executable[_T]],
|
||||
statement: Select[_TSelectParam],
|
||||
*,
|
||||
params: Optional[Union[Mapping[str, Any], Sequence[Mapping[str, Any]]]] = None,
|
||||
execution_options: Mapping[str, Any] = util.EMPTY_DICT,
|
||||
bind_arguments: Optional[Mapping[str, Any]] = None,
|
||||
_parent_execute_state: Optional[Any] = None,
|
||||
_add_event: Optional[Any] = None,
|
||||
**kw: Any,
|
||||
) -> Result[_TSelectParam]:
|
||||
...
|
||||
|
||||
@overload
|
||||
async def exec(
|
||||
self,
|
||||
statement: SelectOfScalar[_TSelectParam],
|
||||
*,
|
||||
params: Optional[Union[Mapping[str, Any], Sequence[Mapping[str, Any]]]] = None,
|
||||
execution_options: Mapping[str, Any] = util.EMPTY_DICT,
|
||||
bind_arguments: Optional[Mapping[str, Any]] = None,
|
||||
_parent_execute_state: Optional[Any] = None,
|
||||
_add_event: Optional[Any] = None,
|
||||
**kw: Any,
|
||||
) -> ScalarResult[_TSelectParam]:
|
||||
...
|
||||
|
||||
async def exec(
|
||||
self,
|
||||
statement: Union[
|
||||
Select[_TSelectParam],
|
||||
SelectOfScalar[_TSelectParam],
|
||||
Executable[_TSelectParam],
|
||||
],
|
||||
params: Optional[Union[Mapping[str, Any], Sequence[Mapping[str, Any]]]] = None,
|
||||
execution_options: Mapping[Any, Any] = util.EMPTY_DICT,
|
||||
bind_arguments: Optional[Mapping[str, Any]] = None,
|
||||
**kw: Any,
|
||||
) -> ScalarResult[_T]:
|
||||
) -> Union[Result[_TSelectParam], ScalarResult[_TSelectParam]]:
|
||||
# TODO: the documentation says execution_options accepts a dict, but only
|
||||
# util.immutabledict has the union() method. Is this a bug in SQLAlchemy?
|
||||
execution_options = execution_options.union({"prebuffer_rows": True}) # type: ignore
|
||||
|
||||
288
sqlmodel/main.py
288
sqlmodel/main.py
@@ -11,6 +11,7 @@ from typing import (
|
||||
Callable,
|
||||
ClassVar,
|
||||
Dict,
|
||||
ForwardRef,
|
||||
List,
|
||||
Mapping,
|
||||
Optional,
|
||||
@@ -21,19 +22,29 @@ from typing import (
|
||||
TypeVar,
|
||||
Union,
|
||||
cast,
|
||||
overload,
|
||||
)
|
||||
|
||||
from pydantic import BaseConfig, BaseModel
|
||||
from pydantic.errors import ConfigError, DictError
|
||||
from pydantic.fields import SHAPE_SINGLETON
|
||||
from pydantic.fields import SHAPE_SINGLETON, ModelField, Undefined, UndefinedType
|
||||
from pydantic.fields import FieldInfo as PydanticFieldInfo
|
||||
from pydantic.fields import ModelField, Undefined, UndefinedType
|
||||
from pydantic.main import ModelMetaclass, validate_model
|
||||
from pydantic.typing import ForwardRef, NoArgAnyCallable, resolve_annotations
|
||||
from pydantic.typing import NoArgAnyCallable, resolve_annotations
|
||||
from pydantic.utils import ROOT_KEY, Representation
|
||||
from sqlalchemy import Boolean, Column, Date, DateTime
|
||||
from sqlalchemy import (
|
||||
Boolean,
|
||||
Column,
|
||||
Date,
|
||||
DateTime,
|
||||
Float,
|
||||
ForeignKey,
|
||||
Integer,
|
||||
Interval,
|
||||
Numeric,
|
||||
inspect,
|
||||
)
|
||||
from sqlalchemy import Enum as sa_Enum
|
||||
from sqlalchemy import Float, ForeignKey, Integer, Interval, Numeric, inspect
|
||||
from sqlalchemy.orm import RelationshipProperty, declared_attr, registry, relationship
|
||||
from sqlalchemy.orm.attributes import set_attribute
|
||||
from sqlalchemy.orm.decl_api import DeclarativeMeta
|
||||
@@ -78,6 +89,28 @@ class FieldInfo(PydanticFieldInfo):
|
||||
"Passing sa_column_kwargs is not supported when "
|
||||
"also passing a sa_column"
|
||||
)
|
||||
if primary_key is not Undefined:
|
||||
raise RuntimeError(
|
||||
"Passing primary_key is not supported when "
|
||||
"also passing a sa_column"
|
||||
)
|
||||
if nullable is not Undefined:
|
||||
raise RuntimeError(
|
||||
"Passing nullable is not supported when " "also passing a sa_column"
|
||||
)
|
||||
if foreign_key is not Undefined:
|
||||
raise RuntimeError(
|
||||
"Passing foreign_key is not supported when "
|
||||
"also passing a sa_column"
|
||||
)
|
||||
if unique is not Undefined:
|
||||
raise RuntimeError(
|
||||
"Passing unique is not supported when " "also passing a sa_column"
|
||||
)
|
||||
if index is not Undefined:
|
||||
raise RuntimeError(
|
||||
"Passing index is not supported when " "also passing a sa_column"
|
||||
)
|
||||
super().__init__(default=default, **kwargs)
|
||||
self.primary_key = primary_key
|
||||
self.nullable = nullable
|
||||
@@ -118,6 +151,7 @@ class RelationshipInfo(Representation):
|
||||
self.sa_relationship_kwargs = sa_relationship_kwargs
|
||||
|
||||
|
||||
@overload
|
||||
def Field(
|
||||
default: Any = Undefined,
|
||||
*,
|
||||
@@ -137,15 +171,99 @@ def Field(
|
||||
lt: Optional[float] = None,
|
||||
le: Optional[float] = None,
|
||||
multiple_of: Optional[float] = None,
|
||||
max_digits: Optional[int] = None,
|
||||
decimal_places: Optional[int] = None,
|
||||
min_items: Optional[int] = None,
|
||||
max_items: Optional[int] = None,
|
||||
unique_items: Optional[bool] = None,
|
||||
min_length: Optional[int] = None,
|
||||
max_length: Optional[int] = None,
|
||||
allow_mutation: bool = True,
|
||||
regex: Optional[str] = None,
|
||||
primary_key: bool = False,
|
||||
foreign_key: Optional[Any] = None,
|
||||
unique: bool = False,
|
||||
discriminator: Optional[str] = None,
|
||||
repr: bool = True,
|
||||
primary_key: Union[bool, UndefinedType] = Undefined,
|
||||
foreign_key: Any = Undefined,
|
||||
unique: Union[bool, UndefinedType] = Undefined,
|
||||
nullable: Union[bool, UndefinedType] = Undefined,
|
||||
index: Union[bool, UndefinedType] = Undefined,
|
||||
sa_column_args: Union[Sequence[Any], UndefinedType] = Undefined,
|
||||
sa_column_kwargs: Union[Mapping[str, Any], UndefinedType] = Undefined,
|
||||
schema_extra: Optional[Dict[str, Any]] = None,
|
||||
) -> Any:
|
||||
...
|
||||
|
||||
|
||||
@overload
|
||||
def Field(
|
||||
default: Any = Undefined,
|
||||
*,
|
||||
default_factory: Optional[NoArgAnyCallable] = None,
|
||||
alias: Optional[str] = None,
|
||||
title: Optional[str] = None,
|
||||
description: Optional[str] = None,
|
||||
exclude: Union[
|
||||
AbstractSet[Union[int, str]], Mapping[Union[int, str], Any], Any
|
||||
] = None,
|
||||
include: Union[
|
||||
AbstractSet[Union[int, str]], Mapping[Union[int, str], Any], Any
|
||||
] = None,
|
||||
const: Optional[bool] = None,
|
||||
gt: Optional[float] = None,
|
||||
ge: Optional[float] = None,
|
||||
lt: Optional[float] = None,
|
||||
le: Optional[float] = None,
|
||||
multiple_of: Optional[float] = None,
|
||||
max_digits: Optional[int] = None,
|
||||
decimal_places: Optional[int] = None,
|
||||
min_items: Optional[int] = None,
|
||||
max_items: Optional[int] = None,
|
||||
unique_items: Optional[bool] = None,
|
||||
min_length: Optional[int] = None,
|
||||
max_length: Optional[int] = None,
|
||||
allow_mutation: bool = True,
|
||||
regex: Optional[str] = None,
|
||||
discriminator: Optional[str] = None,
|
||||
repr: bool = True,
|
||||
sa_column: Union[Column, UndefinedType] = Undefined, # type: ignore
|
||||
schema_extra: Optional[Dict[str, Any]] = None,
|
||||
) -> Any:
|
||||
...
|
||||
|
||||
|
||||
def Field(
|
||||
default: Any = Undefined,
|
||||
*,
|
||||
default_factory: Optional[NoArgAnyCallable] = None,
|
||||
alias: Optional[str] = None,
|
||||
title: Optional[str] = None,
|
||||
description: Optional[str] = None,
|
||||
exclude: Union[
|
||||
AbstractSet[Union[int, str]], Mapping[Union[int, str], Any], Any
|
||||
] = None,
|
||||
include: Union[
|
||||
AbstractSet[Union[int, str]], Mapping[Union[int, str], Any], Any
|
||||
] = None,
|
||||
const: Optional[bool] = None,
|
||||
gt: Optional[float] = None,
|
||||
ge: Optional[float] = None,
|
||||
lt: Optional[float] = None,
|
||||
le: Optional[float] = None,
|
||||
multiple_of: Optional[float] = None,
|
||||
max_digits: Optional[int] = None,
|
||||
decimal_places: Optional[int] = None,
|
||||
min_items: Optional[int] = None,
|
||||
max_items: Optional[int] = None,
|
||||
unique_items: Optional[bool] = None,
|
||||
min_length: Optional[int] = None,
|
||||
max_length: Optional[int] = None,
|
||||
allow_mutation: bool = True,
|
||||
regex: Optional[str] = None,
|
||||
discriminator: Optional[str] = None,
|
||||
repr: bool = True,
|
||||
primary_key: Union[bool, UndefinedType] = Undefined,
|
||||
foreign_key: Any = Undefined,
|
||||
unique: Union[bool, UndefinedType] = Undefined,
|
||||
nullable: Union[bool, UndefinedType] = Undefined,
|
||||
index: Union[bool, UndefinedType] = Undefined,
|
||||
sa_type: Type[Any] = Undefined,
|
||||
@@ -169,12 +287,17 @@ def Field(
|
||||
lt=lt,
|
||||
le=le,
|
||||
multiple_of=multiple_of,
|
||||
max_digits=max_digits,
|
||||
decimal_places=decimal_places,
|
||||
min_items=min_items,
|
||||
max_items=max_items,
|
||||
unique_items=unique_items,
|
||||
min_length=min_length,
|
||||
max_length=max_length,
|
||||
allow_mutation=allow_mutation,
|
||||
regex=regex,
|
||||
discriminator=discriminator,
|
||||
repr=repr,
|
||||
primary_key=primary_key,
|
||||
foreign_key=foreign_key,
|
||||
unique=unique,
|
||||
@@ -190,6 +313,27 @@ def Field(
|
||||
return field_info
|
||||
|
||||
|
||||
@overload
|
||||
def Relationship(
|
||||
*,
|
||||
back_populates: Optional[str] = None,
|
||||
link_model: Optional[Any] = None,
|
||||
sa_relationship_args: Optional[Sequence[Any]] = None,
|
||||
sa_relationship_kwargs: Optional[Mapping[str, Any]] = None,
|
||||
) -> Any:
|
||||
...
|
||||
|
||||
|
||||
@overload
|
||||
def Relationship(
|
||||
*,
|
||||
back_populates: Optional[str] = None,
|
||||
link_model: Optional[Any] = None,
|
||||
sa_relationship: Optional[RelationshipProperty] = None, # type: ignore
|
||||
) -> Any:
|
||||
...
|
||||
|
||||
|
||||
def Relationship(
|
||||
*,
|
||||
back_populates: Optional[str] = None,
|
||||
@@ -308,9 +452,9 @@ class SQLModelMetaclass(ModelMetaclass, DeclarativeMeta):
|
||||
config_registry = cast(registry, config_registry)
|
||||
# If it was passed by kwargs, ensure it's also set in config
|
||||
new_cls.__config__.registry = config_table
|
||||
setattr(new_cls, "_sa_registry", config_registry)
|
||||
setattr(new_cls, "metadata", config_registry.metadata)
|
||||
setattr(new_cls, "__abstract__", True)
|
||||
setattr(new_cls, "_sa_registry", config_registry) # noqa: B010
|
||||
setattr(new_cls, "metadata", config_registry.metadata) # noqa: B010
|
||||
setattr(new_cls, "__abstract__", True) # noqa: B010
|
||||
return new_cls
|
||||
|
||||
# Override SQLAlchemy, allow both SQLAlchemy and plain Pydantic models
|
||||
@@ -323,19 +467,16 @@ class SQLModelMetaclass(ModelMetaclass, DeclarativeMeta):
|
||||
# triggers an error
|
||||
base_is_table = False
|
||||
for base in bases:
|
||||
config = getattr(base, "__config__")
|
||||
config = getattr(base, "__config__") # noqa: B009
|
||||
if config and getattr(config, "table", False):
|
||||
base_is_table = True
|
||||
break
|
||||
if getattr(cls.__config__, "table", False) and not base_is_table:
|
||||
dict_used = dict_.copy()
|
||||
for field_name, field_value in cls.__fields__.items():
|
||||
dict_used[field_name] = get_column_from_field(field_value)
|
||||
for rel_name, rel_info in cls.__sqlmodel_relationships__.items():
|
||||
if rel_info.sa_relationship:
|
||||
# There's a SQLAlchemy relationship declared, that takes precedence
|
||||
# over anything else, use that and continue with the next attribute
|
||||
dict_used[rel_name] = rel_info.sa_relationship
|
||||
setattr(cls, rel_name, rel_info.sa_relationship) # Fix #315
|
||||
continue
|
||||
ann = cls.__annotations__[rel_name]
|
||||
temp_field = ModelField.infer(
|
||||
@@ -353,7 +494,7 @@ class SQLModelMetaclass(ModelMetaclass, DeclarativeMeta):
|
||||
rel_kwargs["back_populates"] = rel_info.back_populates
|
||||
if rel_info.link_model:
|
||||
ins = inspect(rel_info.link_model)
|
||||
local_table = getattr(ins, "local_table")
|
||||
local_table = getattr(ins, "local_table") # noqa: B009
|
||||
if local_table is None:
|
||||
raise RuntimeError(
|
||||
"Couldn't find the secondary table for "
|
||||
@@ -368,9 +509,11 @@ class SQLModelMetaclass(ModelMetaclass, DeclarativeMeta):
|
||||
rel_value: RelationshipProperty = relationship( # type: ignore
|
||||
relationship_to, *rel_args, **rel_kwargs
|
||||
)
|
||||
dict_used[rel_name] = rel_value
|
||||
setattr(cls, rel_name, rel_value) # Fix #315
|
||||
DeclarativeMeta.__init__(cls, classname, bases, dict_used, **kw)
|
||||
# SQLAlchemy no longer uses dict_
|
||||
# Ref: https://github.com/sqlalchemy/sqlalchemy/commit/428ea01f00a9cc7f85e435018565eb6da7af1b77
|
||||
# Tag: 1.4.36
|
||||
DeclarativeMeta.__init__(cls, classname, bases, dict_, **kw)
|
||||
else:
|
||||
ModelMetaclass.__init__(cls, classname, bases, dict_, **kw)
|
||||
|
||||
@@ -379,45 +522,47 @@ def get_sqlalchemy_type(field: ModelField) -> Any:
|
||||
if hasattr(field.field_info, "sa_type"):
|
||||
if not issubclass(type(field.field_info.sa_type), type(Undefined)):
|
||||
return field.field_info.sa_type
|
||||
if issubclass(field.type_, str):
|
||||
if field.field_info.max_length:
|
||||
return AutoString(length=field.field_info.max_length)
|
||||
return AutoString
|
||||
if issubclass(field.type_, float):
|
||||
return Float
|
||||
if issubclass(field.type_, bool):
|
||||
return Boolean
|
||||
if issubclass(field.type_, int):
|
||||
return Integer
|
||||
if issubclass(field.type_, datetime):
|
||||
return DateTime
|
||||
if issubclass(field.type_, date):
|
||||
return Date
|
||||
if issubclass(field.type_, timedelta):
|
||||
return Interval
|
||||
if issubclass(field.type_, time):
|
||||
return Time
|
||||
if issubclass(field.type_, Enum):
|
||||
return sa_Enum(field.type_)
|
||||
if issubclass(field.type_, bytes):
|
||||
return LargeBinary
|
||||
if issubclass(field.type_, Decimal):
|
||||
return Numeric(
|
||||
precision=getattr(field.type_, "max_digits", None),
|
||||
scale=getattr(field.type_, "decimal_places", None),
|
||||
)
|
||||
if issubclass(field.type_, ipaddress.IPv4Address):
|
||||
return AutoString
|
||||
if issubclass(field.type_, ipaddress.IPv4Network):
|
||||
return AutoString
|
||||
if issubclass(field.type_, ipaddress.IPv6Address):
|
||||
return AutoString
|
||||
if issubclass(field.type_, ipaddress.IPv6Network):
|
||||
return AutoString
|
||||
if issubclass(field.type_, Path):
|
||||
return AutoString
|
||||
if issubclass(field.type_, uuid.UUID):
|
||||
return GUID
|
||||
if isinstance(field.type_, type) and field.shape == SHAPE_SINGLETON:
|
||||
# Check enums first as an enum can also be a str, needed by Pydantic/FastAPI
|
||||
if issubclass(field.type_, Enum):
|
||||
return sa_Enum(field.type_)
|
||||
if issubclass(field.type_, str):
|
||||
if field.field_info.max_length:
|
||||
return AutoString(length=field.field_info.max_length)
|
||||
return AutoString
|
||||
if issubclass(field.type_, float):
|
||||
return Float
|
||||
if issubclass(field.type_, bool):
|
||||
return Boolean
|
||||
if issubclass(field.type_, int):
|
||||
return Integer
|
||||
if issubclass(field.type_, datetime):
|
||||
return DateTime
|
||||
if issubclass(field.type_, date):
|
||||
return Date
|
||||
if issubclass(field.type_, timedelta):
|
||||
return Interval
|
||||
if issubclass(field.type_, time):
|
||||
return Time
|
||||
if issubclass(field.type_, bytes):
|
||||
return LargeBinary
|
||||
if issubclass(field.type_, Decimal):
|
||||
return Numeric(
|
||||
precision=getattr(field.type_, "max_digits", None),
|
||||
scale=getattr(field.type_, "decimal_places", None),
|
||||
)
|
||||
if issubclass(field.type_, ipaddress.IPv4Address):
|
||||
return AutoString
|
||||
if issubclass(field.type_, ipaddress.IPv4Network):
|
||||
return AutoString
|
||||
if issubclass(field.type_, ipaddress.IPv6Address):
|
||||
return AutoString
|
||||
if issubclass(field.type_, ipaddress.IPv6Network):
|
||||
return AutoString
|
||||
if issubclass(field.type_, Path):
|
||||
return AutoString
|
||||
if issubclass(field.type_, uuid.UUID):
|
||||
return GUID
|
||||
raise ValueError(f"The field {field.name} has no matching SQLAlchemy type")
|
||||
|
||||
|
||||
@@ -426,21 +571,28 @@ def get_column_from_field(field: ModelField) -> Column: # type: ignore
|
||||
if isinstance(sa_column, Column):
|
||||
return sa_column
|
||||
sa_type = get_sqlalchemy_type(field)
|
||||
primary_key = getattr(field.field_info, "primary_key", False)
|
||||
primary_key = getattr(field.field_info, "primary_key", Undefined)
|
||||
if primary_key is Undefined:
|
||||
primary_key = False
|
||||
index = getattr(field.field_info, "index", Undefined)
|
||||
if index is Undefined:
|
||||
index = False
|
||||
nullable = not primary_key and _is_field_noneable(field)
|
||||
# Override derived nullability if the nullable property is set explicitly
|
||||
# on the field
|
||||
if hasattr(field.field_info, "nullable"):
|
||||
field_nullable = getattr(field.field_info, "nullable")
|
||||
if field_nullable != Undefined:
|
||||
nullable = field_nullable
|
||||
field_nullable = getattr(field.field_info, "nullable", Undefined) # noqa: B009
|
||||
if field_nullable != Undefined:
|
||||
assert not isinstance(field_nullable, UndefinedType)
|
||||
nullable = field_nullable
|
||||
args = []
|
||||
foreign_key = getattr(field.field_info, "foreign_key", None)
|
||||
unique = getattr(field.field_info, "unique", False)
|
||||
foreign_key = getattr(field.field_info, "foreign_key", Undefined)
|
||||
if foreign_key is Undefined:
|
||||
foreign_key = None
|
||||
unique = getattr(field.field_info, "unique", Undefined)
|
||||
if unique is Undefined:
|
||||
unique = False
|
||||
if foreign_key:
|
||||
assert isinstance(foreign_key, str)
|
||||
args.append(ForeignKey(foreign_key))
|
||||
kwargs = {
|
||||
"primary_key": primary_key,
|
||||
@@ -583,7 +735,11 @@ class SQLModel(BaseModel, metaclass=SQLModelMetaclass, registry=default_registry
|
||||
|
||||
def __repr_args__(self) -> Sequence[Tuple[Optional[str], Any]]:
|
||||
# Don't show SQLAlchemy private attributes
|
||||
return [(k, v) for k, v in self.__dict__.items() if not k.startswith("_sa_")]
|
||||
return [
|
||||
(k, v)
|
||||
for k, v in super().__repr_args__()
|
||||
if not (isinstance(k, str) and k.startswith("_sa_"))
|
||||
]
|
||||
|
||||
# From Pydantic, override to enforce validation with dict
|
||||
@classmethod
|
||||
|
||||
@@ -4,11 +4,11 @@ from sqlalchemy import util
|
||||
from sqlalchemy.orm import Query as _Query
|
||||
from sqlalchemy.orm import Session as _Session
|
||||
from sqlalchemy.sql.base import Executable as _Executable
|
||||
from sqlmodel.sql.expression import Select, SelectOfScalar
|
||||
from typing_extensions import Literal
|
||||
|
||||
from ..engine.result import Result, ScalarResult
|
||||
from ..sql.base import Executable
|
||||
from ..sql.expression import Select, SelectOfScalar
|
||||
|
||||
_TSelectParam = TypeVar("_TSelectParam")
|
||||
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
# WARNING: do not modify this code, it is generated by expression.py.jinja2
|
||||
|
||||
import sys
|
||||
from datetime import datetime
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
@@ -12,7 +11,6 @@ from typing import (
|
||||
Type,
|
||||
TypeVar,
|
||||
Union,
|
||||
cast,
|
||||
overload,
|
||||
)
|
||||
from uuid import UUID
|
||||
@@ -24,36 +22,17 @@ from sqlalchemy.sql.expression import Select as _Select
|
||||
|
||||
_TSelect = TypeVar("_TSelect")
|
||||
|
||||
# Workaround Generics incompatibility in Python 3.6
|
||||
# Ref: https://github.com/python/typing/issues/449#issuecomment-316061322
|
||||
if sys.version_info.minor >= 7:
|
||||
|
||||
class Select(_Select, Generic[_TSelect]):
|
||||
inherit_cache = True
|
||||
class Select(_Select, Generic[_TSelect]):
|
||||
inherit_cache = True
|
||||
|
||||
# This is not comparable to sqlalchemy.sql.selectable.ScalarSelect, that has a different
|
||||
# purpose. This is the same as a normal SQLAlchemy Select class where there's only one
|
||||
# entity, so the result will be converted to a scalar by default. This way writing
|
||||
# for loops on the results will feel natural.
|
||||
class SelectOfScalar(_Select, Generic[_TSelect]):
|
||||
inherit_cache = True
|
||||
|
||||
else:
|
||||
from typing import GenericMeta # type: ignore
|
||||
|
||||
class GenericSelectMeta(GenericMeta, _Select.__class__): # type: ignore
|
||||
pass
|
||||
|
||||
class _Py36Select(_Select, Generic[_TSelect], metaclass=GenericSelectMeta):
|
||||
inherit_cache = True
|
||||
|
||||
class _Py36SelectOfScalar(_Select, Generic[_TSelect], metaclass=GenericSelectMeta):
|
||||
inherit_cache = True
|
||||
|
||||
# Cast them for editors to work correctly, from several tricks tried, this works
|
||||
# for both VS Code and PyCharm
|
||||
Select = cast("Select", _Py36Select) # type: ignore
|
||||
SelectOfScalar = cast("SelectOfScalar", _Py36SelectOfScalar) # type: ignore
|
||||
# This is not comparable to sqlalchemy.sql.selectable.ScalarSelect, that has a different
|
||||
# purpose. This is the same as a normal SQLAlchemy Select class where there's only one
|
||||
# entity, so the result will be converted to a scalar by default. This way writing
|
||||
# for loops on the results will feel natural.
|
||||
class SelectOfScalar(_Select, Generic[_TSelect]):
|
||||
inherit_cache = True
|
||||
|
||||
|
||||
if TYPE_CHECKING: # pragma: no cover
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import sys
|
||||
from datetime import datetime
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
@@ -10,7 +9,6 @@ from typing import (
|
||||
Type,
|
||||
TypeVar,
|
||||
Union,
|
||||
cast,
|
||||
overload,
|
||||
)
|
||||
from uuid import UUID
|
||||
@@ -22,37 +20,15 @@ from sqlalchemy.sql.expression import Select as _Select
|
||||
|
||||
_TSelect = TypeVar("_TSelect")
|
||||
|
||||
# Workaround Generics incompatibility in Python 3.6
|
||||
# Ref: https://github.com/python/typing/issues/449#issuecomment-316061322
|
||||
if sys.version_info.minor >= 7:
|
||||
|
||||
class Select(_Select, Generic[_TSelect]):
|
||||
inherit_cache = True
|
||||
|
||||
# This is not comparable to sqlalchemy.sql.selectable.ScalarSelect, that has a different
|
||||
# purpose. This is the same as a normal SQLAlchemy Select class where there's only one
|
||||
# entity, so the result will be converted to a scalar by default. This way writing
|
||||
# for loops on the results will feel natural.
|
||||
class SelectOfScalar(_Select, Generic[_TSelect]):
|
||||
inherit_cache = True
|
||||
|
||||
else:
|
||||
from typing import GenericMeta # type: ignore
|
||||
|
||||
class GenericSelectMeta(GenericMeta, _Select.__class__): # type: ignore
|
||||
pass
|
||||
|
||||
class _Py36Select(_Select, Generic[_TSelect], metaclass=GenericSelectMeta):
|
||||
inherit_cache = True
|
||||
|
||||
class _Py36SelectOfScalar(_Select, Generic[_TSelect], metaclass=GenericSelectMeta):
|
||||
inherit_cache = True
|
||||
|
||||
# Cast them for editors to work correctly, from several tricks tried, this works
|
||||
# for both VS Code and PyCharm
|
||||
Select = cast("Select", _Py36Select) # type: ignore
|
||||
SelectOfScalar = cast("SelectOfScalar", _Py36SelectOfScalar) # type: ignore
|
||||
class Select(_Select, Generic[_TSelect]):
|
||||
inherit_cache = True
|
||||
|
||||
# This is not comparable to sqlalchemy.sql.selectable.ScalarSelect, that has a different
|
||||
# purpose. This is the same as a normal SQLAlchemy Select class where there's only one
|
||||
# entity, so the result will be converted to a scalar by default. This way writing
|
||||
# for loops on the results will feel natural.
|
||||
class SelectOfScalar(_Select, Generic[_TSelect]):
|
||||
inherit_cache = True
|
||||
|
||||
if TYPE_CHECKING: # pragma: no cover
|
||||
from ..main import SQLModel
|
||||
|
||||
@@ -8,7 +8,6 @@ from sqlalchemy.sql.type_api import TypeEngine
|
||||
|
||||
|
||||
class AutoString(types.TypeDecorator): # type: ignore
|
||||
|
||||
impl = types.String
|
||||
cache_ok = True
|
||||
mysql_default_length = 255
|
||||
|
||||
Reference in New Issue
Block a user