This commit is contained in:
John Lyu 2024-11-26 11:17:55 +08:00
parent 48f2a88752
commit e6ad74d50a
2 changed files with 21 additions and 18 deletions

View File

@ -22,6 +22,7 @@ from pydantic import VERSION as P_VERSION
from pydantic import BaseModel from pydantic import BaseModel
from pydantic.fields import FieldInfo from pydantic.fields import FieldInfo
from sqlalchemy import inspect from sqlalchemy import inspect
from sqlalchemy.orm import Mapper
from typing_extensions import Annotated, get_args, get_origin from typing_extensions import Annotated, get_args, get_origin
# Reassign variable to make it reexported for mypy # Reassign variable to make it reexported for mypy
@ -293,6 +294,7 @@ if IS_PYDANTIC_V2:
# End SQLModel override # End SQLModel override
# Override polymorphic_on default value # Override polymorphic_on default value
mapper = inspect(cls) mapper = inspect(cls)
if isinstance(mapper, Mapper):
polymorphic_on = mapper.polymorphic_on polymorphic_on = mapper.polymorphic_on
if polymorphic_on is not None: if polymorphic_on is not None:
polymorphic_property = mapper.get_property_by_column(polymorphic_on) polymorphic_property = mapper.get_property_by_column(polymorphic_on)

View File

@ -551,9 +551,10 @@ class SQLModelMetaclass(ModelMetaclass, DeclarativeMeta):
for base in bases[::-1]: for base in bases[::-1]:
if issubclass(base, BaseModel): if issubclass(base, BaseModel):
base_fields.update(base.model_fields) base_fields.update(base.model_fields)
for k, v in new_cls.model_fields.items(): fields = get_model_fields(new_cls)
for k, v in fields.items():
if isinstance(v.default, InstrumentedAttribute): if isinstance(v.default, InstrumentedAttribute):
new_cls.model_fields[k] = base_fields.get(k) fields[k] = base_fields.get(k, FieldInfo())
def get_config(name: str) -> Any: def get_config(name: str) -> Any:
config_class_value = get_config_value( config_class_value = get_config_value(
@ -572,7 +573,7 @@ class SQLModelMetaclass(ModelMetaclass, DeclarativeMeta):
# or if __tablename__ is in __annotations__. Only set __tablename__ if it's # or if __tablename__ is in __annotations__. Only set __tablename__ if it's
# a table model # a table model
if new_cls.__name__ != "SQLModel" and not hasattr(new_cls, "__tablename__"): if new_cls.__name__ != "SQLModel" and not hasattr(new_cls, "__tablename__"):
new_cls.__tablename__ = new_cls.__name__.lower() setattr(new_cls, "__tablename__", new_cls.__name__.lower()) # noqa: B010
# If it was passed by kwargs, ensure it's also set in config # If it was passed by kwargs, ensure it's also set in config
set_config_value(model=new_cls, parameter="table", value=config_table) set_config_value(model=new_cls, parameter="table", value=config_table)
for k, v in get_model_fields(new_cls).items(): for k, v in get_model_fields(new_cls).items():
@ -731,7 +732,7 @@ def get_sqlalchemy_type(field: Any) -> Any:
raise ValueError(f"{type_} has no matching SQLAlchemy type") raise ValueError(f"{type_} has no matching SQLAlchemy type")
def get_column_from_field(field: Any) -> Column: # type: ignore def get_column_from_field(field: Any) -> Column | MappedColumn: # type: ignore
if IS_PYDANTIC_V2: if IS_PYDANTIC_V2:
field_info = field field_info = field
else: else: