diff --git a/sqlmodel/_compat.py b/sqlmodel/_compat.py index 4e80cdc..6b7d53b 100644 --- a/sqlmodel/_compat.py +++ b/sqlmodel/_compat.py @@ -21,6 +21,7 @@ from typing import ( from pydantic import VERSION as P_VERSION from pydantic import BaseModel from pydantic.fields import FieldInfo +from sqlalchemy import inspect from typing_extensions import Annotated, get_args, get_origin # Reassign variable to make it reexported for mypy @@ -290,6 +291,19 @@ if IS_PYDANTIC_V2: if value is not Undefined: setattr(self_instance, key, value) # End SQLModel override + # Override polymorphic_on default value + mapper = inspect(cls) + polymorphic_on = mapper.polymorphic_on + polymorphic_property = mapper.get_property_by_column(polymorphic_on) + field_info = cls.model_fields.get(polymorphic_property.key) + if field_info: + v = values.get(polymorphic_property.key) + # if model is inherited or polymorphic_on is not explicitly set + # set the polymorphic_on by default + if mapper.inherits or v is None: + setattr( + self_instance, polymorphic_property.key, mapper.polymorphic_identity + ) return self_instance def sqlmodel_validate( diff --git a/sqlmodel/main.py b/sqlmodel/main.py index 3532e81..fcba557 100644 --- a/sqlmodel/main.py +++ b/sqlmodel/main.py @@ -41,9 +41,10 @@ from sqlalchemy import ( ) from sqlalchemy import Enum as sa_Enum from sqlalchemy.orm import ( + InstrumentedAttribute, Mapped, + MappedColumn, RelationshipProperty, - declared_attr, registry, relationship, ) @@ -544,6 +545,15 @@ class SQLModelMetaclass(ModelMetaclass, DeclarativeMeta): **pydantic_annotations, **new_cls.__annotations__, } + # pydantic will set class attribute value inherited from parent as field + # default value, reset it back + base_fields = {} + for base in bases[::-1]: + if issubclass(base, BaseModel): + base_fields.update(base.model_fields) + for k, v in new_cls.model_fields.items(): + if isinstance(v.default, InstrumentedAttribute): + new_cls.model_fields[k] = base_fields.get(k) def get_config(name: str) -> Any: config_class_value = get_config_value( @@ -558,9 +568,19 @@ class SQLModelMetaclass(ModelMetaclass, DeclarativeMeta): config_table = get_config("table") if config_table is True: + if new_cls.__name__ != "SQLModel" and not hasattr(new_cls, "__tablename__"): + new_cls.__tablename__ = new_cls.__name__.lower() # If it was passed by kwargs, ensure it's also set in config set_config_value(model=new_cls, parameter="table", value=config_table) for k, v in get_model_fields(new_cls).items(): + original_v = getattr(new_cls, k, None) + if ( + isinstance(original_v, InstrumentedAttribute) + and k not in class_dict + ): + # The attribute was already set by SQLAlchemy, don't override it + # Needed for polymorphic models, see #36 + continue col = get_column_from_field(v) setattr(new_cls, k, col) # Set a config flag to tell FastAPI that this should be read with a field @@ -594,7 +614,13 @@ class SQLModelMetaclass(ModelMetaclass, DeclarativeMeta): # trying to create a new SQLAlchemy, for a new table, with the same name, that # triggers an error base_is_table = any(is_table_model_class(base) for base in bases) - if is_table_model_class(cls) and not base_is_table: + polymorphic_identity = dict_.get("__mapper_args__", {}).get( + "polymorphic_identity" + ) + has_polymorphic = polymorphic_identity is not None + + # allow polymorphic models inherit from table models + if is_table_model_class(cls) and (not base_is_table or has_polymorphic): for rel_name, rel_info in cls.__sqlmodel_relationships__.items(): if rel_info.sa_relationship: # There's a SQLAlchemy relationship declared, that takes precedence @@ -641,6 +667,16 @@ class SQLModelMetaclass(ModelMetaclass, DeclarativeMeta): # Ref: https://github.com/sqlalchemy/sqlalchemy/commit/428ea01f00a9cc7f85e435018565eb6da7af1b77 # Tag: 1.4.36 DeclarativeMeta.__init__(cls, classname, bases, dict_, **kw) + # # patch sqlmodel field's default value to polymorphic_identity + # if has_polymorphic: + # mapper = inspect(cls) + # polymorphic_on = mapper.polymorphic_on + # polymorphic_property = mapper.get_property_by_column(polymorphic_on) + # field = cls.model_fields.get(polymorphic_property.key) + # def get__polymorphic_identity__(kw): + # return polymorphic_identity + # if field: + # field.default_factory = get__polymorphic_identity__ else: ModelMetaclass.__init__(cls, classname, bases, dict_, **kw) @@ -708,7 +744,7 @@ def get_column_from_field(field: Any) -> Column: # type: ignore else: field_info = field.field_info sa_column = getattr(field_info, "sa_column", Undefined) - if isinstance(sa_column, Column): + if isinstance(sa_column, Column) or isinstance(sa_column, MappedColumn): return sa_column sa_type = get_sqlalchemy_type(field) primary_key = getattr(field_info, "primary_key", Undefined) @@ -772,7 +808,6 @@ _TSQLModel = TypeVar("_TSQLModel", bound="SQLModel") class SQLModel(BaseModel, metaclass=SQLModelMetaclass, registry=default_registry): # SQLAlchemy needs to set weakref(s), Pydantic will set the other slots values __slots__ = ("__weakref__",) - __tablename__: ClassVar[Union[str, Callable[..., str]]] __sqlmodel_relationships__: ClassVar[Dict[str, RelationshipProperty[Any]]] __name__: ClassVar[str] metadata: ClassVar[MetaData] @@ -836,10 +871,6 @@ class SQLModel(BaseModel, metaclass=SQLModelMetaclass, registry=default_registry if not (isinstance(k, str) and k.startswith("_sa_")) ] - @declared_attr # type: ignore - def __tablename__(cls) -> str: - return cls.__name__.lower() - @classmethod def model_validate( cls: Type[_TSQLModel], diff --git a/tests/test_polymorphic_model.py b/tests/test_polymorphic_model.py new file mode 100644 index 0000000..c9c8330 --- /dev/null +++ b/tests/test_polymorphic_model.py @@ -0,0 +1,127 @@ +from typing import Optional + +from sqlalchemy import ForeignKey +from sqlalchemy.orm import mapped_column +from sqlmodel import Field, Session, SQLModel, create_engine, select + + +def test_polymorphic_joined_table(clear_sqlmodel) -> None: + class Hero(SQLModel, table=True): + __tablename__ = "hero" + id: Optional[int] = Field(default=None, primary_key=True) + hero_type: str = Field(default="hero") + + __mapper_args__ = { + "polymorphic_on": "hero_type", + "polymorphic_identity": "hero", + } + + class DarkHero(Hero): + __tablename__ = "dark_hero" + id: Optional[int] = Field( + default=None, + sa_column=mapped_column(ForeignKey("hero.id"), primary_key=True), + ) + dark_power: str = Field( + default="dark", + sa_column=mapped_column( + nullable=False, use_existing_column=True, default="dark" + ), + ) + + __mapper_args__ = { + "polymorphic_identity": "dark", + } + + engine = create_engine("sqlite:///:memory:", echo=True) + SQLModel.metadata.create_all(engine) + with Session(engine) as db: + hero = Hero() + db.add(hero) + dark_hero = DarkHero() + db.add(dark_hero) + db.commit() + statement = select(DarkHero) + result = db.exec(statement).all() + assert len(result) == 1 + assert isinstance(result[0].dark_power, str) + + +def test_polymorphic_joined_table_sm_field(clear_sqlmodel) -> None: + class Hero(SQLModel, table=True): + __tablename__ = "hero" + id: Optional[int] = Field(default=None, primary_key=True) + hero_type: str = Field(default="hero") + + __mapper_args__ = { + "polymorphic_on": "hero_type", + "polymorphic_identity": "hero", + } + + class DarkHero(Hero): + __tablename__ = "dark_hero" + id: Optional[int] = Field( + default=None, + primary_key=True, + foreign_key="hero.id", + ) + dark_power: str = Field( + default="dark", + sa_column=mapped_column( + nullable=False, use_existing_column=True, default="dark" + ), + ) + + __mapper_args__ = { + "polymorphic_identity": "dark", + } + + engine = create_engine("sqlite:///:memory:", echo=True) + SQLModel.metadata.create_all(engine) + with Session(engine) as db: + hero = Hero() + db.add(hero) + dark_hero = DarkHero() + db.add(dark_hero) + db.commit() + statement = select(DarkHero) + result = db.exec(statement).all() + assert len(result) == 1 + assert isinstance(result[0].dark_power, str) + + +def test_polymorphic_single_table(clear_sqlmodel) -> None: + class Hero(SQLModel, table=True): + __tablename__ = "hero" + id: Optional[int] = Field(default=None, primary_key=True) + hero_type: str = Field(default="hero") + + __mapper_args__ = { + "polymorphic_on": "hero_type", + "polymorphic_identity": "hero", + } + + class DarkHero(Hero): + dark_power: str = Field( + default="dark", + sa_column=mapped_column( + nullable=False, use_existing_column=True, default="dark" + ), + ) + + __mapper_args__ = { + "polymorphic_identity": "dark", + } + + engine = create_engine("sqlite:///:memory:", echo=True) + SQLModel.metadata.create_all(engine) + with Session(engine) as db: + hero = Hero() + db.add(hero) + dark_hero = DarkHero() + db.add(dark_hero) + db.commit() + statement = select(DarkHero) + result = db.exec(statement).all() + assert len(result) == 1 + assert isinstance(result[0].dark_power, str)