🐛 Fix enum type checks ordering in get_sqlalchemy_type
(#669)
Co-authored-by: Pierre Cheynier <p.cheynier@criteo.com>
This commit is contained in:
parent
40c1af9202
commit
d3261cab59
@ -384,6 +384,9 @@ class SQLModelMetaclass(ModelMetaclass, DeclarativeMeta):
|
|||||||
|
|
||||||
def get_sqlalchemy_type(field: ModelField) -> Any:
|
def get_sqlalchemy_type(field: ModelField) -> Any:
|
||||||
if isinstance(field.type_, type) and field.shape == SHAPE_SINGLETON:
|
if isinstance(field.type_, type) and field.shape == SHAPE_SINGLETON:
|
||||||
|
# Check enums first as an enum can also be a str, needed by Pydantic/FastAPI
|
||||||
|
if issubclass(field.type_, Enum):
|
||||||
|
return sa_Enum(field.type_)
|
||||||
if issubclass(field.type_, str):
|
if issubclass(field.type_, str):
|
||||||
if field.field_info.max_length:
|
if field.field_info.max_length:
|
||||||
return AutoString(length=field.field_info.max_length)
|
return AutoString(length=field.field_info.max_length)
|
||||||
@ -402,8 +405,6 @@ def get_sqlalchemy_type(field: ModelField) -> Any:
|
|||||||
return Interval
|
return Interval
|
||||||
if issubclass(field.type_, time):
|
if issubclass(field.type_, time):
|
||||||
return Time
|
return Time
|
||||||
if issubclass(field.type_, Enum):
|
|
||||||
return sa_Enum(field.type_)
|
|
||||||
if issubclass(field.type_, bytes):
|
if issubclass(field.type_, bytes):
|
||||||
return LargeBinary
|
return LargeBinary
|
||||||
if issubclass(field.type_, Decimal):
|
if issubclass(field.type_, Decimal):
|
||||||
|
@ -14,12 +14,12 @@ Associated issues:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
class MyEnum1(enum.Enum):
|
class MyEnum1(str, enum.Enum):
|
||||||
A = "A"
|
A = "A"
|
||||||
B = "B"
|
B = "B"
|
||||||
|
|
||||||
|
|
||||||
class MyEnum2(enum.Enum):
|
class MyEnum2(str, enum.Enum):
|
||||||
C = "C"
|
C = "C"
|
||||||
D = "D"
|
D = "D"
|
||||||
|
|
||||||
@ -70,3 +70,43 @@ def test_sqlite_ddl_sql(capsys):
|
|||||||
captured = capsys.readouterr()
|
captured = capsys.readouterr()
|
||||||
assert "enum_field VARCHAR(1) NOT NULL" in captured.out
|
assert "enum_field VARCHAR(1) NOT NULL" in captured.out
|
||||||
assert "CREATE TYPE" not in captured.out
|
assert "CREATE TYPE" not in captured.out
|
||||||
|
|
||||||
|
|
||||||
|
def test_json_schema_flat_model():
|
||||||
|
assert FlatModel.schema() == {
|
||||||
|
"title": "FlatModel",
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"id": {"title": "Id", "type": "string", "format": "uuid"},
|
||||||
|
"enum_field": {"$ref": "#/definitions/MyEnum1"},
|
||||||
|
},
|
||||||
|
"required": ["id", "enum_field"],
|
||||||
|
"definitions": {
|
||||||
|
"MyEnum1": {
|
||||||
|
"title": "MyEnum1",
|
||||||
|
"description": "An enumeration.",
|
||||||
|
"enum": ["A", "B"],
|
||||||
|
"type": "string",
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def test_json_schema_inherit_model():
|
||||||
|
assert InheritModel.schema() == {
|
||||||
|
"title": "InheritModel",
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"id": {"title": "Id", "type": "string", "format": "uuid"},
|
||||||
|
"enum_field": {"$ref": "#/definitions/MyEnum2"},
|
||||||
|
},
|
||||||
|
"required": ["id", "enum_field"],
|
||||||
|
"definitions": {
|
||||||
|
"MyEnum2": {
|
||||||
|
"title": "MyEnum2",
|
||||||
|
"description": "An enumeration.",
|
||||||
|
"enum": ["C", "D"],
|
||||||
|
"type": "string",
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user