From cbaf172c63889985e94031fb38ff00e80b7c90bf Mon Sep 17 00:00:00 2001 From: "Maruo.S" Date: Sun, 29 Oct 2023 17:10:39 +0900 Subject: [PATCH] =?UTF-8?q?=E2=9C=A8=20Add=20support=20for=20passing=20a?= =?UTF-8?q?=20custom=20SQLAlchemy=20type=20to=20`Field()`=20with=20`sa=5Ft?= =?UTF-8?q?ype`=20(#505)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Sebastián Ramírez --- sqlmodel/main.py | 16 ++++++++++++++-- tests/test_field_sa_column.py | 11 +++++++++++ 2 files changed, 25 insertions(+), 2 deletions(-) diff --git a/sqlmodel/main.py b/sqlmodel/main.py index f48e388..2b69dd2 100644 --- a/sqlmodel/main.py +++ b/sqlmodel/main.py @@ -74,6 +74,7 @@ class FieldInfo(PydanticFieldInfo): foreign_key = kwargs.pop("foreign_key", Undefined) unique = kwargs.pop("unique", False) index = kwargs.pop("index", Undefined) + sa_type = kwargs.pop("sa_type", Undefined) sa_column = kwargs.pop("sa_column", Undefined) sa_column_args = kwargs.pop("sa_column_args", Undefined) sa_column_kwargs = kwargs.pop("sa_column_kwargs", Undefined) @@ -104,11 +105,15 @@ class FieldInfo(PydanticFieldInfo): ) if unique is not Undefined: raise RuntimeError( - "Passing unique is not supported when " "also passing a sa_column" + "Passing unique is not supported when also passing a sa_column" ) if index is not Undefined: raise RuntimeError( - "Passing index is not supported when " "also passing a sa_column" + "Passing index is not supported when also passing a sa_column" + ) + if sa_type is not Undefined: + raise RuntimeError( + "Passing sa_type is not supported when also passing a sa_column" ) super().__init__(default=default, **kwargs) self.primary_key = primary_key @@ -116,6 +121,7 @@ class FieldInfo(PydanticFieldInfo): self.foreign_key = foreign_key self.unique = unique self.index = index + self.sa_type = sa_type self.sa_column = sa_column self.sa_column_args = sa_column_args self.sa_column_kwargs = sa_column_kwargs @@ -185,6 +191,7 @@ def Field( unique: Union[bool, UndefinedType] = Undefined, nullable: Union[bool, UndefinedType] = Undefined, index: Union[bool, UndefinedType] = Undefined, + sa_type: Union[Type[Any], UndefinedType] = Undefined, sa_column_args: Union[Sequence[Any], UndefinedType] = Undefined, sa_column_kwargs: Union[Mapping[str, Any], UndefinedType] = Undefined, schema_extra: Optional[Dict[str, Any]] = None, @@ -264,6 +271,7 @@ def Field( unique: Union[bool, UndefinedType] = Undefined, nullable: Union[bool, UndefinedType] = Undefined, index: Union[bool, UndefinedType] = Undefined, + sa_type: Union[Type[Any], UndefinedType] = Undefined, sa_column: Union[Column, UndefinedType] = Undefined, # type: ignore sa_column_args: Union[Sequence[Any], UndefinedType] = Undefined, sa_column_kwargs: Union[Mapping[str, Any], UndefinedType] = Undefined, @@ -300,6 +308,7 @@ def Field( unique=unique, nullable=nullable, index=index, + sa_type=sa_type, sa_column=sa_column, sa_column_args=sa_column_args, sa_column_kwargs=sa_column_kwargs, @@ -515,6 +524,9 @@ class SQLModelMetaclass(ModelMetaclass, DeclarativeMeta): def get_sqlalchemy_type(field: ModelField) -> Any: + sa_type = getattr(field.field_info, "sa_type", Undefined) # noqa: B009 + if sa_type is not Undefined: + return sa_type if isinstance(field.type_, type) and field.shape == SHAPE_SINGLETON: # Check enums first as an enum can also be a str, needed by Pydantic/FastAPI if issubclass(field.type_, Enum): diff --git a/tests/test_field_sa_column.py b/tests/test_field_sa_column.py index 51cfdfa..7384f1f 100644 --- a/tests/test_field_sa_column.py +++ b/tests/test_field_sa_column.py @@ -39,6 +39,17 @@ def test_sa_column_no_sa_kargs() -> None: ) +def test_sa_column_no_type() -> None: + with pytest.raises(RuntimeError): + + class Item(SQLModel, table=True): + id: Optional[int] = Field( + default=None, + sa_type=Integer, + sa_column=Column(Integer, primary_key=True), + ) + + def test_sa_column_no_primary_key() -> None: with pytest.raises(RuntimeError):