From 1b7b3aa668cf3cd70bbaa2024a7a2eb015af9bc8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sebasti=C3=A1n=20Ram=C3=ADrez?= Date: Sat, 17 Feb 2024 14:34:57 +0100 Subject: [PATCH] =?UTF-8?q?=F0=9F=90=9B=20Fix=20class=20initialization=20c?= =?UTF-8?q?ompatibility=20with=20Pydantic=20and=20SQLModel,=20fixing=20err?= =?UTF-8?q?ors=20revealed=20by=20the=20latest=20Pydantic=20(#807)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .github/workflows/test.yml | 4 ++-- sqlmodel/_compat.py | 14 ++++++-------- sqlmodel/main.py | 6 +++--- 3 files changed, 11 insertions(+), 13 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index ade60f2..89da640 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -62,10 +62,10 @@ jobs: run: python -m poetry install - name: Install Pydantic v1 if: matrix.pydantic-version == 'pydantic-v1' - run: pip install "pydantic>=1.10.0,<2.0.0" + run: pip install --upgrade "pydantic>=1.10.0,<2.0.0" - name: Install Pydantic v2 if: matrix.pydantic-version == 'pydantic-v2' - run: pip install "pydantic>=2.0.2,<3.0.0" + run: pip install --upgrade "pydantic>=2.0.2,<3.0.0" - name: Lint # Do not run on Python 3.7 as mypy behaves differently if: matrix.python-version != '3.7' && matrix.pydantic-version == 'pydantic-v2' diff --git a/sqlmodel/_compat.py b/sqlmodel/_compat.py index 2a2caca..76771ce 100644 --- a/sqlmodel/_compat.py +++ b/sqlmodel/_compat.py @@ -97,10 +97,10 @@ if IS_PYDANTIC_V2: def get_model_fields(model: InstanceOrType["SQLModel"]) -> Dict[str, "FieldInfo"]: return model.model_fields - def set_fields_set( - new_object: InstanceOrType["SQLModel"], fields: Set["FieldInfo"] - ) -> None: - object.__setattr__(new_object, "__pydantic_fields_set__", fields) + def init_pydantic_private_attrs(new_object: InstanceOrType["SQLModel"]) -> None: + object.__setattr__(new_object, "__pydantic_fields_set__", set()) + object.__setattr__(new_object, "__pydantic_extra__", None) + object.__setattr__(new_object, "__pydantic_private__", None) def get_annotations(class_dict: Dict[str, Any]) -> Dict[str, Any]: return class_dict.get("__annotations__", {}) @@ -387,10 +387,8 @@ else: def get_model_fields(model: InstanceOrType["SQLModel"]) -> Dict[str, "FieldInfo"]: return model.__fields__ # type: ignore - def set_fields_set( - new_object: InstanceOrType["SQLModel"], fields: Set["FieldInfo"] - ) -> None: - object.__setattr__(new_object, "__fields_set__", fields) + def init_pydantic_private_attrs(new_object: InstanceOrType["SQLModel"]) -> None: + object.__setattr__(new_object, "__fields_set__", set()) def get_annotations(class_dict: Dict[str, Any]) -> Dict[str, Any]: return resolve_annotations( # type: ignore[no-any-return] diff --git a/sqlmodel/main.py b/sqlmodel/main.py index 10064c7..fec3bc7 100644 --- a/sqlmodel/main.py +++ b/sqlmodel/main.py @@ -70,11 +70,11 @@ from ._compat import ( # type: ignore[attr-defined] get_model_fields, get_relationship_to, get_type_from_field, + init_pydantic_private_attrs, is_field_noneable, is_table_model_class, post_init_field_info, set_config_value, - set_fields_set, sqlmodel_init, sqlmodel_validate, ) @@ -686,12 +686,12 @@ class SQLModel(BaseModel, metaclass=SQLModelMetaclass, registry=default_registry def __new__(cls, *args: Any, **kwargs: Any) -> Any: new_object = super().__new__(cls) - # SQLAlchemy doesn't call __init__ on the base class + # SQLAlchemy doesn't call __init__ on the base class when querying from DB # Ref: https://docs.sqlalchemy.org/en/14/orm/constructors.html # Set __fields_set__ here, that would have been set when calling __init__ # in the Pydantic model so that when SQLAlchemy sets attributes that are # added (e.g. when querying from DB) to the __fields_set__, this already exists - set_fields_set(new_object, set()) + init_pydantic_private_attrs(new_object) return new_object def __init__(__pydantic_self__, **data: Any) -> None: