Raise a more clear error when a type is not valid (#425)

Co-authored-by: Sebastián Ramírez <tiangolo@gmail.com>
This commit is contained in:
David Danier 2023-10-23 08:42:30 +02:00 committed by GitHub
parent a8a792e3c0
commit 840fd08ab2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 68 additions and 39 deletions

View File

@ -374,45 +374,46 @@ class SQLModelMetaclass(ModelMetaclass, DeclarativeMeta):
def get_sqlalchemy_type(field: ModelField) -> Any:
if issubclass(field.type_, str):
if field.field_info.max_length:
return AutoString(length=field.field_info.max_length)
return AutoString
if issubclass(field.type_, float):
return Float
if issubclass(field.type_, bool):
return Boolean
if issubclass(field.type_, int):
return Integer
if issubclass(field.type_, datetime):
return DateTime
if issubclass(field.type_, date):
return Date
if issubclass(field.type_, timedelta):
return Interval
if issubclass(field.type_, time):
return Time
if issubclass(field.type_, Enum):
return sa_Enum(field.type_)
if issubclass(field.type_, bytes):
return LargeBinary
if issubclass(field.type_, Decimal):
return Numeric(
precision=getattr(field.type_, "max_digits", None),
scale=getattr(field.type_, "decimal_places", None),
)
if issubclass(field.type_, ipaddress.IPv4Address):
return AutoString
if issubclass(field.type_, ipaddress.IPv4Network):
return AutoString
if issubclass(field.type_, ipaddress.IPv6Address):
return AutoString
if issubclass(field.type_, ipaddress.IPv6Network):
return AutoString
if issubclass(field.type_, Path):
return AutoString
if issubclass(field.type_, uuid.UUID):
return GUID
if isinstance(field.type_, type) and field.shape == SHAPE_SINGLETON:
if issubclass(field.type_, str):
if field.field_info.max_length:
return AutoString(length=field.field_info.max_length)
return AutoString
if issubclass(field.type_, float):
return Float
if issubclass(field.type_, bool):
return Boolean
if issubclass(field.type_, int):
return Integer
if issubclass(field.type_, datetime):
return DateTime
if issubclass(field.type_, date):
return Date
if issubclass(field.type_, timedelta):
return Interval
if issubclass(field.type_, time):
return Time
if issubclass(field.type_, Enum):
return sa_Enum(field.type_)
if issubclass(field.type_, bytes):
return LargeBinary
if issubclass(field.type_, Decimal):
return Numeric(
precision=getattr(field.type_, "max_digits", None),
scale=getattr(field.type_, "decimal_places", None),
)
if issubclass(field.type_, ipaddress.IPv4Address):
return AutoString
if issubclass(field.type_, ipaddress.IPv4Network):
return AutoString
if issubclass(field.type_, ipaddress.IPv6Address):
return AutoString
if issubclass(field.type_, ipaddress.IPv6Network):
return AutoString
if issubclass(field.type_, Path):
return AutoString
if issubclass(field.type_, uuid.UUID):
return GUID
raise ValueError(f"The field {field.name} has no matching SQLAlchemy type")

View File

@ -0,0 +1,28 @@
from typing import Any, Dict, List, Optional, Union
import pytest
from sqlmodel import Field, SQLModel
def test_type_list_breaks() -> None:
with pytest.raises(ValueError):
class Hero(SQLModel, table=True):
id: Optional[int] = Field(default=None, primary_key=True)
tags: List[str]
def test_type_dict_breaks() -> None:
with pytest.raises(ValueError):
class Hero(SQLModel, table=True):
id: Optional[int] = Field(default=None, primary_key=True)
tags: Dict[str, Any]
def test_type_union_breaks() -> None:
with pytest.raises(ValueError):
class Hero(SQLModel, table=True):
id: Optional[int] = Field(default=None, primary_key=True)
tags: Union[int, str]