improve code structure

This commit is contained in:
John Lyu 2024-12-03 17:59:21 +08:00
parent a3044bbf68
commit 015601cd5b
2 changed files with 28 additions and 18 deletions

View File

@ -66,6 +66,29 @@ def _is_union_type(t: Any) -> bool:
finish_init: ContextVar[bool] = ContextVar("finish_init", default=True) finish_init: ContextVar[bool] = ContextVar("finish_init", default=True)
def set_polymorphic_default_value(self_instance, values):
"""By defalut, when init a model, pydantic will set the polymorphic_on
value to field default value. But when inherit a model, the polymorphic_on
should be set to polymorphic_identity value by default."""
cls = type(self_instance)
mapper = inspect(cls)
if isinstance(mapper, Mapper):
polymorphic_on = mapper.polymorphic_on
if polymorphic_on is not None:
polymorphic_property = mapper.get_property_by_column(polymorphic_on)
field_info = get_model_fields(cls).get(polymorphic_property.key)
if field_info:
v = values.get(polymorphic_property.key)
# if model is inherited or polymorphic_on is not explicitly set
# set the polymorphic_on by default
if mapper.inherits or v is None:
setattr(
self_instance,
polymorphic_property.key,
mapper.polymorphic_identity,
)
@contextmanager @contextmanager
def partial_init() -> Generator[None, None, None]: def partial_init() -> Generator[None, None, None]:
token = finish_init.set(False) token = finish_init.set(False)
@ -293,22 +316,7 @@ if IS_PYDANTIC_V2:
setattr(self_instance, key, value) setattr(self_instance, key, value)
# End SQLModel override # End SQLModel override
# Override polymorphic_on default value # Override polymorphic_on default value
mapper = inspect(cls) set_polymorphic_default_value(self_instance, values)
if isinstance(mapper, Mapper):
polymorphic_on = mapper.polymorphic_on
if polymorphic_on is not None:
polymorphic_property = mapper.get_property_by_column(polymorphic_on)
field_info = cls.model_fields.get(polymorphic_property.key)
if field_info:
v = values.get(polymorphic_property.key)
# if model is inherited or polymorphic_on is not explicitly set
# set the polymorphic_on by default
if mapper.inherits or v is None:
setattr(
self_instance,
polymorphic_property.key,
mapper.polymorphic_identity,
)
return self_instance return self_instance
def sqlmodel_validate( def sqlmodel_validate(
@ -592,3 +600,5 @@ else:
for key in non_pydantic_keys: for key in non_pydantic_keys:
if key in self.__sqlmodel_relationships__: if key in self.__sqlmodel_relationships__:
setattr(self, key, data[key]) setattr(self, key, data[key])
# Override polymorphic_on default value
set_polymorphic_default_value(self, values)

View File

@ -51,7 +51,7 @@ def test_polymorphic_joined_table(clear_sqlmodel) -> None:
@needs_pydanticv2 @needs_pydanticv2
def test_polymorphic_joined_table_sm_field(clear_sqlmodel) -> None: def test_polymorphic_joined_table_with_sqlmodel_field(clear_sqlmodel) -> None:
class Hero(SQLModel, table=True): class Hero(SQLModel, table=True):
__tablename__ = "hero" __tablename__ = "hero"
id: Optional[int] = Field(default=None, primary_key=True) id: Optional[int] = Field(default=None, primary_key=True)
@ -123,7 +123,7 @@ def test_polymorphic_single_table(clear_sqlmodel) -> None:
with Session(engine) as db: with Session(engine) as db:
hero = Hero() hero = Hero()
db.add(hero) db.add(hero)
dark_hero = DarkHero() dark_hero = DarkHero(dark_power="pokey")
db.add(dark_hero) db.add(dark_hero)
db.commit() db.commit()
statement = select(DarkHero) statement = select(DarkHero)