fix lint
This commit is contained in:
parent
48f2a88752
commit
e6ad74d50a
@ -22,6 +22,7 @@ from pydantic import VERSION as P_VERSION
|
|||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from pydantic.fields import FieldInfo
|
from pydantic.fields import FieldInfo
|
||||||
from sqlalchemy import inspect
|
from sqlalchemy import inspect
|
||||||
|
from sqlalchemy.orm import Mapper
|
||||||
from typing_extensions import Annotated, get_args, get_origin
|
from typing_extensions import Annotated, get_args, get_origin
|
||||||
|
|
||||||
# Reassign variable to make it reexported for mypy
|
# Reassign variable to make it reexported for mypy
|
||||||
@ -293,20 +294,21 @@ if IS_PYDANTIC_V2:
|
|||||||
# End SQLModel override
|
# End SQLModel override
|
||||||
# Override polymorphic_on default value
|
# Override polymorphic_on default value
|
||||||
mapper = inspect(cls)
|
mapper = inspect(cls)
|
||||||
polymorphic_on = mapper.polymorphic_on
|
if isinstance(mapper, Mapper):
|
||||||
if polymorphic_on is not None:
|
polymorphic_on = mapper.polymorphic_on
|
||||||
polymorphic_property = mapper.get_property_by_column(polymorphic_on)
|
if polymorphic_on is not None:
|
||||||
field_info = cls.model_fields.get(polymorphic_property.key)
|
polymorphic_property = mapper.get_property_by_column(polymorphic_on)
|
||||||
if field_info:
|
field_info = cls.model_fields.get(polymorphic_property.key)
|
||||||
v = values.get(polymorphic_property.key)
|
if field_info:
|
||||||
# if model is inherited or polymorphic_on is not explicitly set
|
v = values.get(polymorphic_property.key)
|
||||||
# set the polymorphic_on by default
|
# if model is inherited or polymorphic_on is not explicitly set
|
||||||
if mapper.inherits or v is None:
|
# set the polymorphic_on by default
|
||||||
setattr(
|
if mapper.inherits or v is None:
|
||||||
self_instance,
|
setattr(
|
||||||
polymorphic_property.key,
|
self_instance,
|
||||||
mapper.polymorphic_identity,
|
polymorphic_property.key,
|
||||||
)
|
mapper.polymorphic_identity,
|
||||||
|
)
|
||||||
return self_instance
|
return self_instance
|
||||||
|
|
||||||
def sqlmodel_validate(
|
def sqlmodel_validate(
|
||||||
|
@ -551,9 +551,10 @@ class SQLModelMetaclass(ModelMetaclass, DeclarativeMeta):
|
|||||||
for base in bases[::-1]:
|
for base in bases[::-1]:
|
||||||
if issubclass(base, BaseModel):
|
if issubclass(base, BaseModel):
|
||||||
base_fields.update(base.model_fields)
|
base_fields.update(base.model_fields)
|
||||||
for k, v in new_cls.model_fields.items():
|
fields = get_model_fields(new_cls)
|
||||||
|
for k, v in fields.items():
|
||||||
if isinstance(v.default, InstrumentedAttribute):
|
if isinstance(v.default, InstrumentedAttribute):
|
||||||
new_cls.model_fields[k] = base_fields.get(k)
|
fields[k] = base_fields.get(k, FieldInfo())
|
||||||
|
|
||||||
def get_config(name: str) -> Any:
|
def get_config(name: str) -> Any:
|
||||||
config_class_value = get_config_value(
|
config_class_value = get_config_value(
|
||||||
@ -572,7 +573,7 @@ class SQLModelMetaclass(ModelMetaclass, DeclarativeMeta):
|
|||||||
# or if __tablename__ is in __annotations__. Only set __tablename__ if it's
|
# or if __tablename__ is in __annotations__. Only set __tablename__ if it's
|
||||||
# a table model
|
# a table model
|
||||||
if new_cls.__name__ != "SQLModel" and not hasattr(new_cls, "__tablename__"):
|
if new_cls.__name__ != "SQLModel" and not hasattr(new_cls, "__tablename__"):
|
||||||
new_cls.__tablename__ = new_cls.__name__.lower()
|
setattr(new_cls, "__tablename__", new_cls.__name__.lower()) # noqa: B010
|
||||||
# If it was passed by kwargs, ensure it's also set in config
|
# If it was passed by kwargs, ensure it's also set in config
|
||||||
set_config_value(model=new_cls, parameter="table", value=config_table)
|
set_config_value(model=new_cls, parameter="table", value=config_table)
|
||||||
for k, v in get_model_fields(new_cls).items():
|
for k, v in get_model_fields(new_cls).items():
|
||||||
@ -731,7 +732,7 @@ def get_sqlalchemy_type(field: Any) -> Any:
|
|||||||
raise ValueError(f"{type_} has no matching SQLAlchemy type")
|
raise ValueError(f"{type_} has no matching SQLAlchemy type")
|
||||||
|
|
||||||
|
|
||||||
def get_column_from_field(field: Any) -> Column: # type: ignore
|
def get_column_from_field(field: Any) -> Column | MappedColumn: # type: ignore
|
||||||
if IS_PYDANTIC_V2:
|
if IS_PYDANTIC_V2:
|
||||||
field_info = field
|
field_info = field
|
||||||
else:
|
else:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user