From 9509313eaf5841790a0503fbba5527c950ad54ba Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sebasti=C3=A1n=20Ram=C3=ADrez?= Date: Sun, 29 Oct 2023 12:00:37 +0400 Subject: [PATCH] =?UTF-8?q?=E2=99=BB=EF=B8=8F=20Refactor=20checking=20for?= =?UTF-8?q?=20sa=5Ftype?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- sqlmodel/main.py | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/sqlmodel/main.py b/sqlmodel/main.py index 8801730..266bb6c 100644 --- a/sqlmodel/main.py +++ b/sqlmodel/main.py @@ -105,11 +105,15 @@ class FieldInfo(PydanticFieldInfo): ) if unique is not Undefined: raise RuntimeError( - "Passing unique is not supported when " "also passing a sa_column" + "Passing unique is not supported when also passing a sa_column" ) if index is not Undefined: raise RuntimeError( - "Passing index is not supported when " "also passing a sa_column" + "Passing index is not supported when also passing a sa_column" + ) + if sa_type is not Undefined: + raise RuntimeError( + "Passing sa_type is not supported when also passing a sa_column" ) super().__init__(default=default, **kwargs) self.primary_key = primary_key @@ -187,6 +191,7 @@ def Field( unique: Union[bool, UndefinedType] = Undefined, nullable: Union[bool, UndefinedType] = Undefined, index: Union[bool, UndefinedType] = Undefined, + sa_type: Union[Type[Any], UndefinedType] = Undefined, sa_column_args: Union[Sequence[Any], UndefinedType] = Undefined, sa_column_kwargs: Union[Mapping[str, Any], UndefinedType] = Undefined, schema_extra: Optional[Dict[str, Any]] = None, @@ -266,7 +271,7 @@ def Field( unique: Union[bool, UndefinedType] = Undefined, nullable: Union[bool, UndefinedType] = Undefined, index: Union[bool, UndefinedType] = Undefined, - sa_type: Type[Any] = Undefined, + sa_type: Union[Type[Any], UndefinedType] = Undefined, sa_column: Union[Column, UndefinedType] = Undefined, # type: ignore sa_column_args: Union[Sequence[Any], UndefinedType] = Undefined, sa_column_kwargs: Union[Mapping[str, Any], UndefinedType] = Undefined, @@ -519,9 +524,9 @@ class SQLModelMetaclass(ModelMetaclass, DeclarativeMeta): def get_sqlalchemy_type(field: ModelField) -> Any: - if hasattr(field.field_info, "sa_type"): - if not issubclass(type(field.field_info.sa_type), type(Undefined)): - return field.field_info.sa_type + sa_type = getattr(field.field_info, "sa_type") # noqa: B009 + if sa_type is not Undefined: + return sa_type if isinstance(field.type_, type) and field.shape == SHAPE_SINGLETON: # Check enums first as an enum can also be a str, needed by Pydantic/FastAPI if issubclass(field.type_, Enum):