1011 lines
36 KiB
Python
1011 lines
36 KiB
Python
import ipaddress
|
|
import uuid
|
|
import weakref
|
|
from datetime import date, datetime, time, timedelta
|
|
from decimal import Decimal
|
|
from enum import Enum
|
|
from pathlib import Path
|
|
from typing import (
|
|
TYPE_CHECKING,
|
|
AbstractSet,
|
|
Any,
|
|
Callable,
|
|
ClassVar,
|
|
Dict,
|
|
List,
|
|
Mapping,
|
|
Optional,
|
|
Sequence,
|
|
Set,
|
|
Tuple,
|
|
Type,
|
|
TypeVar,
|
|
Union,
|
|
cast,
|
|
overload,
|
|
)
|
|
|
|
from pydantic import BaseModel, EmailStr
|
|
from pydantic.fields import FieldInfo as PydanticFieldInfo
|
|
from sqlalchemy import (
|
|
Boolean,
|
|
Column,
|
|
Date,
|
|
DateTime,
|
|
Float,
|
|
ForeignKey,
|
|
Integer,
|
|
Interval,
|
|
Numeric,
|
|
inspect,
|
|
)
|
|
from sqlalchemy import Enum as sa_Enum
|
|
from sqlalchemy.orm import (
|
|
Mapped,
|
|
RelationshipProperty,
|
|
declared_attr,
|
|
registry,
|
|
relationship,
|
|
)
|
|
from sqlalchemy.orm.attributes import set_attribute
|
|
from sqlalchemy.orm.decl_api import DeclarativeMeta
|
|
from sqlalchemy.orm.instrumentation import is_instrumented
|
|
from sqlalchemy.sql.schema import MetaData
|
|
from sqlalchemy.sql.sqltypes import LargeBinary, Time, Uuid
|
|
from typing_extensions import Literal, TypeAlias, deprecated, get_origin
|
|
|
|
from ._compat import ( # type: ignore[attr-defined]
|
|
IS_PYDANTIC_V2,
|
|
PYDANTIC_VERSION,
|
|
BaseConfig,
|
|
ModelField,
|
|
ModelMetaclass,
|
|
Representation,
|
|
SQLModelConfig,
|
|
Undefined,
|
|
UndefinedType,
|
|
_calculate_keys,
|
|
finish_init,
|
|
get_annotations,
|
|
get_config_value,
|
|
get_field_metadata,
|
|
get_model_fields,
|
|
get_relationship_to,
|
|
get_sa_type_from_field,
|
|
init_pydantic_private_attrs,
|
|
is_field_noneable,
|
|
is_table_model_class,
|
|
post_init_field_info,
|
|
set_config_value,
|
|
sqlmodel_init,
|
|
sqlmodel_validate,
|
|
)
|
|
from .sql.sqltypes import AutoString
|
|
|
|
if TYPE_CHECKING:
|
|
from pydantic._internal._model_construction import ModelMetaclass as ModelMetaclass
|
|
from pydantic._internal._repr import Representation as Representation
|
|
from pydantic_core import PydanticUndefined as Undefined
|
|
from pydantic_core import PydanticUndefinedType as UndefinedType
|
|
|
|
_T = TypeVar("_T")
|
|
NoArgAnyCallable = Callable[[], Any]
|
|
IncEx: TypeAlias = Union[
|
|
Set[int],
|
|
Set[str],
|
|
Mapping[int, Union["IncEx", Literal[True]]],
|
|
Mapping[str, Union["IncEx", Literal[True]]],
|
|
]
|
|
OnDeleteType = Literal["CASCADE", "SET NULL", "RESTRICT"]
|
|
|
|
|
|
def __dataclass_transform__(
|
|
*,
|
|
eq_default: bool = True,
|
|
order_default: bool = False,
|
|
kw_only_default: bool = False,
|
|
field_descriptors: Tuple[Union[type, Callable[..., Any]], ...] = (()),
|
|
) -> Callable[[_T], _T]:
|
|
return lambda a: a
|
|
|
|
|
|
class FieldInfo(PydanticFieldInfo):
|
|
def __init__(self, default: Any = Undefined, **kwargs: Any) -> None:
|
|
primary_key = kwargs.pop("primary_key", False)
|
|
nullable = kwargs.pop("nullable", Undefined)
|
|
foreign_key = kwargs.pop("foreign_key", Undefined)
|
|
ondelete = kwargs.pop("ondelete", Undefined)
|
|
unique = kwargs.pop("unique", False)
|
|
index = kwargs.pop("index", Undefined)
|
|
sa_type = kwargs.pop("sa_type", Undefined)
|
|
sa_column = kwargs.pop("sa_column", Undefined)
|
|
sa_column_args = kwargs.pop("sa_column_args", Undefined)
|
|
sa_column_kwargs = kwargs.pop("sa_column_kwargs", Undefined)
|
|
if sa_column is not Undefined:
|
|
if sa_column_args is not Undefined:
|
|
raise RuntimeError(
|
|
"Passing sa_column_args is not supported when "
|
|
"also passing a sa_column"
|
|
)
|
|
if sa_column_kwargs is not Undefined:
|
|
raise RuntimeError(
|
|
"Passing sa_column_kwargs is not supported when "
|
|
"also passing a sa_column"
|
|
)
|
|
if primary_key is not Undefined:
|
|
raise RuntimeError(
|
|
"Passing primary_key is not supported when "
|
|
"also passing a sa_column"
|
|
)
|
|
if nullable is not Undefined:
|
|
raise RuntimeError(
|
|
"Passing nullable is not supported when also passing a sa_column"
|
|
)
|
|
if foreign_key is not Undefined:
|
|
raise RuntimeError(
|
|
"Passing foreign_key is not supported when "
|
|
"also passing a sa_column"
|
|
)
|
|
if ondelete is not Undefined:
|
|
raise RuntimeError(
|
|
"Passing ondelete is not supported when also passing a sa_column"
|
|
)
|
|
if unique is not Undefined:
|
|
raise RuntimeError(
|
|
"Passing unique is not supported when also passing a sa_column"
|
|
)
|
|
if index is not Undefined:
|
|
raise RuntimeError(
|
|
"Passing index is not supported when also passing a sa_column"
|
|
)
|
|
if sa_type is not Undefined:
|
|
raise RuntimeError(
|
|
"Passing sa_type is not supported when also passing a sa_column"
|
|
)
|
|
if ondelete is not Undefined:
|
|
if foreign_key is Undefined:
|
|
raise RuntimeError("ondelete can only be used with foreign_key")
|
|
super().__init__(default=default, **kwargs)
|
|
self.primary_key = primary_key
|
|
self.nullable = nullable
|
|
self.foreign_key = foreign_key
|
|
self.ondelete = ondelete
|
|
self.unique = unique
|
|
self.index = index
|
|
self.sa_type = sa_type
|
|
self.sa_column = sa_column
|
|
self.sa_column_args = sa_column_args
|
|
self.sa_column_kwargs = sa_column_kwargs
|
|
|
|
|
|
class RelationshipInfo(Representation):
|
|
def __init__(
|
|
self,
|
|
*,
|
|
back_populates: Optional[str] = None,
|
|
cascade_delete: Optional[bool] = False,
|
|
passive_deletes: Optional[Union[bool, Literal["all"]]] = False,
|
|
link_model: Optional[Any] = None,
|
|
sa_relationship: Optional[RelationshipProperty] = None, # type: ignore
|
|
sa_relationship_args: Optional[Sequence[Any]] = None,
|
|
sa_relationship_kwargs: Optional[Mapping[str, Any]] = None,
|
|
) -> None:
|
|
if sa_relationship is not None:
|
|
if sa_relationship_args is not None:
|
|
raise RuntimeError(
|
|
"Passing sa_relationship_args is not supported when "
|
|
"also passing a sa_relationship"
|
|
)
|
|
if sa_relationship_kwargs is not None:
|
|
raise RuntimeError(
|
|
"Passing sa_relationship_kwargs is not supported when "
|
|
"also passing a sa_relationship"
|
|
)
|
|
self.back_populates = back_populates
|
|
self.cascade_delete = cascade_delete
|
|
self.passive_deletes = passive_deletes
|
|
self.link_model = link_model
|
|
self.sa_relationship = sa_relationship
|
|
self.sa_relationship_args = sa_relationship_args
|
|
self.sa_relationship_kwargs = sa_relationship_kwargs
|
|
|
|
|
|
# include sa_type, sa_column_args, sa_column_kwargs
|
|
@overload
|
|
def Field(
|
|
default: Any = Undefined,
|
|
*,
|
|
default_factory: Optional[NoArgAnyCallable] = None,
|
|
alias: Optional[str] = None,
|
|
title: Optional[str] = None,
|
|
description: Optional[str] = None,
|
|
exclude: Union[
|
|
AbstractSet[Union[int, str]], Mapping[Union[int, str], Any], Any
|
|
] = None,
|
|
include: Union[
|
|
AbstractSet[Union[int, str]], Mapping[Union[int, str], Any], Any
|
|
] = None,
|
|
const: Optional[bool] = None,
|
|
gt: Optional[float] = None,
|
|
ge: Optional[float] = None,
|
|
lt: Optional[float] = None,
|
|
le: Optional[float] = None,
|
|
multiple_of: Optional[float] = None,
|
|
max_digits: Optional[int] = None,
|
|
decimal_places: Optional[int] = None,
|
|
min_items: Optional[int] = None,
|
|
max_items: Optional[int] = None,
|
|
unique_items: Optional[bool] = None,
|
|
min_length: Optional[int] = None,
|
|
max_length: Optional[int] = None,
|
|
allow_mutation: bool = True,
|
|
regex: Optional[str] = None,
|
|
discriminator: Optional[str] = None,
|
|
repr: bool = True,
|
|
primary_key: Union[bool, UndefinedType] = Undefined,
|
|
foreign_key: Any = Undefined,
|
|
unique: Union[bool, UndefinedType] = Undefined,
|
|
nullable: Union[bool, UndefinedType] = Undefined,
|
|
index: Union[bool, UndefinedType] = Undefined,
|
|
sa_type: Union[Type[Any], UndefinedType] = Undefined,
|
|
sa_column_args: Union[Sequence[Any], UndefinedType] = Undefined,
|
|
sa_column_kwargs: Union[Mapping[str, Any], UndefinedType] = Undefined,
|
|
schema_extra: Optional[Dict[str, Any]] = None,
|
|
) -> Any: ...
|
|
|
|
|
|
# When foreign_key is str, include ondelete
|
|
# include sa_type, sa_column_args, sa_column_kwargs
|
|
@overload
|
|
def Field(
|
|
default: Any = Undefined,
|
|
*,
|
|
default_factory: Optional[NoArgAnyCallable] = None,
|
|
alias: Optional[str] = None,
|
|
title: Optional[str] = None,
|
|
description: Optional[str] = None,
|
|
exclude: Union[
|
|
AbstractSet[Union[int, str]], Mapping[Union[int, str], Any], Any
|
|
] = None,
|
|
include: Union[
|
|
AbstractSet[Union[int, str]], Mapping[Union[int, str], Any], Any
|
|
] = None,
|
|
const: Optional[bool] = None,
|
|
gt: Optional[float] = None,
|
|
ge: Optional[float] = None,
|
|
lt: Optional[float] = None,
|
|
le: Optional[float] = None,
|
|
multiple_of: Optional[float] = None,
|
|
max_digits: Optional[int] = None,
|
|
decimal_places: Optional[int] = None,
|
|
min_items: Optional[int] = None,
|
|
max_items: Optional[int] = None,
|
|
unique_items: Optional[bool] = None,
|
|
min_length: Optional[int] = None,
|
|
max_length: Optional[int] = None,
|
|
allow_mutation: bool = True,
|
|
regex: Optional[str] = None,
|
|
discriminator: Optional[str] = None,
|
|
repr: bool = True,
|
|
primary_key: Union[bool, UndefinedType] = Undefined,
|
|
foreign_key: str,
|
|
ondelete: Union[OnDeleteType, UndefinedType] = Undefined,
|
|
unique: Union[bool, UndefinedType] = Undefined,
|
|
nullable: Union[bool, UndefinedType] = Undefined,
|
|
index: Union[bool, UndefinedType] = Undefined,
|
|
sa_type: Union[Type[Any], UndefinedType] = Undefined,
|
|
sa_column_args: Union[Sequence[Any], UndefinedType] = Undefined,
|
|
sa_column_kwargs: Union[Mapping[str, Any], UndefinedType] = Undefined,
|
|
schema_extra: Optional[Dict[str, Any]] = None,
|
|
) -> Any: ...
|
|
|
|
|
|
# Include sa_column, don't include
|
|
# primary_key
|
|
# foreign_key
|
|
# ondelete
|
|
# unique
|
|
# nullable
|
|
# index
|
|
# sa_type
|
|
# sa_column_args
|
|
# sa_column_kwargs
|
|
@overload
|
|
def Field(
|
|
default: Any = Undefined,
|
|
*,
|
|
default_factory: Optional[NoArgAnyCallable] = None,
|
|
alias: Optional[str] = None,
|
|
title: Optional[str] = None,
|
|
description: Optional[str] = None,
|
|
exclude: Union[
|
|
AbstractSet[Union[int, str]], Mapping[Union[int, str], Any], Any
|
|
] = None,
|
|
include: Union[
|
|
AbstractSet[Union[int, str]], Mapping[Union[int, str], Any], Any
|
|
] = None,
|
|
const: Optional[bool] = None,
|
|
gt: Optional[float] = None,
|
|
ge: Optional[float] = None,
|
|
lt: Optional[float] = None,
|
|
le: Optional[float] = None,
|
|
multiple_of: Optional[float] = None,
|
|
max_digits: Optional[int] = None,
|
|
decimal_places: Optional[int] = None,
|
|
min_items: Optional[int] = None,
|
|
max_items: Optional[int] = None,
|
|
unique_items: Optional[bool] = None,
|
|
min_length: Optional[int] = None,
|
|
max_length: Optional[int] = None,
|
|
allow_mutation: bool = True,
|
|
regex: Optional[str] = None,
|
|
discriminator: Optional[str] = None,
|
|
repr: bool = True,
|
|
sa_column: Union[Column, UndefinedType] = Undefined, # type: ignore
|
|
schema_extra: Optional[Dict[str, Any]] = None,
|
|
) -> Any: ...
|
|
|
|
|
|
def Field(
|
|
default: Any = Undefined,
|
|
*,
|
|
default_factory: Optional[NoArgAnyCallable] = None,
|
|
alias: Optional[str] = None,
|
|
title: Optional[str] = None,
|
|
description: Optional[str] = None,
|
|
exclude: Union[
|
|
AbstractSet[Union[int, str]], Mapping[Union[int, str], Any], Any
|
|
] = None,
|
|
include: Union[
|
|
AbstractSet[Union[int, str]], Mapping[Union[int, str], Any], Any
|
|
] = None,
|
|
const: Optional[bool] = None,
|
|
gt: Optional[float] = None,
|
|
ge: Optional[float] = None,
|
|
lt: Optional[float] = None,
|
|
le: Optional[float] = None,
|
|
multiple_of: Optional[float] = None,
|
|
max_digits: Optional[int] = None,
|
|
decimal_places: Optional[int] = None,
|
|
min_items: Optional[int] = None,
|
|
max_items: Optional[int] = None,
|
|
unique_items: Optional[bool] = None,
|
|
min_length: Optional[int] = None,
|
|
max_length: Optional[int] = None,
|
|
allow_mutation: bool = True,
|
|
regex: Optional[str] = None,
|
|
discriminator: Optional[str] = None,
|
|
repr: bool = True,
|
|
primary_key: Union[bool, UndefinedType] = Undefined,
|
|
foreign_key: Any = Undefined,
|
|
ondelete: Union[OnDeleteType, UndefinedType] = Undefined,
|
|
unique: Union[bool, UndefinedType] = Undefined,
|
|
nullable: Union[bool, UndefinedType] = Undefined,
|
|
index: Union[bool, UndefinedType] = Undefined,
|
|
sa_type: Union[Type[Any], UndefinedType] = Undefined,
|
|
sa_column: Union[Column, UndefinedType] = Undefined, # type: ignore
|
|
sa_column_args: Union[Sequence[Any], UndefinedType] = Undefined,
|
|
sa_column_kwargs: Union[Mapping[str, Any], UndefinedType] = Undefined,
|
|
schema_extra: Optional[Dict[str, Any]] = None,
|
|
) -> Any:
|
|
current_schema_extra = schema_extra or {}
|
|
field_info = FieldInfo(
|
|
default,
|
|
default_factory=default_factory,
|
|
alias=alias,
|
|
title=title,
|
|
description=description,
|
|
exclude=exclude,
|
|
include=include,
|
|
const=const,
|
|
gt=gt,
|
|
ge=ge,
|
|
lt=lt,
|
|
le=le,
|
|
multiple_of=multiple_of,
|
|
max_digits=max_digits,
|
|
decimal_places=decimal_places,
|
|
min_items=min_items,
|
|
max_items=max_items,
|
|
unique_items=unique_items,
|
|
min_length=min_length,
|
|
max_length=max_length,
|
|
allow_mutation=allow_mutation,
|
|
regex=regex,
|
|
discriminator=discriminator,
|
|
repr=repr,
|
|
primary_key=primary_key,
|
|
foreign_key=foreign_key,
|
|
ondelete=ondelete,
|
|
unique=unique,
|
|
nullable=nullable,
|
|
index=index,
|
|
sa_type=sa_type,
|
|
sa_column=sa_column,
|
|
sa_column_args=sa_column_args,
|
|
sa_column_kwargs=sa_column_kwargs,
|
|
**current_schema_extra,
|
|
)
|
|
post_init_field_info(field_info)
|
|
return field_info
|
|
|
|
|
|
@overload
|
|
def Relationship(
|
|
*,
|
|
back_populates: Optional[str] = None,
|
|
cascade_delete: Optional[bool] = False,
|
|
passive_deletes: Optional[Union[bool, Literal["all"]]] = False,
|
|
link_model: Optional[Any] = None,
|
|
sa_relationship_args: Optional[Sequence[Any]] = None,
|
|
sa_relationship_kwargs: Optional[Mapping[str, Any]] = None,
|
|
) -> Any: ...
|
|
|
|
|
|
@overload
|
|
def Relationship(
|
|
*,
|
|
back_populates: Optional[str] = None,
|
|
cascade_delete: Optional[bool] = False,
|
|
passive_deletes: Optional[Union[bool, Literal["all"]]] = False,
|
|
link_model: Optional[Any] = None,
|
|
sa_relationship: Optional[RelationshipProperty[Any]] = None,
|
|
) -> Any: ...
|
|
|
|
|
|
def Relationship(
|
|
*,
|
|
back_populates: Optional[str] = None,
|
|
cascade_delete: Optional[bool] = False,
|
|
passive_deletes: Optional[Union[bool, Literal["all"]]] = False,
|
|
link_model: Optional[Any] = None,
|
|
sa_relationship: Optional[RelationshipProperty[Any]] = None,
|
|
sa_relationship_args: Optional[Sequence[Any]] = None,
|
|
sa_relationship_kwargs: Optional[Mapping[str, Any]] = None,
|
|
) -> Any:
|
|
relationship_info = RelationshipInfo(
|
|
back_populates=back_populates,
|
|
cascade_delete=cascade_delete,
|
|
passive_deletes=passive_deletes,
|
|
link_model=link_model,
|
|
sa_relationship=sa_relationship,
|
|
sa_relationship_args=sa_relationship_args,
|
|
sa_relationship_kwargs=sa_relationship_kwargs,
|
|
)
|
|
return relationship_info
|
|
|
|
|
|
@__dataclass_transform__(kw_only_default=True, field_descriptors=(Field, FieldInfo))
|
|
class SQLModelMetaclass(ModelMetaclass, DeclarativeMeta):
|
|
__sqlmodel_relationships__: Dict[str, RelationshipInfo]
|
|
model_config: SQLModelConfig
|
|
model_fields: Dict[str, FieldInfo]
|
|
__config__: Type[SQLModelConfig]
|
|
__fields__: Dict[str, ModelField] # type: ignore[assignment]
|
|
|
|
# Replicate SQLAlchemy
|
|
def __setattr__(cls, name: str, value: Any) -> None:
|
|
if is_table_model_class(cls):
|
|
DeclarativeMeta.__setattr__(cls, name, value)
|
|
else:
|
|
super().__setattr__(name, value)
|
|
|
|
def __delattr__(cls, name: str) -> None:
|
|
if is_table_model_class(cls):
|
|
DeclarativeMeta.__delattr__(cls, name)
|
|
else:
|
|
super().__delattr__(name)
|
|
|
|
# From Pydantic
|
|
def __new__(
|
|
cls,
|
|
name: str,
|
|
bases: Tuple[Type[Any], ...],
|
|
class_dict: Dict[str, Any],
|
|
**kwargs: Any,
|
|
) -> Any:
|
|
relationships: Dict[str, RelationshipInfo] = {}
|
|
dict_for_pydantic = {}
|
|
original_annotations = get_annotations(class_dict)
|
|
pydantic_annotations = {}
|
|
relationship_annotations = {}
|
|
for k, v in class_dict.items():
|
|
if isinstance(v, RelationshipInfo):
|
|
relationships[k] = v
|
|
else:
|
|
dict_for_pydantic[k] = v
|
|
for k, v in original_annotations.items():
|
|
if k in relationships:
|
|
relationship_annotations[k] = v
|
|
else:
|
|
pydantic_annotations[k] = v
|
|
dict_used = {
|
|
**dict_for_pydantic,
|
|
"__weakref__": None,
|
|
"__sqlmodel_relationships__": relationships,
|
|
"__annotations__": pydantic_annotations,
|
|
}
|
|
# Duplicate logic from Pydantic to filter config kwargs because if they are
|
|
# passed directly including the registry Pydantic will pass them over to the
|
|
# superclass causing an error
|
|
allowed_config_kwargs: Set[str] = {
|
|
key
|
|
for key in dir(BaseConfig)
|
|
if not (
|
|
key.startswith("__") and key.endswith("__")
|
|
) # skip dunder methods and attributes
|
|
}
|
|
config_kwargs = {
|
|
key: kwargs[key] for key in kwargs.keys() & allowed_config_kwargs
|
|
}
|
|
new_cls = super().__new__(cls, name, bases, dict_used, **config_kwargs)
|
|
new_cls.__annotations__ = {
|
|
**relationship_annotations,
|
|
**pydantic_annotations,
|
|
**new_cls.__annotations__,
|
|
}
|
|
|
|
def get_config(name: str) -> Any:
|
|
config_class_value = get_config_value(
|
|
model=new_cls, parameter=name, default=Undefined
|
|
)
|
|
if config_class_value is not Undefined:
|
|
return config_class_value
|
|
kwarg_value = kwargs.get(name, Undefined)
|
|
if kwarg_value is not Undefined:
|
|
return kwarg_value
|
|
return Undefined
|
|
|
|
config_table = get_config("table")
|
|
if config_table is True:
|
|
# If it was passed by kwargs, ensure it's also set in config
|
|
set_config_value(model=new_cls, parameter="table", value=config_table)
|
|
for k, v in get_model_fields(new_cls).items():
|
|
col = get_column_from_field(v)
|
|
setattr(new_cls, k, col)
|
|
# Set a config flag to tell FastAPI that this should be read with a field
|
|
# in orm_mode instead of preemptively converting it to a dict.
|
|
# This could be done by reading new_cls.model_config['table'] in FastAPI, but
|
|
# that's very specific about SQLModel, so let's have another config that
|
|
# other future tools based on Pydantic can use.
|
|
set_config_value(
|
|
model=new_cls, parameter="read_from_attributes", value=True
|
|
)
|
|
# For compatibility with older versions
|
|
# TODO: remove this in the future
|
|
set_config_value(model=new_cls, parameter="read_with_orm_mode", value=True)
|
|
|
|
config_registry = get_config("registry")
|
|
if config_registry is not Undefined:
|
|
config_registry = cast(registry, config_registry)
|
|
# If it was passed by kwargs, ensure it's also set in config
|
|
set_config_value(model=new_cls, parameter="registry", value=config_table)
|
|
setattr(new_cls, "_sa_registry", config_registry) # noqa: B010
|
|
setattr(new_cls, "metadata", config_registry.metadata) # noqa: B010
|
|
setattr(new_cls, "__abstract__", True) # noqa: B010
|
|
return new_cls
|
|
|
|
# Override SQLAlchemy, allow both SQLAlchemy and plain Pydantic models
|
|
def __init__(
|
|
cls, classname: str, bases: Tuple[type, ...], dict_: Dict[str, Any], **kw: Any
|
|
) -> None:
|
|
# Only one of the base classes (or the current one) should be a table model
|
|
# this allows FastAPI cloning a SQLModel for the response_model without
|
|
# trying to create a new SQLAlchemy, for a new table, with the same name, that
|
|
# triggers an error
|
|
base_is_table = any(is_table_model_class(base) for base in bases)
|
|
if is_table_model_class(cls) and not base_is_table:
|
|
for rel_name, rel_info in cls.__sqlmodel_relationships__.items():
|
|
if rel_info.sa_relationship:
|
|
# There's a SQLAlchemy relationship declared, that takes precedence
|
|
# over anything else, use that and continue with the next attribute
|
|
setattr(cls, rel_name, rel_info.sa_relationship) # Fix #315
|
|
continue
|
|
raw_ann = cls.__annotations__[rel_name]
|
|
origin = get_origin(raw_ann)
|
|
if origin is Mapped:
|
|
ann = raw_ann.__args__[0]
|
|
else:
|
|
ann = raw_ann
|
|
# Plain forward references, for models not yet defined, are not
|
|
# handled well by SQLAlchemy without Mapped, so, wrap the
|
|
# annotations in Mapped here
|
|
cls.__annotations__[rel_name] = Mapped[ann] # type: ignore[valid-type]
|
|
relationship_to = get_relationship_to(
|
|
name=rel_name, rel_info=rel_info, annotation=ann
|
|
)
|
|
rel_kwargs: Dict[str, Any] = {}
|
|
if rel_info.back_populates:
|
|
rel_kwargs["back_populates"] = rel_info.back_populates
|
|
if rel_info.cascade_delete:
|
|
rel_kwargs["cascade"] = "all, delete-orphan"
|
|
if rel_info.passive_deletes:
|
|
rel_kwargs["passive_deletes"] = rel_info.passive_deletes
|
|
if rel_info.link_model:
|
|
ins = inspect(rel_info.link_model)
|
|
local_table = getattr(ins, "local_table") # noqa: B009
|
|
if local_table is None:
|
|
raise RuntimeError(
|
|
"Couldn't find the secondary table for "
|
|
f"model {rel_info.link_model}"
|
|
)
|
|
rel_kwargs["secondary"] = local_table
|
|
rel_args: List[Any] = []
|
|
if rel_info.sa_relationship_args:
|
|
rel_args.extend(rel_info.sa_relationship_args)
|
|
if rel_info.sa_relationship_kwargs:
|
|
rel_kwargs.update(rel_info.sa_relationship_kwargs)
|
|
rel_value = relationship(relationship_to, *rel_args, **rel_kwargs)
|
|
setattr(cls, rel_name, rel_value) # Fix #315
|
|
# SQLAlchemy no longer uses dict_
|
|
# Ref: https://github.com/sqlalchemy/sqlalchemy/commit/428ea01f00a9cc7f85e435018565eb6da7af1b77
|
|
# Tag: 1.4.36
|
|
DeclarativeMeta.__init__(cls, classname, bases, dict_, **kw)
|
|
else:
|
|
ModelMetaclass.__init__(cls, classname, bases, dict_, **kw)
|
|
|
|
|
|
def get_sqlalchemy_type(field: Any) -> Any:
|
|
if IS_PYDANTIC_V2:
|
|
field_info = field
|
|
else:
|
|
field_info = field.field_info
|
|
sa_type = getattr(field_info, "sa_type", Undefined) # noqa: B009
|
|
if sa_type is not Undefined:
|
|
return sa_type
|
|
|
|
type_ = get_sa_type_from_field(field)
|
|
metadata = get_field_metadata(field)
|
|
|
|
# Check enums first as an enum can also be a str, needed by Pydantic/FastAPI
|
|
if issubclass(type_, Enum):
|
|
return sa_Enum(type_)
|
|
if issubclass(
|
|
type_,
|
|
(
|
|
str,
|
|
ipaddress.IPv4Address,
|
|
ipaddress.IPv4Network,
|
|
ipaddress.IPv6Address,
|
|
ipaddress.IPv6Network,
|
|
Path,
|
|
EmailStr,
|
|
),
|
|
):
|
|
max_length = getattr(metadata, "max_length", None)
|
|
if max_length:
|
|
return AutoString(length=max_length)
|
|
return AutoString
|
|
if issubclass(type_, float):
|
|
return Float
|
|
if issubclass(type_, bool):
|
|
return Boolean
|
|
if issubclass(type_, int):
|
|
return Integer
|
|
if issubclass(type_, datetime):
|
|
return DateTime
|
|
if issubclass(type_, date):
|
|
return Date
|
|
if issubclass(type_, timedelta):
|
|
return Interval
|
|
if issubclass(type_, time):
|
|
return Time
|
|
if issubclass(type_, bytes):
|
|
return LargeBinary
|
|
if issubclass(type_, Decimal):
|
|
return Numeric(
|
|
precision=getattr(metadata, "max_digits", None),
|
|
scale=getattr(metadata, "decimal_places", None),
|
|
)
|
|
if issubclass(type_, uuid.UUID):
|
|
return Uuid
|
|
raise ValueError(f"{type_} has no matching SQLAlchemy type")
|
|
|
|
|
|
def get_column_from_field(field: Any) -> Column: # type: ignore
|
|
if IS_PYDANTIC_V2:
|
|
field_info = field
|
|
else:
|
|
field_info = field.field_info
|
|
sa_column = getattr(field_info, "sa_column", Undefined)
|
|
if isinstance(sa_column, Column):
|
|
return sa_column
|
|
sa_type = get_sqlalchemy_type(field)
|
|
primary_key = getattr(field_info, "primary_key", Undefined)
|
|
if primary_key is Undefined:
|
|
primary_key = False
|
|
index = getattr(field_info, "index", Undefined)
|
|
if index is Undefined:
|
|
index = False
|
|
nullable = not primary_key and is_field_noneable(field)
|
|
# Override derived nullability if the nullable property is set explicitly
|
|
# on the field
|
|
field_nullable = getattr(field_info, "nullable", Undefined) # noqa: B009
|
|
if field_nullable is not Undefined:
|
|
assert not isinstance(field_nullable, UndefinedType)
|
|
nullable = field_nullable
|
|
args = []
|
|
foreign_key = getattr(field_info, "foreign_key", Undefined)
|
|
if foreign_key is Undefined:
|
|
foreign_key = None
|
|
unique = getattr(field_info, "unique", Undefined)
|
|
if unique is Undefined:
|
|
unique = False
|
|
if foreign_key:
|
|
if field_info.ondelete == "SET NULL" and not nullable:
|
|
raise RuntimeError('ondelete="SET NULL" requires nullable=True')
|
|
assert isinstance(foreign_key, str)
|
|
ondelete = getattr(field_info, "ondelete", Undefined)
|
|
if ondelete is Undefined:
|
|
ondelete = None
|
|
assert isinstance(ondelete, (str, type(None))) # for typing
|
|
args.append(ForeignKey(foreign_key, ondelete=ondelete))
|
|
kwargs = {
|
|
"primary_key": primary_key,
|
|
"nullable": nullable,
|
|
"index": index,
|
|
"unique": unique,
|
|
}
|
|
sa_default = Undefined
|
|
if field_info.default_factory:
|
|
sa_default = field_info.default_factory
|
|
elif field_info.default is not Undefined:
|
|
sa_default = field_info.default
|
|
if sa_default is not Undefined:
|
|
kwargs["default"] = sa_default
|
|
sa_column_args = getattr(field_info, "sa_column_args", Undefined)
|
|
if sa_column_args is not Undefined:
|
|
args.extend(list(cast(Sequence[Any], sa_column_args)))
|
|
sa_column_kwargs = getattr(field_info, "sa_column_kwargs", Undefined)
|
|
if sa_column_kwargs is not Undefined:
|
|
kwargs.update(cast(Dict[Any, Any], sa_column_kwargs))
|
|
return Column(sa_type, *args, **kwargs) # type: ignore
|
|
|
|
|
|
class_registry = weakref.WeakValueDictionary() # type: ignore
|
|
|
|
default_registry = registry()
|
|
|
|
_TSQLModel = TypeVar("_TSQLModel", bound="SQLModel")
|
|
|
|
|
|
class SQLModel(BaseModel, metaclass=SQLModelMetaclass, registry=default_registry):
|
|
# SQLAlchemy needs to set weakref(s), Pydantic will set the other slots values
|
|
__slots__ = ("__weakref__",)
|
|
__tablename__: ClassVar[Union[str, Callable[..., str]]]
|
|
__sqlmodel_relationships__: ClassVar[Dict[str, RelationshipProperty[Any]]]
|
|
__name__: ClassVar[str]
|
|
metadata: ClassVar[MetaData]
|
|
__allow_unmapped__ = True # https://docs.sqlalchemy.org/en/20/changelog/migration_20.html#migration-20-step-six
|
|
|
|
if IS_PYDANTIC_V2:
|
|
model_config = SQLModelConfig(from_attributes=True)
|
|
else:
|
|
|
|
class Config:
|
|
orm_mode = True
|
|
|
|
def __new__(cls, *args: Any, **kwargs: Any) -> Any:
|
|
new_object = super().__new__(cls)
|
|
# SQLAlchemy doesn't call __init__ on the base class when querying from DB
|
|
# Ref: https://docs.sqlalchemy.org/en/14/orm/constructors.html
|
|
# Set __fields_set__ here, that would have been set when calling __init__
|
|
# in the Pydantic model so that when SQLAlchemy sets attributes that are
|
|
# added (e.g. when querying from DB) to the __fields_set__, this already exists
|
|
init_pydantic_private_attrs(new_object)
|
|
return new_object
|
|
|
|
def __init__(__pydantic_self__, **data: Any) -> None:
|
|
# Uses something other than `self` the first arg to allow "self" as a
|
|
# settable attribute
|
|
|
|
# SQLAlchemy does very dark black magic and modifies the __init__ method in
|
|
# sqlalchemy.orm.instrumentation._generate_init()
|
|
# so, to make SQLAlchemy work, it's needed to explicitly call __init__ to
|
|
# trigger all the SQLAlchemy logic, it doesn't work using cls.__new__, setting
|
|
# attributes obj.__dict__, etc. The __init__ method has to be called. But
|
|
# there are cases where calling all the default logic is not ideal, e.g.
|
|
# when calling Model.model_validate(), as the validation is done outside
|
|
# of instance creation.
|
|
# At the same time, __init__ is what users would normally call, by creating
|
|
# a new instance, which should have validation and all the default logic.
|
|
# So, to be able to set up the internal SQLAlchemy logic alone without
|
|
# executing the rest, and support things like Model.model_validate(), we
|
|
# use a contextvar to know if we should execute everything.
|
|
if finish_init.get():
|
|
sqlmodel_init(self=__pydantic_self__, data=data)
|
|
|
|
def __setattr__(self, name: str, value: Any) -> None:
|
|
if name in {"_sa_instance_state"}:
|
|
self.__dict__[name] = value
|
|
return
|
|
else:
|
|
# Set in SQLAlchemy, before Pydantic to trigger events and updates
|
|
if is_table_model_class(self.__class__) and is_instrumented(self, name): # type: ignore[no-untyped-call]
|
|
set_attribute(self, name, value)
|
|
# Set in Pydantic model to trigger possible validation changes, only for
|
|
# non relationship values
|
|
if name not in self.__sqlmodel_relationships__:
|
|
super().__setattr__(name, value)
|
|
|
|
def __repr_args__(self) -> Sequence[Tuple[Optional[str], Any]]:
|
|
# Don't show SQLAlchemy private attributes
|
|
return [
|
|
(k, v)
|
|
for k, v in super().__repr_args__()
|
|
if not (isinstance(k, str) and k.startswith("_sa_"))
|
|
]
|
|
|
|
@declared_attr # type: ignore
|
|
def __tablename__(cls) -> str:
|
|
return cls.__name__.lower()
|
|
|
|
@classmethod
|
|
def model_validate(
|
|
cls: Type[_TSQLModel],
|
|
obj: Any,
|
|
*,
|
|
strict: Union[bool, None] = None,
|
|
from_attributes: Union[bool, None] = None,
|
|
context: Union[Dict[str, Any], None] = None,
|
|
update: Union[Dict[str, Any], None] = None,
|
|
) -> _TSQLModel:
|
|
return sqlmodel_validate(
|
|
cls=cls,
|
|
obj=obj,
|
|
strict=strict,
|
|
from_attributes=from_attributes,
|
|
context=context,
|
|
update=update,
|
|
)
|
|
|
|
def model_dump(
|
|
self,
|
|
*,
|
|
mode: Union[Literal["json", "python"], str] = "python",
|
|
include: Union[IncEx, None] = None,
|
|
exclude: Union[IncEx, None] = None,
|
|
context: Union[Dict[str, Any], None] = None,
|
|
by_alias: bool = False,
|
|
exclude_unset: bool = False,
|
|
exclude_defaults: bool = False,
|
|
exclude_none: bool = False,
|
|
round_trip: bool = False,
|
|
warnings: Union[bool, Literal["none", "warn", "error"]] = True,
|
|
serialize_as_any: bool = False,
|
|
) -> Dict[str, Any]:
|
|
if PYDANTIC_VERSION >= "2.7.0":
|
|
extra_kwargs: Dict[str, Any] = {
|
|
"context": context,
|
|
"serialize_as_any": serialize_as_any,
|
|
}
|
|
else:
|
|
extra_kwargs = {}
|
|
if IS_PYDANTIC_V2:
|
|
return super().model_dump(
|
|
mode=mode,
|
|
include=include,
|
|
exclude=exclude,
|
|
by_alias=by_alias,
|
|
exclude_unset=exclude_unset,
|
|
exclude_defaults=exclude_defaults,
|
|
exclude_none=exclude_none,
|
|
round_trip=round_trip,
|
|
warnings=warnings,
|
|
**extra_kwargs,
|
|
)
|
|
else:
|
|
return super().dict(
|
|
include=include,
|
|
exclude=exclude,
|
|
by_alias=by_alias,
|
|
exclude_unset=exclude_unset,
|
|
exclude_defaults=exclude_defaults,
|
|
exclude_none=exclude_none,
|
|
)
|
|
|
|
@deprecated(
|
|
"""
|
|
🚨 `obj.dict()` was deprecated in SQLModel 0.0.14, you should
|
|
instead use `obj.model_dump()`.
|
|
"""
|
|
)
|
|
def dict(
|
|
self,
|
|
*,
|
|
include: Union[IncEx, None] = None,
|
|
exclude: Union[IncEx, None] = None,
|
|
by_alias: bool = False,
|
|
exclude_unset: bool = False,
|
|
exclude_defaults: bool = False,
|
|
exclude_none: bool = False,
|
|
) -> Dict[str, Any]:
|
|
return self.model_dump(
|
|
include=include,
|
|
exclude=exclude,
|
|
by_alias=by_alias,
|
|
exclude_unset=exclude_unset,
|
|
exclude_defaults=exclude_defaults,
|
|
exclude_none=exclude_none,
|
|
)
|
|
|
|
@classmethod
|
|
@deprecated(
|
|
"""
|
|
🚨 `obj.from_orm(data)` was deprecated in SQLModel 0.0.14, you should
|
|
instead use `obj.model_validate(data)`.
|
|
"""
|
|
)
|
|
def from_orm(
|
|
cls: Type[_TSQLModel], obj: Any, update: Optional[Dict[str, Any]] = None
|
|
) -> _TSQLModel:
|
|
return cls.model_validate(obj, update=update)
|
|
|
|
@classmethod
|
|
@deprecated(
|
|
"""
|
|
🚨 `obj.parse_obj(data)` was deprecated in SQLModel 0.0.14, you should
|
|
instead use `obj.model_validate(data)`.
|
|
"""
|
|
)
|
|
def parse_obj(
|
|
cls: Type[_TSQLModel], obj: Any, update: Optional[Dict[str, Any]] = None
|
|
) -> _TSQLModel:
|
|
if not IS_PYDANTIC_V2:
|
|
obj = cls._enforce_dict_if_root(obj) # type: ignore[attr-defined] # noqa
|
|
return cls.model_validate(obj, update=update)
|
|
|
|
# From Pydantic, override to only show keys from fields, omit SQLAlchemy attributes
|
|
@deprecated(
|
|
"""
|
|
🚨 You should not access `obj._calculate_keys()` directly.
|
|
|
|
It is only useful for Pydantic v1.X, you should probably upgrade to
|
|
Pydantic v2.X.
|
|
""",
|
|
category=None,
|
|
)
|
|
def _calculate_keys(
|
|
self,
|
|
include: Optional[Mapping[Union[int, str], Any]],
|
|
exclude: Optional[Mapping[Union[int, str], Any]],
|
|
exclude_unset: bool,
|
|
update: Optional[Dict[str, Any]] = None,
|
|
) -> Optional[AbstractSet[str]]:
|
|
return _calculate_keys(
|
|
self,
|
|
include=include,
|
|
exclude=exclude,
|
|
exclude_unset=exclude_unset,
|
|
update=update,
|
|
)
|
|
|
|
def sqlmodel_update(
|
|
self: _TSQLModel,
|
|
obj: Union[Dict[str, Any], BaseModel],
|
|
*,
|
|
update: Union[Dict[str, Any], None] = None,
|
|
) -> _TSQLModel:
|
|
use_update = (update or {}).copy()
|
|
if isinstance(obj, dict):
|
|
for key, value in {**obj, **use_update}.items():
|
|
if key in get_model_fields(self):
|
|
setattr(self, key, value)
|
|
elif isinstance(obj, BaseModel):
|
|
for key in get_model_fields(obj):
|
|
if key in use_update:
|
|
value = use_update.pop(key)
|
|
else:
|
|
value = getattr(obj, key)
|
|
setattr(self, key, value)
|
|
for remaining_key in use_update:
|
|
if remaining_key in get_model_fields(self):
|
|
value = use_update.pop(remaining_key)
|
|
setattr(self, remaining_key, value)
|
|
else:
|
|
raise ValueError(
|
|
"Can't use sqlmodel_update() with something that "
|
|
f"is not a dict or SQLModel or Pydantic model: {obj}"
|
|
)
|
|
return self
|