This commit is contained in:
John Lyu 2024-11-26 11:17:55 +08:00
parent 48f2a88752
commit e6ad74d50a
2 changed files with 21 additions and 18 deletions

View File

@ -22,6 +22,7 @@ from pydantic import VERSION as P_VERSION
from pydantic import BaseModel from pydantic import BaseModel
from pydantic.fields import FieldInfo from pydantic.fields import FieldInfo
from sqlalchemy import inspect from sqlalchemy import inspect
from sqlalchemy.orm import Mapper
from typing_extensions import Annotated, get_args, get_origin from typing_extensions import Annotated, get_args, get_origin
# Reassign variable to make it reexported for mypy # Reassign variable to make it reexported for mypy
@ -293,20 +294,21 @@ if IS_PYDANTIC_V2:
# End SQLModel override # End SQLModel override
# Override polymorphic_on default value # Override polymorphic_on default value
mapper = inspect(cls) mapper = inspect(cls)
polymorphic_on = mapper.polymorphic_on if isinstance(mapper, Mapper):
if polymorphic_on is not None: polymorphic_on = mapper.polymorphic_on
polymorphic_property = mapper.get_property_by_column(polymorphic_on) if polymorphic_on is not None:
field_info = cls.model_fields.get(polymorphic_property.key) polymorphic_property = mapper.get_property_by_column(polymorphic_on)
if field_info: field_info = cls.model_fields.get(polymorphic_property.key)
v = values.get(polymorphic_property.key) if field_info:
# if model is inherited or polymorphic_on is not explicitly set v = values.get(polymorphic_property.key)
# set the polymorphic_on by default # if model is inherited or polymorphic_on is not explicitly set
if mapper.inherits or v is None: # set the polymorphic_on by default
setattr( if mapper.inherits or v is None:
self_instance, setattr(
polymorphic_property.key, self_instance,
mapper.polymorphic_identity, polymorphic_property.key,
) mapper.polymorphic_identity,
)
return self_instance return self_instance
def sqlmodel_validate( def sqlmodel_validate(

View File

@ -551,9 +551,10 @@ class SQLModelMetaclass(ModelMetaclass, DeclarativeMeta):
for base in bases[::-1]: for base in bases[::-1]:
if issubclass(base, BaseModel): if issubclass(base, BaseModel):
base_fields.update(base.model_fields) base_fields.update(base.model_fields)
for k, v in new_cls.model_fields.items(): fields = get_model_fields(new_cls)
for k, v in fields.items():
if isinstance(v.default, InstrumentedAttribute): if isinstance(v.default, InstrumentedAttribute):
new_cls.model_fields[k] = base_fields.get(k) fields[k] = base_fields.get(k, FieldInfo())
def get_config(name: str) -> Any: def get_config(name: str) -> Any:
config_class_value = get_config_value( config_class_value = get_config_value(
@ -572,7 +573,7 @@ class SQLModelMetaclass(ModelMetaclass, DeclarativeMeta):
# or if __tablename__ is in __annotations__. Only set __tablename__ if it's # or if __tablename__ is in __annotations__. Only set __tablename__ if it's
# a table model # a table model
if new_cls.__name__ != "SQLModel" and not hasattr(new_cls, "__tablename__"): if new_cls.__name__ != "SQLModel" and not hasattr(new_cls, "__tablename__"):
new_cls.__tablename__ = new_cls.__name__.lower() setattr(new_cls, "__tablename__", new_cls.__name__.lower()) # noqa: B010
# If it was passed by kwargs, ensure it's also set in config # If it was passed by kwargs, ensure it's also set in config
set_config_value(model=new_cls, parameter="table", value=config_table) set_config_value(model=new_cls, parameter="table", value=config_table)
for k, v in get_model_fields(new_cls).items(): for k, v in get_model_fields(new_cls).items():
@ -731,7 +732,7 @@ def get_sqlalchemy_type(field: Any) -> Any:
raise ValueError(f"{type_} has no matching SQLAlchemy type") raise ValueError(f"{type_} has no matching SQLAlchemy type")
def get_column_from_field(field: Any) -> Column: # type: ignore def get_column_from_field(field: Any) -> Column | MappedColumn: # type: ignore
if IS_PYDANTIC_V2: if IS_PYDANTIC_V2:
field_info = field field_info = field
else: else: