improve code structure
This commit is contained in:
@@ -66,6 +66,29 @@ def _is_union_type(t: Any) -> bool:
|
||||
finish_init: ContextVar[bool] = ContextVar("finish_init", default=True)
|
||||
|
||||
|
||||
def set_polymorphic_default_value(self_instance, values):
|
||||
"""By defalut, when init a model, pydantic will set the polymorphic_on
|
||||
value to field default value. But when inherit a model, the polymorphic_on
|
||||
should be set to polymorphic_identity value by default."""
|
||||
cls = type(self_instance)
|
||||
mapper = inspect(cls)
|
||||
if isinstance(mapper, Mapper):
|
||||
polymorphic_on = mapper.polymorphic_on
|
||||
if polymorphic_on is not None:
|
||||
polymorphic_property = mapper.get_property_by_column(polymorphic_on)
|
||||
field_info = get_model_fields(cls).get(polymorphic_property.key)
|
||||
if field_info:
|
||||
v = values.get(polymorphic_property.key)
|
||||
# if model is inherited or polymorphic_on is not explicitly set
|
||||
# set the polymorphic_on by default
|
||||
if mapper.inherits or v is None:
|
||||
setattr(
|
||||
self_instance,
|
||||
polymorphic_property.key,
|
||||
mapper.polymorphic_identity,
|
||||
)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def partial_init() -> Generator[None, None, None]:
|
||||
token = finish_init.set(False)
|
||||
@@ -293,22 +316,7 @@ if IS_PYDANTIC_V2:
|
||||
setattr(self_instance, key, value)
|
||||
# End SQLModel override
|
||||
# Override polymorphic_on default value
|
||||
mapper = inspect(cls)
|
||||
if isinstance(mapper, Mapper):
|
||||
polymorphic_on = mapper.polymorphic_on
|
||||
if polymorphic_on is not None:
|
||||
polymorphic_property = mapper.get_property_by_column(polymorphic_on)
|
||||
field_info = cls.model_fields.get(polymorphic_property.key)
|
||||
if field_info:
|
||||
v = values.get(polymorphic_property.key)
|
||||
# if model is inherited or polymorphic_on is not explicitly set
|
||||
# set the polymorphic_on by default
|
||||
if mapper.inherits or v is None:
|
||||
setattr(
|
||||
self_instance,
|
||||
polymorphic_property.key,
|
||||
mapper.polymorphic_identity,
|
||||
)
|
||||
set_polymorphic_default_value(self_instance, values)
|
||||
return self_instance
|
||||
|
||||
def sqlmodel_validate(
|
||||
@@ -592,3 +600,5 @@ else:
|
||||
for key in non_pydantic_keys:
|
||||
if key in self.__sqlmodel_relationships__:
|
||||
setattr(self, key, data[key])
|
||||
# Override polymorphic_on default value
|
||||
set_polymorphic_default_value(self, values)
|
||||
|
||||
Reference in New Issue
Block a user