This commit is contained in:
John Lyu 2024-11-26 11:17:55 +08:00
parent 48f2a88752
commit e6ad74d50a
2 changed files with 21 additions and 18 deletions

View File

@ -22,6 +22,7 @@ from pydantic import VERSION as P_VERSION
from pydantic import BaseModel
from pydantic.fields import FieldInfo
from sqlalchemy import inspect
from sqlalchemy.orm import Mapper
from typing_extensions import Annotated, get_args, get_origin
# Reassign variable to make it reexported for mypy
@ -293,20 +294,21 @@ if IS_PYDANTIC_V2:
# End SQLModel override
# Override polymorphic_on default value
mapper = inspect(cls)
polymorphic_on = mapper.polymorphic_on
if polymorphic_on is not None:
polymorphic_property = mapper.get_property_by_column(polymorphic_on)
field_info = cls.model_fields.get(polymorphic_property.key)
if field_info:
v = values.get(polymorphic_property.key)
# if model is inherited or polymorphic_on is not explicitly set
# set the polymorphic_on by default
if mapper.inherits or v is None:
setattr(
self_instance,
polymorphic_property.key,
mapper.polymorphic_identity,
)
if isinstance(mapper, Mapper):
polymorphic_on = mapper.polymorphic_on
if polymorphic_on is not None:
polymorphic_property = mapper.get_property_by_column(polymorphic_on)
field_info = cls.model_fields.get(polymorphic_property.key)
if field_info:
v = values.get(polymorphic_property.key)
# if model is inherited or polymorphic_on is not explicitly set
# set the polymorphic_on by default
if mapper.inherits or v is None:
setattr(
self_instance,
polymorphic_property.key,
mapper.polymorphic_identity,
)
return self_instance
def sqlmodel_validate(

View File

@ -551,9 +551,10 @@ class SQLModelMetaclass(ModelMetaclass, DeclarativeMeta):
for base in bases[::-1]:
if issubclass(base, BaseModel):
base_fields.update(base.model_fields)
for k, v in new_cls.model_fields.items():
fields = get_model_fields(new_cls)
for k, v in fields.items():
if isinstance(v.default, InstrumentedAttribute):
new_cls.model_fields[k] = base_fields.get(k)
fields[k] = base_fields.get(k, FieldInfo())
def get_config(name: str) -> Any:
config_class_value = get_config_value(
@ -572,7 +573,7 @@ class SQLModelMetaclass(ModelMetaclass, DeclarativeMeta):
# or if __tablename__ is in __annotations__. Only set __tablename__ if it's
# a table model
if new_cls.__name__ != "SQLModel" and not hasattr(new_cls, "__tablename__"):
new_cls.__tablename__ = new_cls.__name__.lower()
setattr(new_cls, "__tablename__", new_cls.__name__.lower()) # noqa: B010
# If it was passed by kwargs, ensure it's also set in config
set_config_value(model=new_cls, parameter="table", value=config_table)
for k, v in get_model_fields(new_cls).items():
@ -731,7 +732,7 @@ def get_sqlalchemy_type(field: Any) -> Any:
raise ValueError(f"{type_} has no matching SQLAlchemy type")
def get_column_from_field(field: Any) -> Column: # type: ignore
def get_column_from_field(field: Any) -> Column | MappedColumn: # type: ignore
if IS_PYDANTIC_V2:
field_info = field
else: