support sqlalchemy polymorphic

This commit is contained in:
John Lyu
2024-11-26 10:43:13 +08:00
parent e86b5fcc84
commit 6d93a46fe0
3 changed files with 180 additions and 8 deletions

View File

@@ -21,6 +21,7 @@ from typing import (
from pydantic import VERSION as P_VERSION
from pydantic import BaseModel
from pydantic.fields import FieldInfo
from sqlalchemy import inspect
from typing_extensions import Annotated, get_args, get_origin
# Reassign variable to make it reexported for mypy
@@ -290,6 +291,19 @@ if IS_PYDANTIC_V2:
if value is not Undefined:
setattr(self_instance, key, value)
# End SQLModel override
# Override polymorphic_on default value
mapper = inspect(cls)
polymorphic_on = mapper.polymorphic_on
polymorphic_property = mapper.get_property_by_column(polymorphic_on)
field_info = cls.model_fields.get(polymorphic_property.key)
if field_info:
v = values.get(polymorphic_property.key)
# if model is inherited or polymorphic_on is not explicitly set
# set the polymorphic_on by default
if mapper.inherits or v is None:
setattr(
self_instance, polymorphic_property.key, mapper.polymorphic_identity
)
return self_instance
def sqlmodel_validate(

View File

@@ -41,9 +41,10 @@ from sqlalchemy import (
)
from sqlalchemy import Enum as sa_Enum
from sqlalchemy.orm import (
InstrumentedAttribute,
Mapped,
MappedColumn,
RelationshipProperty,
declared_attr,
registry,
relationship,
)
@@ -544,6 +545,15 @@ class SQLModelMetaclass(ModelMetaclass, DeclarativeMeta):
**pydantic_annotations,
**new_cls.__annotations__,
}
# pydantic will set class attribute value inherited from parent as field
# default value, reset it back
base_fields = {}
for base in bases[::-1]:
if issubclass(base, BaseModel):
base_fields.update(base.model_fields)
for k, v in new_cls.model_fields.items():
if isinstance(v.default, InstrumentedAttribute):
new_cls.model_fields[k] = base_fields.get(k)
def get_config(name: str) -> Any:
config_class_value = get_config_value(
@@ -558,9 +568,19 @@ class SQLModelMetaclass(ModelMetaclass, DeclarativeMeta):
config_table = get_config("table")
if config_table is True:
if new_cls.__name__ != "SQLModel" and not hasattr(new_cls, "__tablename__"):
new_cls.__tablename__ = new_cls.__name__.lower()
# If it was passed by kwargs, ensure it's also set in config
set_config_value(model=new_cls, parameter="table", value=config_table)
for k, v in get_model_fields(new_cls).items():
original_v = getattr(new_cls, k, None)
if (
isinstance(original_v, InstrumentedAttribute)
and k not in class_dict
):
# The attribute was already set by SQLAlchemy, don't override it
# Needed for polymorphic models, see #36
continue
col = get_column_from_field(v)
setattr(new_cls, k, col)
# Set a config flag to tell FastAPI that this should be read with a field
@@ -594,7 +614,13 @@ class SQLModelMetaclass(ModelMetaclass, DeclarativeMeta):
# trying to create a new SQLAlchemy, for a new table, with the same name, that
# triggers an error
base_is_table = any(is_table_model_class(base) for base in bases)
if is_table_model_class(cls) and not base_is_table:
polymorphic_identity = dict_.get("__mapper_args__", {}).get(
"polymorphic_identity"
)
has_polymorphic = polymorphic_identity is not None
# allow polymorphic models inherit from table models
if is_table_model_class(cls) and (not base_is_table or has_polymorphic):
for rel_name, rel_info in cls.__sqlmodel_relationships__.items():
if rel_info.sa_relationship:
# There's a SQLAlchemy relationship declared, that takes precedence
@@ -641,6 +667,16 @@ class SQLModelMetaclass(ModelMetaclass, DeclarativeMeta):
# Ref: https://github.com/sqlalchemy/sqlalchemy/commit/428ea01f00a9cc7f85e435018565eb6da7af1b77
# Tag: 1.4.36
DeclarativeMeta.__init__(cls, classname, bases, dict_, **kw)
# # patch sqlmodel field's default value to polymorphic_identity
# if has_polymorphic:
# mapper = inspect(cls)
# polymorphic_on = mapper.polymorphic_on
# polymorphic_property = mapper.get_property_by_column(polymorphic_on)
# field = cls.model_fields.get(polymorphic_property.key)
# def get__polymorphic_identity__(kw):
# return polymorphic_identity
# if field:
# field.default_factory = get__polymorphic_identity__
else:
ModelMetaclass.__init__(cls, classname, bases, dict_, **kw)
@@ -708,7 +744,7 @@ def get_column_from_field(field: Any) -> Column: # type: ignore
else:
field_info = field.field_info
sa_column = getattr(field_info, "sa_column", Undefined)
if isinstance(sa_column, Column):
if isinstance(sa_column, Column) or isinstance(sa_column, MappedColumn):
return sa_column
sa_type = get_sqlalchemy_type(field)
primary_key = getattr(field_info, "primary_key", Undefined)
@@ -772,7 +808,6 @@ _TSQLModel = TypeVar("_TSQLModel", bound="SQLModel")
class SQLModel(BaseModel, metaclass=SQLModelMetaclass, registry=default_registry):
# SQLAlchemy needs to set weakref(s), Pydantic will set the other slots values
__slots__ = ("__weakref__",)
__tablename__: ClassVar[Union[str, Callable[..., str]]]
__sqlmodel_relationships__: ClassVar[Dict[str, RelationshipProperty[Any]]]
__name__: ClassVar[str]
metadata: ClassVar[MetaData]
@@ -836,10 +871,6 @@ class SQLModel(BaseModel, metaclass=SQLModelMetaclass, registry=default_registry
if not (isinstance(k, str) and k.startswith("_sa_"))
]
@declared_attr # type: ignore
def __tablename__(cls) -> str:
return cls.__name__.lower()
@classmethod
def model_validate(
cls: Type[_TSQLModel],