✨ Raise a more clear error when a type is not valid (#425)
Co-authored-by: Sebastián Ramírez <tiangolo@gmail.com>
This commit is contained in:
parent
a8a792e3c0
commit
840fd08ab2
@ -374,45 +374,46 @@ class SQLModelMetaclass(ModelMetaclass, DeclarativeMeta):
|
|||||||
|
|
||||||
|
|
||||||
def get_sqlalchemy_type(field: ModelField) -> Any:
|
def get_sqlalchemy_type(field: ModelField) -> Any:
|
||||||
if issubclass(field.type_, str):
|
if isinstance(field.type_, type) and field.shape == SHAPE_SINGLETON:
|
||||||
if field.field_info.max_length:
|
if issubclass(field.type_, str):
|
||||||
return AutoString(length=field.field_info.max_length)
|
if field.field_info.max_length:
|
||||||
return AutoString
|
return AutoString(length=field.field_info.max_length)
|
||||||
if issubclass(field.type_, float):
|
return AutoString
|
||||||
return Float
|
if issubclass(field.type_, float):
|
||||||
if issubclass(field.type_, bool):
|
return Float
|
||||||
return Boolean
|
if issubclass(field.type_, bool):
|
||||||
if issubclass(field.type_, int):
|
return Boolean
|
||||||
return Integer
|
if issubclass(field.type_, int):
|
||||||
if issubclass(field.type_, datetime):
|
return Integer
|
||||||
return DateTime
|
if issubclass(field.type_, datetime):
|
||||||
if issubclass(field.type_, date):
|
return DateTime
|
||||||
return Date
|
if issubclass(field.type_, date):
|
||||||
if issubclass(field.type_, timedelta):
|
return Date
|
||||||
return Interval
|
if issubclass(field.type_, timedelta):
|
||||||
if issubclass(field.type_, time):
|
return Interval
|
||||||
return Time
|
if issubclass(field.type_, time):
|
||||||
if issubclass(field.type_, Enum):
|
return Time
|
||||||
return sa_Enum(field.type_)
|
if issubclass(field.type_, Enum):
|
||||||
if issubclass(field.type_, bytes):
|
return sa_Enum(field.type_)
|
||||||
return LargeBinary
|
if issubclass(field.type_, bytes):
|
||||||
if issubclass(field.type_, Decimal):
|
return LargeBinary
|
||||||
return Numeric(
|
if issubclass(field.type_, Decimal):
|
||||||
precision=getattr(field.type_, "max_digits", None),
|
return Numeric(
|
||||||
scale=getattr(field.type_, "decimal_places", None),
|
precision=getattr(field.type_, "max_digits", None),
|
||||||
)
|
scale=getattr(field.type_, "decimal_places", None),
|
||||||
if issubclass(field.type_, ipaddress.IPv4Address):
|
)
|
||||||
return AutoString
|
if issubclass(field.type_, ipaddress.IPv4Address):
|
||||||
if issubclass(field.type_, ipaddress.IPv4Network):
|
return AutoString
|
||||||
return AutoString
|
if issubclass(field.type_, ipaddress.IPv4Network):
|
||||||
if issubclass(field.type_, ipaddress.IPv6Address):
|
return AutoString
|
||||||
return AutoString
|
if issubclass(field.type_, ipaddress.IPv6Address):
|
||||||
if issubclass(field.type_, ipaddress.IPv6Network):
|
return AutoString
|
||||||
return AutoString
|
if issubclass(field.type_, ipaddress.IPv6Network):
|
||||||
if issubclass(field.type_, Path):
|
return AutoString
|
||||||
return AutoString
|
if issubclass(field.type_, Path):
|
||||||
if issubclass(field.type_, uuid.UUID):
|
return AutoString
|
||||||
return GUID
|
if issubclass(field.type_, uuid.UUID):
|
||||||
|
return GUID
|
||||||
raise ValueError(f"The field {field.name} has no matching SQLAlchemy type")
|
raise ValueError(f"The field {field.name} has no matching SQLAlchemy type")
|
||||||
|
|
||||||
|
|
||||||
|
28
tests/test_sqlalchemy_type_errors.py
Normal file
28
tests/test_sqlalchemy_type_errors.py
Normal file
@ -0,0 +1,28 @@
|
|||||||
|
from typing import Any, Dict, List, Optional, Union
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from sqlmodel import Field, SQLModel
|
||||||
|
|
||||||
|
|
||||||
|
def test_type_list_breaks() -> None:
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
|
||||||
|
class Hero(SQLModel, table=True):
|
||||||
|
id: Optional[int] = Field(default=None, primary_key=True)
|
||||||
|
tags: List[str]
|
||||||
|
|
||||||
|
|
||||||
|
def test_type_dict_breaks() -> None:
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
|
||||||
|
class Hero(SQLModel, table=True):
|
||||||
|
id: Optional[int] = Field(default=None, primary_key=True)
|
||||||
|
tags: Dict[str, Any]
|
||||||
|
|
||||||
|
|
||||||
|
def test_type_union_breaks() -> None:
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
|
||||||
|
class Hero(SQLModel, table=True):
|
||||||
|
id: Optional[int] = Field(default=None, primary_key=True)
|
||||||
|
tags: Union[int, str]
|
Loading…
x
Reference in New Issue
Block a user