support sqlalchemy polymorphic
This commit is contained in:
parent
e86b5fcc84
commit
6d93a46fe0
@ -21,6 +21,7 @@ from typing import (
|
||||
from pydantic import VERSION as P_VERSION
|
||||
from pydantic import BaseModel
|
||||
from pydantic.fields import FieldInfo
|
||||
from sqlalchemy import inspect
|
||||
from typing_extensions import Annotated, get_args, get_origin
|
||||
|
||||
# Reassign variable to make it reexported for mypy
|
||||
@ -290,6 +291,19 @@ if IS_PYDANTIC_V2:
|
||||
if value is not Undefined:
|
||||
setattr(self_instance, key, value)
|
||||
# End SQLModel override
|
||||
# Override polymorphic_on default value
|
||||
mapper = inspect(cls)
|
||||
polymorphic_on = mapper.polymorphic_on
|
||||
polymorphic_property = mapper.get_property_by_column(polymorphic_on)
|
||||
field_info = cls.model_fields.get(polymorphic_property.key)
|
||||
if field_info:
|
||||
v = values.get(polymorphic_property.key)
|
||||
# if model is inherited or polymorphic_on is not explicitly set
|
||||
# set the polymorphic_on by default
|
||||
if mapper.inherits or v is None:
|
||||
setattr(
|
||||
self_instance, polymorphic_property.key, mapper.polymorphic_identity
|
||||
)
|
||||
return self_instance
|
||||
|
||||
def sqlmodel_validate(
|
||||
|
@ -41,9 +41,10 @@ from sqlalchemy import (
|
||||
)
|
||||
from sqlalchemy import Enum as sa_Enum
|
||||
from sqlalchemy.orm import (
|
||||
InstrumentedAttribute,
|
||||
Mapped,
|
||||
MappedColumn,
|
||||
RelationshipProperty,
|
||||
declared_attr,
|
||||
registry,
|
||||
relationship,
|
||||
)
|
||||
@ -544,6 +545,15 @@ class SQLModelMetaclass(ModelMetaclass, DeclarativeMeta):
|
||||
**pydantic_annotations,
|
||||
**new_cls.__annotations__,
|
||||
}
|
||||
# pydantic will set class attribute value inherited from parent as field
|
||||
# default value, reset it back
|
||||
base_fields = {}
|
||||
for base in bases[::-1]:
|
||||
if issubclass(base, BaseModel):
|
||||
base_fields.update(base.model_fields)
|
||||
for k, v in new_cls.model_fields.items():
|
||||
if isinstance(v.default, InstrumentedAttribute):
|
||||
new_cls.model_fields[k] = base_fields.get(k)
|
||||
|
||||
def get_config(name: str) -> Any:
|
||||
config_class_value = get_config_value(
|
||||
@ -558,9 +568,19 @@ class SQLModelMetaclass(ModelMetaclass, DeclarativeMeta):
|
||||
|
||||
config_table = get_config("table")
|
||||
if config_table is True:
|
||||
if new_cls.__name__ != "SQLModel" and not hasattr(new_cls, "__tablename__"):
|
||||
new_cls.__tablename__ = new_cls.__name__.lower()
|
||||
# If it was passed by kwargs, ensure it's also set in config
|
||||
set_config_value(model=new_cls, parameter="table", value=config_table)
|
||||
for k, v in get_model_fields(new_cls).items():
|
||||
original_v = getattr(new_cls, k, None)
|
||||
if (
|
||||
isinstance(original_v, InstrumentedAttribute)
|
||||
and k not in class_dict
|
||||
):
|
||||
# The attribute was already set by SQLAlchemy, don't override it
|
||||
# Needed for polymorphic models, see #36
|
||||
continue
|
||||
col = get_column_from_field(v)
|
||||
setattr(new_cls, k, col)
|
||||
# Set a config flag to tell FastAPI that this should be read with a field
|
||||
@ -594,7 +614,13 @@ class SQLModelMetaclass(ModelMetaclass, DeclarativeMeta):
|
||||
# trying to create a new SQLAlchemy, for a new table, with the same name, that
|
||||
# triggers an error
|
||||
base_is_table = any(is_table_model_class(base) for base in bases)
|
||||
if is_table_model_class(cls) and not base_is_table:
|
||||
polymorphic_identity = dict_.get("__mapper_args__", {}).get(
|
||||
"polymorphic_identity"
|
||||
)
|
||||
has_polymorphic = polymorphic_identity is not None
|
||||
|
||||
# allow polymorphic models inherit from table models
|
||||
if is_table_model_class(cls) and (not base_is_table or has_polymorphic):
|
||||
for rel_name, rel_info in cls.__sqlmodel_relationships__.items():
|
||||
if rel_info.sa_relationship:
|
||||
# There's a SQLAlchemy relationship declared, that takes precedence
|
||||
@ -641,6 +667,16 @@ class SQLModelMetaclass(ModelMetaclass, DeclarativeMeta):
|
||||
# Ref: https://github.com/sqlalchemy/sqlalchemy/commit/428ea01f00a9cc7f85e435018565eb6da7af1b77
|
||||
# Tag: 1.4.36
|
||||
DeclarativeMeta.__init__(cls, classname, bases, dict_, **kw)
|
||||
# # patch sqlmodel field's default value to polymorphic_identity
|
||||
# if has_polymorphic:
|
||||
# mapper = inspect(cls)
|
||||
# polymorphic_on = mapper.polymorphic_on
|
||||
# polymorphic_property = mapper.get_property_by_column(polymorphic_on)
|
||||
# field = cls.model_fields.get(polymorphic_property.key)
|
||||
# def get__polymorphic_identity__(kw):
|
||||
# return polymorphic_identity
|
||||
# if field:
|
||||
# field.default_factory = get__polymorphic_identity__
|
||||
else:
|
||||
ModelMetaclass.__init__(cls, classname, bases, dict_, **kw)
|
||||
|
||||
@ -708,7 +744,7 @@ def get_column_from_field(field: Any) -> Column: # type: ignore
|
||||
else:
|
||||
field_info = field.field_info
|
||||
sa_column = getattr(field_info, "sa_column", Undefined)
|
||||
if isinstance(sa_column, Column):
|
||||
if isinstance(sa_column, Column) or isinstance(sa_column, MappedColumn):
|
||||
return sa_column
|
||||
sa_type = get_sqlalchemy_type(field)
|
||||
primary_key = getattr(field_info, "primary_key", Undefined)
|
||||
@ -772,7 +808,6 @@ _TSQLModel = TypeVar("_TSQLModel", bound="SQLModel")
|
||||
class SQLModel(BaseModel, metaclass=SQLModelMetaclass, registry=default_registry):
|
||||
# SQLAlchemy needs to set weakref(s), Pydantic will set the other slots values
|
||||
__slots__ = ("__weakref__",)
|
||||
__tablename__: ClassVar[Union[str, Callable[..., str]]]
|
||||
__sqlmodel_relationships__: ClassVar[Dict[str, RelationshipProperty[Any]]]
|
||||
__name__: ClassVar[str]
|
||||
metadata: ClassVar[MetaData]
|
||||
@ -836,10 +871,6 @@ class SQLModel(BaseModel, metaclass=SQLModelMetaclass, registry=default_registry
|
||||
if not (isinstance(k, str) and k.startswith("_sa_"))
|
||||
]
|
||||
|
||||
@declared_attr # type: ignore
|
||||
def __tablename__(cls) -> str:
|
||||
return cls.__name__.lower()
|
||||
|
||||
@classmethod
|
||||
def model_validate(
|
||||
cls: Type[_TSQLModel],
|
||||
|
127
tests/test_polymorphic_model.py
Normal file
127
tests/test_polymorphic_model.py
Normal file
@ -0,0 +1,127 @@
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy import ForeignKey
|
||||
from sqlalchemy.orm import mapped_column
|
||||
from sqlmodel import Field, Session, SQLModel, create_engine, select
|
||||
|
||||
|
||||
def test_polymorphic_joined_table(clear_sqlmodel) -> None:
|
||||
class Hero(SQLModel, table=True):
|
||||
__tablename__ = "hero"
|
||||
id: Optional[int] = Field(default=None, primary_key=True)
|
||||
hero_type: str = Field(default="hero")
|
||||
|
||||
__mapper_args__ = {
|
||||
"polymorphic_on": "hero_type",
|
||||
"polymorphic_identity": "hero",
|
||||
}
|
||||
|
||||
class DarkHero(Hero):
|
||||
__tablename__ = "dark_hero"
|
||||
id: Optional[int] = Field(
|
||||
default=None,
|
||||
sa_column=mapped_column(ForeignKey("hero.id"), primary_key=True),
|
||||
)
|
||||
dark_power: str = Field(
|
||||
default="dark",
|
||||
sa_column=mapped_column(
|
||||
nullable=False, use_existing_column=True, default="dark"
|
||||
),
|
||||
)
|
||||
|
||||
__mapper_args__ = {
|
||||
"polymorphic_identity": "dark",
|
||||
}
|
||||
|
||||
engine = create_engine("sqlite:///:memory:", echo=True)
|
||||
SQLModel.metadata.create_all(engine)
|
||||
with Session(engine) as db:
|
||||
hero = Hero()
|
||||
db.add(hero)
|
||||
dark_hero = DarkHero()
|
||||
db.add(dark_hero)
|
||||
db.commit()
|
||||
statement = select(DarkHero)
|
||||
result = db.exec(statement).all()
|
||||
assert len(result) == 1
|
||||
assert isinstance(result[0].dark_power, str)
|
||||
|
||||
|
||||
def test_polymorphic_joined_table_sm_field(clear_sqlmodel) -> None:
|
||||
class Hero(SQLModel, table=True):
|
||||
__tablename__ = "hero"
|
||||
id: Optional[int] = Field(default=None, primary_key=True)
|
||||
hero_type: str = Field(default="hero")
|
||||
|
||||
__mapper_args__ = {
|
||||
"polymorphic_on": "hero_type",
|
||||
"polymorphic_identity": "hero",
|
||||
}
|
||||
|
||||
class DarkHero(Hero):
|
||||
__tablename__ = "dark_hero"
|
||||
id: Optional[int] = Field(
|
||||
default=None,
|
||||
primary_key=True,
|
||||
foreign_key="hero.id",
|
||||
)
|
||||
dark_power: str = Field(
|
||||
default="dark",
|
||||
sa_column=mapped_column(
|
||||
nullable=False, use_existing_column=True, default="dark"
|
||||
),
|
||||
)
|
||||
|
||||
__mapper_args__ = {
|
||||
"polymorphic_identity": "dark",
|
||||
}
|
||||
|
||||
engine = create_engine("sqlite:///:memory:", echo=True)
|
||||
SQLModel.metadata.create_all(engine)
|
||||
with Session(engine) as db:
|
||||
hero = Hero()
|
||||
db.add(hero)
|
||||
dark_hero = DarkHero()
|
||||
db.add(dark_hero)
|
||||
db.commit()
|
||||
statement = select(DarkHero)
|
||||
result = db.exec(statement).all()
|
||||
assert len(result) == 1
|
||||
assert isinstance(result[0].dark_power, str)
|
||||
|
||||
|
||||
def test_polymorphic_single_table(clear_sqlmodel) -> None:
|
||||
class Hero(SQLModel, table=True):
|
||||
__tablename__ = "hero"
|
||||
id: Optional[int] = Field(default=None, primary_key=True)
|
||||
hero_type: str = Field(default="hero")
|
||||
|
||||
__mapper_args__ = {
|
||||
"polymorphic_on": "hero_type",
|
||||
"polymorphic_identity": "hero",
|
||||
}
|
||||
|
||||
class DarkHero(Hero):
|
||||
dark_power: str = Field(
|
||||
default="dark",
|
||||
sa_column=mapped_column(
|
||||
nullable=False, use_existing_column=True, default="dark"
|
||||
),
|
||||
)
|
||||
|
||||
__mapper_args__ = {
|
||||
"polymorphic_identity": "dark",
|
||||
}
|
||||
|
||||
engine = create_engine("sqlite:///:memory:", echo=True)
|
||||
SQLModel.metadata.create_all(engine)
|
||||
with Session(engine) as db:
|
||||
hero = Hero()
|
||||
db.add(hero)
|
||||
dark_hero = DarkHero()
|
||||
db.add(dark_hero)
|
||||
db.commit()
|
||||
statement = select(DarkHero)
|
||||
result = db.exec(statement).all()
|
||||
assert len(result) == 1
|
||||
assert isinstance(result[0].dark_power, str)
|
Loading…
x
Reference in New Issue
Block a user