From 76d72cd32eb59de127d7c0b87cce3ab456f5cd02 Mon Sep 17 00:00:00 2001 From: Esteban Maya Cadavid Date: Mon, 1 Jul 2024 17:28:11 -0500 Subject: [PATCH] =?UTF-8?q?=E2=9C=A8=20Add=20support=20for=20safe=20access?= =?UTF-8?q?=20to=20PK=20int=20autotypes?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- sqlmodel/_compat.py | 58 +++++++++++++++++++++++++++++++++++++++++++++ sqlmodel/main.py | 7 ++++++ 2 files changed, 65 insertions(+) diff --git a/sqlmodel/_compat.py b/sqlmodel/_compat.py index 4018d1b..02d5613 100644 --- a/sqlmodel/_compat.py +++ b/sqlmodel/_compat.py @@ -219,6 +219,33 @@ if IS_PYDANTIC_V2: ) -> Optional[AbstractSet[str]]: # pragma: no cover return None + def validate_access_primary_key_autotype( + self: InstanceOrType["SQLModel"], name: str, value: Any + ) -> None: + """ + Pydantic v2 + Validates if the attribute being accessed is a primary key with an auto type and has not been set. + + Args: + self (InstanceOrType["SQLModel"]): The instance or type of SQLModel. + name (str): The name of the attribute being accessed. + value (Any): The value of the attribute being accessed. + + Raises: + ValueError: If the attribute is a primary key with an auto type and has not been set. + + Returns: + None + """ + if name != "model_fields": + model_fields = object.__getattribute__(self, "model_fields") + field = model_fields.get(name) + if field is not None and isinstance(field, FieldInfo): + if field.primary_key and field.annotation is int and value is None: + raise ValueError( + f"Primary key attribute '{name}' has not been set, please commit() it first." + ) + def sqlmodel_table_construct( *, self_instance: _TSQLModel, @@ -499,6 +526,37 @@ else: return keys + def validate_access_primary_key_autotype( + self: InstanceOrType["SQLModel"], name: str, value: Any + ) -> None: + """ + Pydantic v1 + Validates if the attribute being accessed is a primary key with an auto type and has not been set. + + Args: + self (InstanceOrType["SQLModel"]): The instance or type of SQLModel. + name (str): The name of the attribute being accessed. + value (Any): The value of the attribute being accessed. + + Raises: + ValueError: If the attribute is a primary key with an auto type and has not been set. + + Returns: + None + """ + if name != "__fields__": + fields = object.__getattribute__(self, "__fields__") + field = fields.get(name) + if field is not None and isinstance(field.field_info, FieldInfo): + if ( + field.field_info.primary_key + and field.annotation is int + and value is None + ): + raise ValueError( + f"Primary key attribute '{name}' has not been set, please commit() it first." + ) + def sqlmodel_validate( cls: Type[_TSQLModel], obj: Any, diff --git a/sqlmodel/main.py b/sqlmodel/main.py index 505683f..6a5a46a 100644 --- a/sqlmodel/main.py +++ b/sqlmodel/main.py @@ -79,6 +79,7 @@ from ._compat import ( # type: ignore[attr-defined] set_config_value, sqlmodel_init, sqlmodel_validate, + validate_access_primary_key_autotype, ) from .sql.sqltypes import GUID, AutoString @@ -732,6 +733,12 @@ class SQLModel(BaseModel, metaclass=SQLModelMetaclass, registry=default_registry if name not in self.__sqlmodel_relationships__: super().__setattr__(name, value) + def __getattribute__(self, name: str) -> Any: + # Access attributes safely using object.__getattribute__ to avoid recursion + value = object.__getattribute__(self, name) + validate_access_primary_key_autotype(self, name, value) + return value + def __repr_args__(self) -> Sequence[Tuple[Optional[str], Any]]: # Don't show SQLAlchemy private attributes return [