support sqlalchemy polymorphic
This commit is contained in:
parent
e86b5fcc84
commit
6d93a46fe0
@ -21,6 +21,7 @@ from typing import (
|
|||||||
from pydantic import VERSION as P_VERSION
|
from pydantic import VERSION as P_VERSION
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from pydantic.fields import FieldInfo
|
from pydantic.fields import FieldInfo
|
||||||
|
from sqlalchemy import inspect
|
||||||
from typing_extensions import Annotated, get_args, get_origin
|
from typing_extensions import Annotated, get_args, get_origin
|
||||||
|
|
||||||
# Reassign variable to make it reexported for mypy
|
# Reassign variable to make it reexported for mypy
|
||||||
@ -290,6 +291,19 @@ if IS_PYDANTIC_V2:
|
|||||||
if value is not Undefined:
|
if value is not Undefined:
|
||||||
setattr(self_instance, key, value)
|
setattr(self_instance, key, value)
|
||||||
# End SQLModel override
|
# End SQLModel override
|
||||||
|
# Override polymorphic_on default value
|
||||||
|
mapper = inspect(cls)
|
||||||
|
polymorphic_on = mapper.polymorphic_on
|
||||||
|
polymorphic_property = mapper.get_property_by_column(polymorphic_on)
|
||||||
|
field_info = cls.model_fields.get(polymorphic_property.key)
|
||||||
|
if field_info:
|
||||||
|
v = values.get(polymorphic_property.key)
|
||||||
|
# if model is inherited or polymorphic_on is not explicitly set
|
||||||
|
# set the polymorphic_on by default
|
||||||
|
if mapper.inherits or v is None:
|
||||||
|
setattr(
|
||||||
|
self_instance, polymorphic_property.key, mapper.polymorphic_identity
|
||||||
|
)
|
||||||
return self_instance
|
return self_instance
|
||||||
|
|
||||||
def sqlmodel_validate(
|
def sqlmodel_validate(
|
||||||
|
@ -41,9 +41,10 @@ from sqlalchemy import (
|
|||||||
)
|
)
|
||||||
from sqlalchemy import Enum as sa_Enum
|
from sqlalchemy import Enum as sa_Enum
|
||||||
from sqlalchemy.orm import (
|
from sqlalchemy.orm import (
|
||||||
|
InstrumentedAttribute,
|
||||||
Mapped,
|
Mapped,
|
||||||
|
MappedColumn,
|
||||||
RelationshipProperty,
|
RelationshipProperty,
|
||||||
declared_attr,
|
|
||||||
registry,
|
registry,
|
||||||
relationship,
|
relationship,
|
||||||
)
|
)
|
||||||
@ -544,6 +545,15 @@ class SQLModelMetaclass(ModelMetaclass, DeclarativeMeta):
|
|||||||
**pydantic_annotations,
|
**pydantic_annotations,
|
||||||
**new_cls.__annotations__,
|
**new_cls.__annotations__,
|
||||||
}
|
}
|
||||||
|
# pydantic will set class attribute value inherited from parent as field
|
||||||
|
# default value, reset it back
|
||||||
|
base_fields = {}
|
||||||
|
for base in bases[::-1]:
|
||||||
|
if issubclass(base, BaseModel):
|
||||||
|
base_fields.update(base.model_fields)
|
||||||
|
for k, v in new_cls.model_fields.items():
|
||||||
|
if isinstance(v.default, InstrumentedAttribute):
|
||||||
|
new_cls.model_fields[k] = base_fields.get(k)
|
||||||
|
|
||||||
def get_config(name: str) -> Any:
|
def get_config(name: str) -> Any:
|
||||||
config_class_value = get_config_value(
|
config_class_value = get_config_value(
|
||||||
@ -558,9 +568,19 @@ class SQLModelMetaclass(ModelMetaclass, DeclarativeMeta):
|
|||||||
|
|
||||||
config_table = get_config("table")
|
config_table = get_config("table")
|
||||||
if config_table is True:
|
if config_table is True:
|
||||||
|
if new_cls.__name__ != "SQLModel" and not hasattr(new_cls, "__tablename__"):
|
||||||
|
new_cls.__tablename__ = new_cls.__name__.lower()
|
||||||
# If it was passed by kwargs, ensure it's also set in config
|
# If it was passed by kwargs, ensure it's also set in config
|
||||||
set_config_value(model=new_cls, parameter="table", value=config_table)
|
set_config_value(model=new_cls, parameter="table", value=config_table)
|
||||||
for k, v in get_model_fields(new_cls).items():
|
for k, v in get_model_fields(new_cls).items():
|
||||||
|
original_v = getattr(new_cls, k, None)
|
||||||
|
if (
|
||||||
|
isinstance(original_v, InstrumentedAttribute)
|
||||||
|
and k not in class_dict
|
||||||
|
):
|
||||||
|
# The attribute was already set by SQLAlchemy, don't override it
|
||||||
|
# Needed for polymorphic models, see #36
|
||||||
|
continue
|
||||||
col = get_column_from_field(v)
|
col = get_column_from_field(v)
|
||||||
setattr(new_cls, k, col)
|
setattr(new_cls, k, col)
|
||||||
# Set a config flag to tell FastAPI that this should be read with a field
|
# Set a config flag to tell FastAPI that this should be read with a field
|
||||||
@ -594,7 +614,13 @@ class SQLModelMetaclass(ModelMetaclass, DeclarativeMeta):
|
|||||||
# trying to create a new SQLAlchemy, for a new table, with the same name, that
|
# trying to create a new SQLAlchemy, for a new table, with the same name, that
|
||||||
# triggers an error
|
# triggers an error
|
||||||
base_is_table = any(is_table_model_class(base) for base in bases)
|
base_is_table = any(is_table_model_class(base) for base in bases)
|
||||||
if is_table_model_class(cls) and not base_is_table:
|
polymorphic_identity = dict_.get("__mapper_args__", {}).get(
|
||||||
|
"polymorphic_identity"
|
||||||
|
)
|
||||||
|
has_polymorphic = polymorphic_identity is not None
|
||||||
|
|
||||||
|
# allow polymorphic models inherit from table models
|
||||||
|
if is_table_model_class(cls) and (not base_is_table or has_polymorphic):
|
||||||
for rel_name, rel_info in cls.__sqlmodel_relationships__.items():
|
for rel_name, rel_info in cls.__sqlmodel_relationships__.items():
|
||||||
if rel_info.sa_relationship:
|
if rel_info.sa_relationship:
|
||||||
# There's a SQLAlchemy relationship declared, that takes precedence
|
# There's a SQLAlchemy relationship declared, that takes precedence
|
||||||
@ -641,6 +667,16 @@ class SQLModelMetaclass(ModelMetaclass, DeclarativeMeta):
|
|||||||
# Ref: https://github.com/sqlalchemy/sqlalchemy/commit/428ea01f00a9cc7f85e435018565eb6da7af1b77
|
# Ref: https://github.com/sqlalchemy/sqlalchemy/commit/428ea01f00a9cc7f85e435018565eb6da7af1b77
|
||||||
# Tag: 1.4.36
|
# Tag: 1.4.36
|
||||||
DeclarativeMeta.__init__(cls, classname, bases, dict_, **kw)
|
DeclarativeMeta.__init__(cls, classname, bases, dict_, **kw)
|
||||||
|
# # patch sqlmodel field's default value to polymorphic_identity
|
||||||
|
# if has_polymorphic:
|
||||||
|
# mapper = inspect(cls)
|
||||||
|
# polymorphic_on = mapper.polymorphic_on
|
||||||
|
# polymorphic_property = mapper.get_property_by_column(polymorphic_on)
|
||||||
|
# field = cls.model_fields.get(polymorphic_property.key)
|
||||||
|
# def get__polymorphic_identity__(kw):
|
||||||
|
# return polymorphic_identity
|
||||||
|
# if field:
|
||||||
|
# field.default_factory = get__polymorphic_identity__
|
||||||
else:
|
else:
|
||||||
ModelMetaclass.__init__(cls, classname, bases, dict_, **kw)
|
ModelMetaclass.__init__(cls, classname, bases, dict_, **kw)
|
||||||
|
|
||||||
@ -708,7 +744,7 @@ def get_column_from_field(field: Any) -> Column: # type: ignore
|
|||||||
else:
|
else:
|
||||||
field_info = field.field_info
|
field_info = field.field_info
|
||||||
sa_column = getattr(field_info, "sa_column", Undefined)
|
sa_column = getattr(field_info, "sa_column", Undefined)
|
||||||
if isinstance(sa_column, Column):
|
if isinstance(sa_column, Column) or isinstance(sa_column, MappedColumn):
|
||||||
return sa_column
|
return sa_column
|
||||||
sa_type = get_sqlalchemy_type(field)
|
sa_type = get_sqlalchemy_type(field)
|
||||||
primary_key = getattr(field_info, "primary_key", Undefined)
|
primary_key = getattr(field_info, "primary_key", Undefined)
|
||||||
@ -772,7 +808,6 @@ _TSQLModel = TypeVar("_TSQLModel", bound="SQLModel")
|
|||||||
class SQLModel(BaseModel, metaclass=SQLModelMetaclass, registry=default_registry):
|
class SQLModel(BaseModel, metaclass=SQLModelMetaclass, registry=default_registry):
|
||||||
# SQLAlchemy needs to set weakref(s), Pydantic will set the other slots values
|
# SQLAlchemy needs to set weakref(s), Pydantic will set the other slots values
|
||||||
__slots__ = ("__weakref__",)
|
__slots__ = ("__weakref__",)
|
||||||
__tablename__: ClassVar[Union[str, Callable[..., str]]]
|
|
||||||
__sqlmodel_relationships__: ClassVar[Dict[str, RelationshipProperty[Any]]]
|
__sqlmodel_relationships__: ClassVar[Dict[str, RelationshipProperty[Any]]]
|
||||||
__name__: ClassVar[str]
|
__name__: ClassVar[str]
|
||||||
metadata: ClassVar[MetaData]
|
metadata: ClassVar[MetaData]
|
||||||
@ -836,10 +871,6 @@ class SQLModel(BaseModel, metaclass=SQLModelMetaclass, registry=default_registry
|
|||||||
if not (isinstance(k, str) and k.startswith("_sa_"))
|
if not (isinstance(k, str) and k.startswith("_sa_"))
|
||||||
]
|
]
|
||||||
|
|
||||||
@declared_attr # type: ignore
|
|
||||||
def __tablename__(cls) -> str:
|
|
||||||
return cls.__name__.lower()
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def model_validate(
|
def model_validate(
|
||||||
cls: Type[_TSQLModel],
|
cls: Type[_TSQLModel],
|
||||||
|
127
tests/test_polymorphic_model.py
Normal file
127
tests/test_polymorphic_model.py
Normal file
@ -0,0 +1,127 @@
|
|||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from sqlalchemy import ForeignKey
|
||||||
|
from sqlalchemy.orm import mapped_column
|
||||||
|
from sqlmodel import Field, Session, SQLModel, create_engine, select
|
||||||
|
|
||||||
|
|
||||||
|
def test_polymorphic_joined_table(clear_sqlmodel) -> None:
|
||||||
|
class Hero(SQLModel, table=True):
|
||||||
|
__tablename__ = "hero"
|
||||||
|
id: Optional[int] = Field(default=None, primary_key=True)
|
||||||
|
hero_type: str = Field(default="hero")
|
||||||
|
|
||||||
|
__mapper_args__ = {
|
||||||
|
"polymorphic_on": "hero_type",
|
||||||
|
"polymorphic_identity": "hero",
|
||||||
|
}
|
||||||
|
|
||||||
|
class DarkHero(Hero):
|
||||||
|
__tablename__ = "dark_hero"
|
||||||
|
id: Optional[int] = Field(
|
||||||
|
default=None,
|
||||||
|
sa_column=mapped_column(ForeignKey("hero.id"), primary_key=True),
|
||||||
|
)
|
||||||
|
dark_power: str = Field(
|
||||||
|
default="dark",
|
||||||
|
sa_column=mapped_column(
|
||||||
|
nullable=False, use_existing_column=True, default="dark"
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
__mapper_args__ = {
|
||||||
|
"polymorphic_identity": "dark",
|
||||||
|
}
|
||||||
|
|
||||||
|
engine = create_engine("sqlite:///:memory:", echo=True)
|
||||||
|
SQLModel.metadata.create_all(engine)
|
||||||
|
with Session(engine) as db:
|
||||||
|
hero = Hero()
|
||||||
|
db.add(hero)
|
||||||
|
dark_hero = DarkHero()
|
||||||
|
db.add(dark_hero)
|
||||||
|
db.commit()
|
||||||
|
statement = select(DarkHero)
|
||||||
|
result = db.exec(statement).all()
|
||||||
|
assert len(result) == 1
|
||||||
|
assert isinstance(result[0].dark_power, str)
|
||||||
|
|
||||||
|
|
||||||
|
def test_polymorphic_joined_table_sm_field(clear_sqlmodel) -> None:
|
||||||
|
class Hero(SQLModel, table=True):
|
||||||
|
__tablename__ = "hero"
|
||||||
|
id: Optional[int] = Field(default=None, primary_key=True)
|
||||||
|
hero_type: str = Field(default="hero")
|
||||||
|
|
||||||
|
__mapper_args__ = {
|
||||||
|
"polymorphic_on": "hero_type",
|
||||||
|
"polymorphic_identity": "hero",
|
||||||
|
}
|
||||||
|
|
||||||
|
class DarkHero(Hero):
|
||||||
|
__tablename__ = "dark_hero"
|
||||||
|
id: Optional[int] = Field(
|
||||||
|
default=None,
|
||||||
|
primary_key=True,
|
||||||
|
foreign_key="hero.id",
|
||||||
|
)
|
||||||
|
dark_power: str = Field(
|
||||||
|
default="dark",
|
||||||
|
sa_column=mapped_column(
|
||||||
|
nullable=False, use_existing_column=True, default="dark"
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
__mapper_args__ = {
|
||||||
|
"polymorphic_identity": "dark",
|
||||||
|
}
|
||||||
|
|
||||||
|
engine = create_engine("sqlite:///:memory:", echo=True)
|
||||||
|
SQLModel.metadata.create_all(engine)
|
||||||
|
with Session(engine) as db:
|
||||||
|
hero = Hero()
|
||||||
|
db.add(hero)
|
||||||
|
dark_hero = DarkHero()
|
||||||
|
db.add(dark_hero)
|
||||||
|
db.commit()
|
||||||
|
statement = select(DarkHero)
|
||||||
|
result = db.exec(statement).all()
|
||||||
|
assert len(result) == 1
|
||||||
|
assert isinstance(result[0].dark_power, str)
|
||||||
|
|
||||||
|
|
||||||
|
def test_polymorphic_single_table(clear_sqlmodel) -> None:
|
||||||
|
class Hero(SQLModel, table=True):
|
||||||
|
__tablename__ = "hero"
|
||||||
|
id: Optional[int] = Field(default=None, primary_key=True)
|
||||||
|
hero_type: str = Field(default="hero")
|
||||||
|
|
||||||
|
__mapper_args__ = {
|
||||||
|
"polymorphic_on": "hero_type",
|
||||||
|
"polymorphic_identity": "hero",
|
||||||
|
}
|
||||||
|
|
||||||
|
class DarkHero(Hero):
|
||||||
|
dark_power: str = Field(
|
||||||
|
default="dark",
|
||||||
|
sa_column=mapped_column(
|
||||||
|
nullable=False, use_existing_column=True, default="dark"
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
__mapper_args__ = {
|
||||||
|
"polymorphic_identity": "dark",
|
||||||
|
}
|
||||||
|
|
||||||
|
engine = create_engine("sqlite:///:memory:", echo=True)
|
||||||
|
SQLModel.metadata.create_all(engine)
|
||||||
|
with Session(engine) as db:
|
||||||
|
hero = Hero()
|
||||||
|
db.add(hero)
|
||||||
|
dark_hero = DarkHero()
|
||||||
|
db.add(dark_hero)
|
||||||
|
db.commit()
|
||||||
|
statement = select(DarkHero)
|
||||||
|
result = db.exec(statement).all()
|
||||||
|
assert len(result) == 1
|
||||||
|
assert isinstance(result[0].dark_power, str)
|
Loading…
x
Reference in New Issue
Block a user