2024-10-07 21:21:59 +00:00

1011 lines
36 KiB
Python

import ipaddress
import uuid
import weakref
from datetime import date, datetime, time, timedelta
from decimal import Decimal
from enum import Enum
from pathlib import Path
from typing import (
TYPE_CHECKING,
AbstractSet,
Any,
Callable,
ClassVar,
Dict,
List,
Mapping,
Optional,
Sequence,
Set,
Tuple,
Type,
TypeVar,
Union,
cast,
overload,
)
from pydantic import BaseModel, EmailStr
from pydantic.fields import FieldInfo as PydanticFieldInfo
from sqlalchemy import (
Boolean,
Column,
Date,
DateTime,
Float,
ForeignKey,
Integer,
Interval,
Numeric,
inspect,
)
from sqlalchemy import Enum as sa_Enum
from sqlalchemy.orm import (
Mapped,
RelationshipProperty,
declared_attr,
registry,
relationship,
)
from sqlalchemy.orm.attributes import set_attribute
from sqlalchemy.orm.decl_api import DeclarativeMeta
from sqlalchemy.orm.instrumentation import is_instrumented
from sqlalchemy.sql.schema import MetaData
from sqlalchemy.sql.sqltypes import LargeBinary, Time, Uuid
from typing_extensions import Literal, TypeAlias, deprecated, get_origin
from ._compat import ( # type: ignore[attr-defined]
IS_PYDANTIC_V2,
PYDANTIC_VERSION,
BaseConfig,
ModelField,
ModelMetaclass,
Representation,
SQLModelConfig,
Undefined,
UndefinedType,
_calculate_keys,
finish_init,
get_annotations,
get_config_value,
get_field_metadata,
get_model_fields,
get_relationship_to,
get_sa_type_from_field,
init_pydantic_private_attrs,
is_field_noneable,
is_table_model_class,
post_init_field_info,
set_config_value,
sqlmodel_init,
sqlmodel_validate,
)
from .sql.sqltypes import AutoString
if TYPE_CHECKING:
from pydantic._internal._model_construction import ModelMetaclass as ModelMetaclass
from pydantic._internal._repr import Representation as Representation
from pydantic_core import PydanticUndefined as Undefined
from pydantic_core import PydanticUndefinedType as UndefinedType
_T = TypeVar("_T")
NoArgAnyCallable = Callable[[], Any]
IncEx: TypeAlias = Union[
Set[int],
Set[str],
Mapping[int, Union["IncEx", Literal[True]]],
Mapping[str, Union["IncEx", Literal[True]]],
]
OnDeleteType = Literal["CASCADE", "SET NULL", "RESTRICT"]
def __dataclass_transform__(
*,
eq_default: bool = True,
order_default: bool = False,
kw_only_default: bool = False,
field_descriptors: Tuple[Union[type, Callable[..., Any]], ...] = (()),
) -> Callable[[_T], _T]:
return lambda a: a
class FieldInfo(PydanticFieldInfo):
def __init__(self, default: Any = Undefined, **kwargs: Any) -> None:
primary_key = kwargs.pop("primary_key", False)
nullable = kwargs.pop("nullable", Undefined)
foreign_key = kwargs.pop("foreign_key", Undefined)
ondelete = kwargs.pop("ondelete", Undefined)
unique = kwargs.pop("unique", False)
index = kwargs.pop("index", Undefined)
sa_type = kwargs.pop("sa_type", Undefined)
sa_column = kwargs.pop("sa_column", Undefined)
sa_column_args = kwargs.pop("sa_column_args", Undefined)
sa_column_kwargs = kwargs.pop("sa_column_kwargs", Undefined)
if sa_column is not Undefined:
if sa_column_args is not Undefined:
raise RuntimeError(
"Passing sa_column_args is not supported when "
"also passing a sa_column"
)
if sa_column_kwargs is not Undefined:
raise RuntimeError(
"Passing sa_column_kwargs is not supported when "
"also passing a sa_column"
)
if primary_key is not Undefined:
raise RuntimeError(
"Passing primary_key is not supported when "
"also passing a sa_column"
)
if nullable is not Undefined:
raise RuntimeError(
"Passing nullable is not supported when also passing a sa_column"
)
if foreign_key is not Undefined:
raise RuntimeError(
"Passing foreign_key is not supported when "
"also passing a sa_column"
)
if ondelete is not Undefined:
raise RuntimeError(
"Passing ondelete is not supported when also passing a sa_column"
)
if unique is not Undefined:
raise RuntimeError(
"Passing unique is not supported when also passing a sa_column"
)
if index is not Undefined:
raise RuntimeError(
"Passing index is not supported when also passing a sa_column"
)
if sa_type is not Undefined:
raise RuntimeError(
"Passing sa_type is not supported when also passing a sa_column"
)
if ondelete is not Undefined:
if foreign_key is Undefined:
raise RuntimeError("ondelete can only be used with foreign_key")
super().__init__(default=default, **kwargs)
self.primary_key = primary_key
self.nullable = nullable
self.foreign_key = foreign_key
self.ondelete = ondelete
self.unique = unique
self.index = index
self.sa_type = sa_type
self.sa_column = sa_column
self.sa_column_args = sa_column_args
self.sa_column_kwargs = sa_column_kwargs
class RelationshipInfo(Representation):
def __init__(
self,
*,
back_populates: Optional[str] = None,
cascade_delete: Optional[bool] = False,
passive_deletes: Optional[Union[bool, Literal["all"]]] = False,
link_model: Optional[Any] = None,
sa_relationship: Optional[RelationshipProperty] = None, # type: ignore
sa_relationship_args: Optional[Sequence[Any]] = None,
sa_relationship_kwargs: Optional[Mapping[str, Any]] = None,
) -> None:
if sa_relationship is not None:
if sa_relationship_args is not None:
raise RuntimeError(
"Passing sa_relationship_args is not supported when "
"also passing a sa_relationship"
)
if sa_relationship_kwargs is not None:
raise RuntimeError(
"Passing sa_relationship_kwargs is not supported when "
"also passing a sa_relationship"
)
self.back_populates = back_populates
self.cascade_delete = cascade_delete
self.passive_deletes = passive_deletes
self.link_model = link_model
self.sa_relationship = sa_relationship
self.sa_relationship_args = sa_relationship_args
self.sa_relationship_kwargs = sa_relationship_kwargs
# include sa_type, sa_column_args, sa_column_kwargs
@overload
def Field(
default: Any = Undefined,
*,
default_factory: Optional[NoArgAnyCallable] = None,
alias: Optional[str] = None,
title: Optional[str] = None,
description: Optional[str] = None,
exclude: Union[
AbstractSet[Union[int, str]], Mapping[Union[int, str], Any], Any
] = None,
include: Union[
AbstractSet[Union[int, str]], Mapping[Union[int, str], Any], Any
] = None,
const: Optional[bool] = None,
gt: Optional[float] = None,
ge: Optional[float] = None,
lt: Optional[float] = None,
le: Optional[float] = None,
multiple_of: Optional[float] = None,
max_digits: Optional[int] = None,
decimal_places: Optional[int] = None,
min_items: Optional[int] = None,
max_items: Optional[int] = None,
unique_items: Optional[bool] = None,
min_length: Optional[int] = None,
max_length: Optional[int] = None,
allow_mutation: bool = True,
regex: Optional[str] = None,
discriminator: Optional[str] = None,
repr: bool = True,
primary_key: Union[bool, UndefinedType] = Undefined,
foreign_key: Any = Undefined,
unique: Union[bool, UndefinedType] = Undefined,
nullable: Union[bool, UndefinedType] = Undefined,
index: Union[bool, UndefinedType] = Undefined,
sa_type: Union[Type[Any], UndefinedType] = Undefined,
sa_column_args: Union[Sequence[Any], UndefinedType] = Undefined,
sa_column_kwargs: Union[Mapping[str, Any], UndefinedType] = Undefined,
schema_extra: Optional[Dict[str, Any]] = None,
) -> Any: ...
# When foreign_key is str, include ondelete
# include sa_type, sa_column_args, sa_column_kwargs
@overload
def Field(
default: Any = Undefined,
*,
default_factory: Optional[NoArgAnyCallable] = None,
alias: Optional[str] = None,
title: Optional[str] = None,
description: Optional[str] = None,
exclude: Union[
AbstractSet[Union[int, str]], Mapping[Union[int, str], Any], Any
] = None,
include: Union[
AbstractSet[Union[int, str]], Mapping[Union[int, str], Any], Any
] = None,
const: Optional[bool] = None,
gt: Optional[float] = None,
ge: Optional[float] = None,
lt: Optional[float] = None,
le: Optional[float] = None,
multiple_of: Optional[float] = None,
max_digits: Optional[int] = None,
decimal_places: Optional[int] = None,
min_items: Optional[int] = None,
max_items: Optional[int] = None,
unique_items: Optional[bool] = None,
min_length: Optional[int] = None,
max_length: Optional[int] = None,
allow_mutation: bool = True,
regex: Optional[str] = None,
discriminator: Optional[str] = None,
repr: bool = True,
primary_key: Union[bool, UndefinedType] = Undefined,
foreign_key: str,
ondelete: Union[OnDeleteType, UndefinedType] = Undefined,
unique: Union[bool, UndefinedType] = Undefined,
nullable: Union[bool, UndefinedType] = Undefined,
index: Union[bool, UndefinedType] = Undefined,
sa_type: Union[Type[Any], UndefinedType] = Undefined,
sa_column_args: Union[Sequence[Any], UndefinedType] = Undefined,
sa_column_kwargs: Union[Mapping[str, Any], UndefinedType] = Undefined,
schema_extra: Optional[Dict[str, Any]] = None,
) -> Any: ...
# Include sa_column, don't include
# primary_key
# foreign_key
# ondelete
# unique
# nullable
# index
# sa_type
# sa_column_args
# sa_column_kwargs
@overload
def Field(
default: Any = Undefined,
*,
default_factory: Optional[NoArgAnyCallable] = None,
alias: Optional[str] = None,
title: Optional[str] = None,
description: Optional[str] = None,
exclude: Union[
AbstractSet[Union[int, str]], Mapping[Union[int, str], Any], Any
] = None,
include: Union[
AbstractSet[Union[int, str]], Mapping[Union[int, str], Any], Any
] = None,
const: Optional[bool] = None,
gt: Optional[float] = None,
ge: Optional[float] = None,
lt: Optional[float] = None,
le: Optional[float] = None,
multiple_of: Optional[float] = None,
max_digits: Optional[int] = None,
decimal_places: Optional[int] = None,
min_items: Optional[int] = None,
max_items: Optional[int] = None,
unique_items: Optional[bool] = None,
min_length: Optional[int] = None,
max_length: Optional[int] = None,
allow_mutation: bool = True,
regex: Optional[str] = None,
discriminator: Optional[str] = None,
repr: bool = True,
sa_column: Union[Column, UndefinedType] = Undefined, # type: ignore
schema_extra: Optional[Dict[str, Any]] = None,
) -> Any: ...
def Field(
default: Any = Undefined,
*,
default_factory: Optional[NoArgAnyCallable] = None,
alias: Optional[str] = None,
title: Optional[str] = None,
description: Optional[str] = None,
exclude: Union[
AbstractSet[Union[int, str]], Mapping[Union[int, str], Any], Any
] = None,
include: Union[
AbstractSet[Union[int, str]], Mapping[Union[int, str], Any], Any
] = None,
const: Optional[bool] = None,
gt: Optional[float] = None,
ge: Optional[float] = None,
lt: Optional[float] = None,
le: Optional[float] = None,
multiple_of: Optional[float] = None,
max_digits: Optional[int] = None,
decimal_places: Optional[int] = None,
min_items: Optional[int] = None,
max_items: Optional[int] = None,
unique_items: Optional[bool] = None,
min_length: Optional[int] = None,
max_length: Optional[int] = None,
allow_mutation: bool = True,
regex: Optional[str] = None,
discriminator: Optional[str] = None,
repr: bool = True,
primary_key: Union[bool, UndefinedType] = Undefined,
foreign_key: Any = Undefined,
ondelete: Union[OnDeleteType, UndefinedType] = Undefined,
unique: Union[bool, UndefinedType] = Undefined,
nullable: Union[bool, UndefinedType] = Undefined,
index: Union[bool, UndefinedType] = Undefined,
sa_type: Union[Type[Any], UndefinedType] = Undefined,
sa_column: Union[Column, UndefinedType] = Undefined, # type: ignore
sa_column_args: Union[Sequence[Any], UndefinedType] = Undefined,
sa_column_kwargs: Union[Mapping[str, Any], UndefinedType] = Undefined,
schema_extra: Optional[Dict[str, Any]] = None,
) -> Any:
current_schema_extra = schema_extra or {}
field_info = FieldInfo(
default,
default_factory=default_factory,
alias=alias,
title=title,
description=description,
exclude=exclude,
include=include,
const=const,
gt=gt,
ge=ge,
lt=lt,
le=le,
multiple_of=multiple_of,
max_digits=max_digits,
decimal_places=decimal_places,
min_items=min_items,
max_items=max_items,
unique_items=unique_items,
min_length=min_length,
max_length=max_length,
allow_mutation=allow_mutation,
regex=regex,
discriminator=discriminator,
repr=repr,
primary_key=primary_key,
foreign_key=foreign_key,
ondelete=ondelete,
unique=unique,
nullable=nullable,
index=index,
sa_type=sa_type,
sa_column=sa_column,
sa_column_args=sa_column_args,
sa_column_kwargs=sa_column_kwargs,
**current_schema_extra,
)
post_init_field_info(field_info)
return field_info
@overload
def Relationship(
*,
back_populates: Optional[str] = None,
cascade_delete: Optional[bool] = False,
passive_deletes: Optional[Union[bool, Literal["all"]]] = False,
link_model: Optional[Any] = None,
sa_relationship_args: Optional[Sequence[Any]] = None,
sa_relationship_kwargs: Optional[Mapping[str, Any]] = None,
) -> Any: ...
@overload
def Relationship(
*,
back_populates: Optional[str] = None,
cascade_delete: Optional[bool] = False,
passive_deletes: Optional[Union[bool, Literal["all"]]] = False,
link_model: Optional[Any] = None,
sa_relationship: Optional[RelationshipProperty[Any]] = None,
) -> Any: ...
def Relationship(
*,
back_populates: Optional[str] = None,
cascade_delete: Optional[bool] = False,
passive_deletes: Optional[Union[bool, Literal["all"]]] = False,
link_model: Optional[Any] = None,
sa_relationship: Optional[RelationshipProperty[Any]] = None,
sa_relationship_args: Optional[Sequence[Any]] = None,
sa_relationship_kwargs: Optional[Mapping[str, Any]] = None,
) -> Any:
relationship_info = RelationshipInfo(
back_populates=back_populates,
cascade_delete=cascade_delete,
passive_deletes=passive_deletes,
link_model=link_model,
sa_relationship=sa_relationship,
sa_relationship_args=sa_relationship_args,
sa_relationship_kwargs=sa_relationship_kwargs,
)
return relationship_info
@__dataclass_transform__(kw_only_default=True, field_descriptors=(Field, FieldInfo))
class SQLModelMetaclass(ModelMetaclass, DeclarativeMeta):
__sqlmodel_relationships__: Dict[str, RelationshipInfo]
model_config: SQLModelConfig
model_fields: Dict[str, FieldInfo]
__config__: Type[SQLModelConfig]
__fields__: Dict[str, ModelField] # type: ignore[assignment]
# Replicate SQLAlchemy
def __setattr__(cls, name: str, value: Any) -> None:
if is_table_model_class(cls):
DeclarativeMeta.__setattr__(cls, name, value)
else:
super().__setattr__(name, value)
def __delattr__(cls, name: str) -> None:
if is_table_model_class(cls):
DeclarativeMeta.__delattr__(cls, name)
else:
super().__delattr__(name)
# From Pydantic
def __new__(
cls,
name: str,
bases: Tuple[Type[Any], ...],
class_dict: Dict[str, Any],
**kwargs: Any,
) -> Any:
relationships: Dict[str, RelationshipInfo] = {}
dict_for_pydantic = {}
original_annotations = get_annotations(class_dict)
pydantic_annotations = {}
relationship_annotations = {}
for k, v in class_dict.items():
if isinstance(v, RelationshipInfo):
relationships[k] = v
else:
dict_for_pydantic[k] = v
for k, v in original_annotations.items():
if k in relationships:
relationship_annotations[k] = v
else:
pydantic_annotations[k] = v
dict_used = {
**dict_for_pydantic,
"__weakref__": None,
"__sqlmodel_relationships__": relationships,
"__annotations__": pydantic_annotations,
}
# Duplicate logic from Pydantic to filter config kwargs because if they are
# passed directly including the registry Pydantic will pass them over to the
# superclass causing an error
allowed_config_kwargs: Set[str] = {
key
for key in dir(BaseConfig)
if not (
key.startswith("__") and key.endswith("__")
) # skip dunder methods and attributes
}
config_kwargs = {
key: kwargs[key] for key in kwargs.keys() & allowed_config_kwargs
}
new_cls = super().__new__(cls, name, bases, dict_used, **config_kwargs)
new_cls.__annotations__ = {
**relationship_annotations,
**pydantic_annotations,
**new_cls.__annotations__,
}
def get_config(name: str) -> Any:
config_class_value = get_config_value(
model=new_cls, parameter=name, default=Undefined
)
if config_class_value is not Undefined:
return config_class_value
kwarg_value = kwargs.get(name, Undefined)
if kwarg_value is not Undefined:
return kwarg_value
return Undefined
config_table = get_config("table")
if config_table is True:
# If it was passed by kwargs, ensure it's also set in config
set_config_value(model=new_cls, parameter="table", value=config_table)
for k, v in get_model_fields(new_cls).items():
col = get_column_from_field(v)
setattr(new_cls, k, col)
# Set a config flag to tell FastAPI that this should be read with a field
# in orm_mode instead of preemptively converting it to a dict.
# This could be done by reading new_cls.model_config['table'] in FastAPI, but
# that's very specific about SQLModel, so let's have another config that
# other future tools based on Pydantic can use.
set_config_value(
model=new_cls, parameter="read_from_attributes", value=True
)
# For compatibility with older versions
# TODO: remove this in the future
set_config_value(model=new_cls, parameter="read_with_orm_mode", value=True)
config_registry = get_config("registry")
if config_registry is not Undefined:
config_registry = cast(registry, config_registry)
# If it was passed by kwargs, ensure it's also set in config
set_config_value(model=new_cls, parameter="registry", value=config_table)
setattr(new_cls, "_sa_registry", config_registry) # noqa: B010
setattr(new_cls, "metadata", config_registry.metadata) # noqa: B010
setattr(new_cls, "__abstract__", True) # noqa: B010
return new_cls
# Override SQLAlchemy, allow both SQLAlchemy and plain Pydantic models
def __init__(
cls, classname: str, bases: Tuple[type, ...], dict_: Dict[str, Any], **kw: Any
) -> None:
# Only one of the base classes (or the current one) should be a table model
# this allows FastAPI cloning a SQLModel for the response_model without
# trying to create a new SQLAlchemy, for a new table, with the same name, that
# triggers an error
base_is_table = any(is_table_model_class(base) for base in bases)
if is_table_model_class(cls) and not base_is_table:
for rel_name, rel_info in cls.__sqlmodel_relationships__.items():
if rel_info.sa_relationship:
# There's a SQLAlchemy relationship declared, that takes precedence
# over anything else, use that and continue with the next attribute
setattr(cls, rel_name, rel_info.sa_relationship) # Fix #315
continue
raw_ann = cls.__annotations__[rel_name]
origin = get_origin(raw_ann)
if origin is Mapped:
ann = raw_ann.__args__[0]
else:
ann = raw_ann
# Plain forward references, for models not yet defined, are not
# handled well by SQLAlchemy without Mapped, so, wrap the
# annotations in Mapped here
cls.__annotations__[rel_name] = Mapped[ann] # type: ignore[valid-type]
relationship_to = get_relationship_to(
name=rel_name, rel_info=rel_info, annotation=ann
)
rel_kwargs: Dict[str, Any] = {}
if rel_info.back_populates:
rel_kwargs["back_populates"] = rel_info.back_populates
if rel_info.cascade_delete:
rel_kwargs["cascade"] = "all, delete-orphan"
if rel_info.passive_deletes:
rel_kwargs["passive_deletes"] = rel_info.passive_deletes
if rel_info.link_model:
ins = inspect(rel_info.link_model)
local_table = getattr(ins, "local_table") # noqa: B009
if local_table is None:
raise RuntimeError(
"Couldn't find the secondary table for "
f"model {rel_info.link_model}"
)
rel_kwargs["secondary"] = local_table
rel_args: List[Any] = []
if rel_info.sa_relationship_args:
rel_args.extend(rel_info.sa_relationship_args)
if rel_info.sa_relationship_kwargs:
rel_kwargs.update(rel_info.sa_relationship_kwargs)
rel_value = relationship(relationship_to, *rel_args, **rel_kwargs)
setattr(cls, rel_name, rel_value) # Fix #315
# SQLAlchemy no longer uses dict_
# Ref: https://github.com/sqlalchemy/sqlalchemy/commit/428ea01f00a9cc7f85e435018565eb6da7af1b77
# Tag: 1.4.36
DeclarativeMeta.__init__(cls, classname, bases, dict_, **kw)
else:
ModelMetaclass.__init__(cls, classname, bases, dict_, **kw)
def get_sqlalchemy_type(field: Any) -> Any:
if IS_PYDANTIC_V2:
field_info = field
else:
field_info = field.field_info
sa_type = getattr(field_info, "sa_type", Undefined) # noqa: B009
if sa_type is not Undefined:
return sa_type
type_ = get_sa_type_from_field(field)
metadata = get_field_metadata(field)
# Check enums first as an enum can also be a str, needed by Pydantic/FastAPI
if issubclass(type_, Enum):
return sa_Enum(type_)
if issubclass(
type_,
(
str,
ipaddress.IPv4Address,
ipaddress.IPv4Network,
ipaddress.IPv6Address,
ipaddress.IPv6Network,
Path,
EmailStr,
),
):
max_length = getattr(metadata, "max_length", None)
if max_length:
return AutoString(length=max_length)
return AutoString
if issubclass(type_, float):
return Float
if issubclass(type_, bool):
return Boolean
if issubclass(type_, int):
return Integer
if issubclass(type_, datetime):
return DateTime
if issubclass(type_, date):
return Date
if issubclass(type_, timedelta):
return Interval
if issubclass(type_, time):
return Time
if issubclass(type_, bytes):
return LargeBinary
if issubclass(type_, Decimal):
return Numeric(
precision=getattr(metadata, "max_digits", None),
scale=getattr(metadata, "decimal_places", None),
)
if issubclass(type_, uuid.UUID):
return Uuid
raise ValueError(f"{type_} has no matching SQLAlchemy type")
def get_column_from_field(field: Any) -> Column: # type: ignore
if IS_PYDANTIC_V2:
field_info = field
else:
field_info = field.field_info
sa_column = getattr(field_info, "sa_column", Undefined)
if isinstance(sa_column, Column):
return sa_column
sa_type = get_sqlalchemy_type(field)
primary_key = getattr(field_info, "primary_key", Undefined)
if primary_key is Undefined:
primary_key = False
index = getattr(field_info, "index", Undefined)
if index is Undefined:
index = False
nullable = not primary_key and is_field_noneable(field)
# Override derived nullability if the nullable property is set explicitly
# on the field
field_nullable = getattr(field_info, "nullable", Undefined) # noqa: B009
if field_nullable is not Undefined:
assert not isinstance(field_nullable, UndefinedType)
nullable = field_nullable
args = []
foreign_key = getattr(field_info, "foreign_key", Undefined)
if foreign_key is Undefined:
foreign_key = None
unique = getattr(field_info, "unique", Undefined)
if unique is Undefined:
unique = False
if foreign_key:
if field_info.ondelete == "SET NULL" and not nullable:
raise RuntimeError('ondelete="SET NULL" requires nullable=True')
assert isinstance(foreign_key, str)
ondelete = getattr(field_info, "ondelete", Undefined)
if ondelete is Undefined:
ondelete = None
assert isinstance(ondelete, (str, type(None))) # for typing
args.append(ForeignKey(foreign_key, ondelete=ondelete))
kwargs = {
"primary_key": primary_key,
"nullable": nullable,
"index": index,
"unique": unique,
}
sa_default = Undefined
if field_info.default_factory:
sa_default = field_info.default_factory
elif field_info.default is not Undefined:
sa_default = field_info.default
if sa_default is not Undefined:
kwargs["default"] = sa_default
sa_column_args = getattr(field_info, "sa_column_args", Undefined)
if sa_column_args is not Undefined:
args.extend(list(cast(Sequence[Any], sa_column_args)))
sa_column_kwargs = getattr(field_info, "sa_column_kwargs", Undefined)
if sa_column_kwargs is not Undefined:
kwargs.update(cast(Dict[Any, Any], sa_column_kwargs))
return Column(sa_type, *args, **kwargs) # type: ignore
class_registry = weakref.WeakValueDictionary() # type: ignore
default_registry = registry()
_TSQLModel = TypeVar("_TSQLModel", bound="SQLModel")
class SQLModel(BaseModel, metaclass=SQLModelMetaclass, registry=default_registry):
# SQLAlchemy needs to set weakref(s), Pydantic will set the other slots values
__slots__ = ("__weakref__",)
__tablename__: ClassVar[Union[str, Callable[..., str]]]
__sqlmodel_relationships__: ClassVar[Dict[str, RelationshipProperty[Any]]]
__name__: ClassVar[str]
metadata: ClassVar[MetaData]
__allow_unmapped__ = True # https://docs.sqlalchemy.org/en/20/changelog/migration_20.html#migration-20-step-six
if IS_PYDANTIC_V2:
model_config = SQLModelConfig(from_attributes=True)
else:
class Config:
orm_mode = True
def __new__(cls, *args: Any, **kwargs: Any) -> Any:
new_object = super().__new__(cls)
# SQLAlchemy doesn't call __init__ on the base class when querying from DB
# Ref: https://docs.sqlalchemy.org/en/14/orm/constructors.html
# Set __fields_set__ here, that would have been set when calling __init__
# in the Pydantic model so that when SQLAlchemy sets attributes that are
# added (e.g. when querying from DB) to the __fields_set__, this already exists
init_pydantic_private_attrs(new_object)
return new_object
def __init__(__pydantic_self__, **data: Any) -> None:
# Uses something other than `self` the first arg to allow "self" as a
# settable attribute
# SQLAlchemy does very dark black magic and modifies the __init__ method in
# sqlalchemy.orm.instrumentation._generate_init()
# so, to make SQLAlchemy work, it's needed to explicitly call __init__ to
# trigger all the SQLAlchemy logic, it doesn't work using cls.__new__, setting
# attributes obj.__dict__, etc. The __init__ method has to be called. But
# there are cases where calling all the default logic is not ideal, e.g.
# when calling Model.model_validate(), as the validation is done outside
# of instance creation.
# At the same time, __init__ is what users would normally call, by creating
# a new instance, which should have validation and all the default logic.
# So, to be able to set up the internal SQLAlchemy logic alone without
# executing the rest, and support things like Model.model_validate(), we
# use a contextvar to know if we should execute everything.
if finish_init.get():
sqlmodel_init(self=__pydantic_self__, data=data)
def __setattr__(self, name: str, value: Any) -> None:
if name in {"_sa_instance_state"}:
self.__dict__[name] = value
return
else:
# Set in SQLAlchemy, before Pydantic to trigger events and updates
if is_table_model_class(self.__class__) and is_instrumented(self, name): # type: ignore[no-untyped-call]
set_attribute(self, name, value)
# Set in Pydantic model to trigger possible validation changes, only for
# non relationship values
if name not in self.__sqlmodel_relationships__:
super().__setattr__(name, value)
def __repr_args__(self) -> Sequence[Tuple[Optional[str], Any]]:
# Don't show SQLAlchemy private attributes
return [
(k, v)
for k, v in super().__repr_args__()
if not (isinstance(k, str) and k.startswith("_sa_"))
]
@declared_attr # type: ignore
def __tablename__(cls) -> str:
return cls.__name__.lower()
@classmethod
def model_validate(
cls: Type[_TSQLModel],
obj: Any,
*,
strict: Union[bool, None] = None,
from_attributes: Union[bool, None] = None,
context: Union[Dict[str, Any], None] = None,
update: Union[Dict[str, Any], None] = None,
) -> _TSQLModel:
return sqlmodel_validate(
cls=cls,
obj=obj,
strict=strict,
from_attributes=from_attributes,
context=context,
update=update,
)
def model_dump(
self,
*,
mode: Union[Literal["json", "python"], str] = "python",
include: Union[IncEx, None] = None,
exclude: Union[IncEx, None] = None,
context: Union[Dict[str, Any], None] = None,
by_alias: bool = False,
exclude_unset: bool = False,
exclude_defaults: bool = False,
exclude_none: bool = False,
round_trip: bool = False,
warnings: Union[bool, Literal["none", "warn", "error"]] = True,
serialize_as_any: bool = False,
) -> Dict[str, Any]:
if PYDANTIC_VERSION >= "2.7.0":
extra_kwargs: Dict[str, Any] = {
"context": context,
"serialize_as_any": serialize_as_any,
}
else:
extra_kwargs = {}
if IS_PYDANTIC_V2:
return super().model_dump(
mode=mode,
include=include,
exclude=exclude,
by_alias=by_alias,
exclude_unset=exclude_unset,
exclude_defaults=exclude_defaults,
exclude_none=exclude_none,
round_trip=round_trip,
warnings=warnings,
**extra_kwargs,
)
else:
return super().dict(
include=include,
exclude=exclude,
by_alias=by_alias,
exclude_unset=exclude_unset,
exclude_defaults=exclude_defaults,
exclude_none=exclude_none,
)
@deprecated(
"""
🚨 `obj.dict()` was deprecated in SQLModel 0.0.14, you should
instead use `obj.model_dump()`.
"""
)
def dict(
self,
*,
include: Union[IncEx, None] = None,
exclude: Union[IncEx, None] = None,
by_alias: bool = False,
exclude_unset: bool = False,
exclude_defaults: bool = False,
exclude_none: bool = False,
) -> Dict[str, Any]:
return self.model_dump(
include=include,
exclude=exclude,
by_alias=by_alias,
exclude_unset=exclude_unset,
exclude_defaults=exclude_defaults,
exclude_none=exclude_none,
)
@classmethod
@deprecated(
"""
🚨 `obj.from_orm(data)` was deprecated in SQLModel 0.0.14, you should
instead use `obj.model_validate(data)`.
"""
)
def from_orm(
cls: Type[_TSQLModel], obj: Any, update: Optional[Dict[str, Any]] = None
) -> _TSQLModel:
return cls.model_validate(obj, update=update)
@classmethod
@deprecated(
"""
🚨 `obj.parse_obj(data)` was deprecated in SQLModel 0.0.14, you should
instead use `obj.model_validate(data)`.
"""
)
def parse_obj(
cls: Type[_TSQLModel], obj: Any, update: Optional[Dict[str, Any]] = None
) -> _TSQLModel:
if not IS_PYDANTIC_V2:
obj = cls._enforce_dict_if_root(obj) # type: ignore[attr-defined] # noqa
return cls.model_validate(obj, update=update)
# From Pydantic, override to only show keys from fields, omit SQLAlchemy attributes
@deprecated(
"""
🚨 You should not access `obj._calculate_keys()` directly.
It is only useful for Pydantic v1.X, you should probably upgrade to
Pydantic v2.X.
""",
category=None,
)
def _calculate_keys(
self,
include: Optional[Mapping[Union[int, str], Any]],
exclude: Optional[Mapping[Union[int, str], Any]],
exclude_unset: bool,
update: Optional[Dict[str, Any]] = None,
) -> Optional[AbstractSet[str]]:
return _calculate_keys(
self,
include=include,
exclude=exclude,
exclude_unset=exclude_unset,
update=update,
)
def sqlmodel_update(
self: _TSQLModel,
obj: Union[Dict[str, Any], BaseModel],
*,
update: Union[Dict[str, Any], None] = None,
) -> _TSQLModel:
use_update = (update or {}).copy()
if isinstance(obj, dict):
for key, value in {**obj, **use_update}.items():
if key in get_model_fields(self):
setattr(self, key, value)
elif isinstance(obj, BaseModel):
for key in get_model_fields(obj):
if key in use_update:
value = use_update.pop(key)
else:
value = getattr(obj, key)
setattr(self, key, value)
for remaining_key in use_update:
if remaining_key in get_model_fields(self):
value = use_update.pop(remaining_key)
setattr(self, remaining_key, value)
else:
raise ValueError(
"Can't use sqlmodel_update() with something that "
f"is not a dict or SQLModel or Pydantic model: {obj}"
)
return self