From e6ad74d50a943f006cb4876193d55edba043a727 Mon Sep 17 00:00:00 2001 From: John Lyu Date: Tue, 26 Nov 2024 11:17:55 +0800 Subject: [PATCH] fix lint --- sqlmodel/_compat.py | 30 ++++++++++++++++-------------- sqlmodel/main.py | 9 +++++---- 2 files changed, 21 insertions(+), 18 deletions(-) diff --git a/sqlmodel/_compat.py b/sqlmodel/_compat.py index 740e27a..10742d8 100644 --- a/sqlmodel/_compat.py +++ b/sqlmodel/_compat.py @@ -22,6 +22,7 @@ from pydantic import VERSION as P_VERSION from pydantic import BaseModel from pydantic.fields import FieldInfo from sqlalchemy import inspect +from sqlalchemy.orm import Mapper from typing_extensions import Annotated, get_args, get_origin # Reassign variable to make it reexported for mypy @@ -293,20 +294,21 @@ if IS_PYDANTIC_V2: # End SQLModel override # Override polymorphic_on default value mapper = inspect(cls) - polymorphic_on = mapper.polymorphic_on - if polymorphic_on is not None: - polymorphic_property = mapper.get_property_by_column(polymorphic_on) - field_info = cls.model_fields.get(polymorphic_property.key) - if field_info: - v = values.get(polymorphic_property.key) - # if model is inherited or polymorphic_on is not explicitly set - # set the polymorphic_on by default - if mapper.inherits or v is None: - setattr( - self_instance, - polymorphic_property.key, - mapper.polymorphic_identity, - ) + if isinstance(mapper, Mapper): + polymorphic_on = mapper.polymorphic_on + if polymorphic_on is not None: + polymorphic_property = mapper.get_property_by_column(polymorphic_on) + field_info = cls.model_fields.get(polymorphic_property.key) + if field_info: + v = values.get(polymorphic_property.key) + # if model is inherited or polymorphic_on is not explicitly set + # set the polymorphic_on by default + if mapper.inherits or v is None: + setattr( + self_instance, + polymorphic_property.key, + mapper.polymorphic_identity, + ) return self_instance def sqlmodel_validate( diff --git a/sqlmodel/main.py b/sqlmodel/main.py index f85dfc4..923079e 100644 --- a/sqlmodel/main.py +++ b/sqlmodel/main.py @@ -551,9 +551,10 @@ class SQLModelMetaclass(ModelMetaclass, DeclarativeMeta): for base in bases[::-1]: if issubclass(base, BaseModel): base_fields.update(base.model_fields) - for k, v in new_cls.model_fields.items(): + fields = get_model_fields(new_cls) + for k, v in fields.items(): if isinstance(v.default, InstrumentedAttribute): - new_cls.model_fields[k] = base_fields.get(k) + fields[k] = base_fields.get(k, FieldInfo()) def get_config(name: str) -> Any: config_class_value = get_config_value( @@ -572,7 +573,7 @@ class SQLModelMetaclass(ModelMetaclass, DeclarativeMeta): # or if __tablename__ is in __annotations__. Only set __tablename__ if it's # a table model if new_cls.__name__ != "SQLModel" and not hasattr(new_cls, "__tablename__"): - new_cls.__tablename__ = new_cls.__name__.lower() + setattr(new_cls, "__tablename__", new_cls.__name__.lower()) # noqa: B010 # If it was passed by kwargs, ensure it's also set in config set_config_value(model=new_cls, parameter="table", value=config_table) for k, v in get_model_fields(new_cls).items(): @@ -731,7 +732,7 @@ def get_sqlalchemy_type(field: Any) -> Any: raise ValueError(f"{type_} has no matching SQLAlchemy type") -def get_column_from_field(field: Any) -> Column: # type: ignore +def get_column_from_field(field: Any) -> Column | MappedColumn: # type: ignore if IS_PYDANTIC_V2: field_info = field else: