support sqlalchemy polymorphic

This commit is contained in:
John Lyu 2024-11-26 10:43:13 +08:00
parent e86b5fcc84
commit 6d93a46fe0
3 changed files with 180 additions and 8 deletions

View File

@ -21,6 +21,7 @@ from typing import (
from pydantic import VERSION as P_VERSION from pydantic import VERSION as P_VERSION
from pydantic import BaseModel from pydantic import BaseModel
from pydantic.fields import FieldInfo from pydantic.fields import FieldInfo
from sqlalchemy import inspect
from typing_extensions import Annotated, get_args, get_origin from typing_extensions import Annotated, get_args, get_origin
# Reassign variable to make it reexported for mypy # Reassign variable to make it reexported for mypy
@ -290,6 +291,19 @@ if IS_PYDANTIC_V2:
if value is not Undefined: if value is not Undefined:
setattr(self_instance, key, value) setattr(self_instance, key, value)
# End SQLModel override # End SQLModel override
# Override polymorphic_on default value
mapper = inspect(cls)
polymorphic_on = mapper.polymorphic_on
polymorphic_property = mapper.get_property_by_column(polymorphic_on)
field_info = cls.model_fields.get(polymorphic_property.key)
if field_info:
v = values.get(polymorphic_property.key)
# if model is inherited or polymorphic_on is not explicitly set
# set the polymorphic_on by default
if mapper.inherits or v is None:
setattr(
self_instance, polymorphic_property.key, mapper.polymorphic_identity
)
return self_instance return self_instance
def sqlmodel_validate( def sqlmodel_validate(

View File

@ -41,9 +41,10 @@ from sqlalchemy import (
) )
from sqlalchemy import Enum as sa_Enum from sqlalchemy import Enum as sa_Enum
from sqlalchemy.orm import ( from sqlalchemy.orm import (
InstrumentedAttribute,
Mapped, Mapped,
MappedColumn,
RelationshipProperty, RelationshipProperty,
declared_attr,
registry, registry,
relationship, relationship,
) )
@ -544,6 +545,15 @@ class SQLModelMetaclass(ModelMetaclass, DeclarativeMeta):
**pydantic_annotations, **pydantic_annotations,
**new_cls.__annotations__, **new_cls.__annotations__,
} }
# pydantic will set class attribute value inherited from parent as field
# default value, reset it back
base_fields = {}
for base in bases[::-1]:
if issubclass(base, BaseModel):
base_fields.update(base.model_fields)
for k, v in new_cls.model_fields.items():
if isinstance(v.default, InstrumentedAttribute):
new_cls.model_fields[k] = base_fields.get(k)
def get_config(name: str) -> Any: def get_config(name: str) -> Any:
config_class_value = get_config_value( config_class_value = get_config_value(
@ -558,9 +568,19 @@ class SQLModelMetaclass(ModelMetaclass, DeclarativeMeta):
config_table = get_config("table") config_table = get_config("table")
if config_table is True: if config_table is True:
if new_cls.__name__ != "SQLModel" and not hasattr(new_cls, "__tablename__"):
new_cls.__tablename__ = new_cls.__name__.lower()
# If it was passed by kwargs, ensure it's also set in config # If it was passed by kwargs, ensure it's also set in config
set_config_value(model=new_cls, parameter="table", value=config_table) set_config_value(model=new_cls, parameter="table", value=config_table)
for k, v in get_model_fields(new_cls).items(): for k, v in get_model_fields(new_cls).items():
original_v = getattr(new_cls, k, None)
if (
isinstance(original_v, InstrumentedAttribute)
and k not in class_dict
):
# The attribute was already set by SQLAlchemy, don't override it
# Needed for polymorphic models, see #36
continue
col = get_column_from_field(v) col = get_column_from_field(v)
setattr(new_cls, k, col) setattr(new_cls, k, col)
# Set a config flag to tell FastAPI that this should be read with a field # Set a config flag to tell FastAPI that this should be read with a field
@ -594,7 +614,13 @@ class SQLModelMetaclass(ModelMetaclass, DeclarativeMeta):
# trying to create a new SQLAlchemy, for a new table, with the same name, that # trying to create a new SQLAlchemy, for a new table, with the same name, that
# triggers an error # triggers an error
base_is_table = any(is_table_model_class(base) for base in bases) base_is_table = any(is_table_model_class(base) for base in bases)
if is_table_model_class(cls) and not base_is_table: polymorphic_identity = dict_.get("__mapper_args__", {}).get(
"polymorphic_identity"
)
has_polymorphic = polymorphic_identity is not None
# allow polymorphic models inherit from table models
if is_table_model_class(cls) and (not base_is_table or has_polymorphic):
for rel_name, rel_info in cls.__sqlmodel_relationships__.items(): for rel_name, rel_info in cls.__sqlmodel_relationships__.items():
if rel_info.sa_relationship: if rel_info.sa_relationship:
# There's a SQLAlchemy relationship declared, that takes precedence # There's a SQLAlchemy relationship declared, that takes precedence
@ -641,6 +667,16 @@ class SQLModelMetaclass(ModelMetaclass, DeclarativeMeta):
# Ref: https://github.com/sqlalchemy/sqlalchemy/commit/428ea01f00a9cc7f85e435018565eb6da7af1b77 # Ref: https://github.com/sqlalchemy/sqlalchemy/commit/428ea01f00a9cc7f85e435018565eb6da7af1b77
# Tag: 1.4.36 # Tag: 1.4.36
DeclarativeMeta.__init__(cls, classname, bases, dict_, **kw) DeclarativeMeta.__init__(cls, classname, bases, dict_, **kw)
# # patch sqlmodel field's default value to polymorphic_identity
# if has_polymorphic:
# mapper = inspect(cls)
# polymorphic_on = mapper.polymorphic_on
# polymorphic_property = mapper.get_property_by_column(polymorphic_on)
# field = cls.model_fields.get(polymorphic_property.key)
# def get__polymorphic_identity__(kw):
# return polymorphic_identity
# if field:
# field.default_factory = get__polymorphic_identity__
else: else:
ModelMetaclass.__init__(cls, classname, bases, dict_, **kw) ModelMetaclass.__init__(cls, classname, bases, dict_, **kw)
@ -708,7 +744,7 @@ def get_column_from_field(field: Any) -> Column: # type: ignore
else: else:
field_info = field.field_info field_info = field.field_info
sa_column = getattr(field_info, "sa_column", Undefined) sa_column = getattr(field_info, "sa_column", Undefined)
if isinstance(sa_column, Column): if isinstance(sa_column, Column) or isinstance(sa_column, MappedColumn):
return sa_column return sa_column
sa_type = get_sqlalchemy_type(field) sa_type = get_sqlalchemy_type(field)
primary_key = getattr(field_info, "primary_key", Undefined) primary_key = getattr(field_info, "primary_key", Undefined)
@ -772,7 +808,6 @@ _TSQLModel = TypeVar("_TSQLModel", bound="SQLModel")
class SQLModel(BaseModel, metaclass=SQLModelMetaclass, registry=default_registry): class SQLModel(BaseModel, metaclass=SQLModelMetaclass, registry=default_registry):
# SQLAlchemy needs to set weakref(s), Pydantic will set the other slots values # SQLAlchemy needs to set weakref(s), Pydantic will set the other slots values
__slots__ = ("__weakref__",) __slots__ = ("__weakref__",)
__tablename__: ClassVar[Union[str, Callable[..., str]]]
__sqlmodel_relationships__: ClassVar[Dict[str, RelationshipProperty[Any]]] __sqlmodel_relationships__: ClassVar[Dict[str, RelationshipProperty[Any]]]
__name__: ClassVar[str] __name__: ClassVar[str]
metadata: ClassVar[MetaData] metadata: ClassVar[MetaData]
@ -836,10 +871,6 @@ class SQLModel(BaseModel, metaclass=SQLModelMetaclass, registry=default_registry
if not (isinstance(k, str) and k.startswith("_sa_")) if not (isinstance(k, str) and k.startswith("_sa_"))
] ]
@declared_attr # type: ignore
def __tablename__(cls) -> str:
return cls.__name__.lower()
@classmethod @classmethod
def model_validate( def model_validate(
cls: Type[_TSQLModel], cls: Type[_TSQLModel],

View File

@ -0,0 +1,127 @@
from typing import Optional
from sqlalchemy import ForeignKey
from sqlalchemy.orm import mapped_column
from sqlmodel import Field, Session, SQLModel, create_engine, select
def test_polymorphic_joined_table(clear_sqlmodel) -> None:
class Hero(SQLModel, table=True):
__tablename__ = "hero"
id: Optional[int] = Field(default=None, primary_key=True)
hero_type: str = Field(default="hero")
__mapper_args__ = {
"polymorphic_on": "hero_type",
"polymorphic_identity": "hero",
}
class DarkHero(Hero):
__tablename__ = "dark_hero"
id: Optional[int] = Field(
default=None,
sa_column=mapped_column(ForeignKey("hero.id"), primary_key=True),
)
dark_power: str = Field(
default="dark",
sa_column=mapped_column(
nullable=False, use_existing_column=True, default="dark"
),
)
__mapper_args__ = {
"polymorphic_identity": "dark",
}
engine = create_engine("sqlite:///:memory:", echo=True)
SQLModel.metadata.create_all(engine)
with Session(engine) as db:
hero = Hero()
db.add(hero)
dark_hero = DarkHero()
db.add(dark_hero)
db.commit()
statement = select(DarkHero)
result = db.exec(statement).all()
assert len(result) == 1
assert isinstance(result[0].dark_power, str)
def test_polymorphic_joined_table_sm_field(clear_sqlmodel) -> None:
class Hero(SQLModel, table=True):
__tablename__ = "hero"
id: Optional[int] = Field(default=None, primary_key=True)
hero_type: str = Field(default="hero")
__mapper_args__ = {
"polymorphic_on": "hero_type",
"polymorphic_identity": "hero",
}
class DarkHero(Hero):
__tablename__ = "dark_hero"
id: Optional[int] = Field(
default=None,
primary_key=True,
foreign_key="hero.id",
)
dark_power: str = Field(
default="dark",
sa_column=mapped_column(
nullable=False, use_existing_column=True, default="dark"
),
)
__mapper_args__ = {
"polymorphic_identity": "dark",
}
engine = create_engine("sqlite:///:memory:", echo=True)
SQLModel.metadata.create_all(engine)
with Session(engine) as db:
hero = Hero()
db.add(hero)
dark_hero = DarkHero()
db.add(dark_hero)
db.commit()
statement = select(DarkHero)
result = db.exec(statement).all()
assert len(result) == 1
assert isinstance(result[0].dark_power, str)
def test_polymorphic_single_table(clear_sqlmodel) -> None:
class Hero(SQLModel, table=True):
__tablename__ = "hero"
id: Optional[int] = Field(default=None, primary_key=True)
hero_type: str = Field(default="hero")
__mapper_args__ = {
"polymorphic_on": "hero_type",
"polymorphic_identity": "hero",
}
class DarkHero(Hero):
dark_power: str = Field(
default="dark",
sa_column=mapped_column(
nullable=False, use_existing_column=True, default="dark"
),
)
__mapper_args__ = {
"polymorphic_identity": "dark",
}
engine = create_engine("sqlite:///:memory:", echo=True)
SQLModel.metadata.create_all(engine)
with Session(engine) as db:
hero = Hero()
db.add(hero)
dark_hero = DarkHero()
db.add(dark_hero)
db.commit()
statement = select(DarkHero)
result = db.exec(statement).all()
assert len(result) == 1
assert isinstance(result[0].dark_power, str)