🐛 Fix Enum handling in SQLAlchemy (#165)
Co-authored-by: Sebastián Ramírez <tiangolo@gmail.com>
This commit is contained in:
parent
2fab4817fe
commit
eef0b7770b
@ -31,18 +31,9 @@ from pydantic.fields import ModelField, Undefined, UndefinedType
|
|||||||
from pydantic.main import ModelMetaclass, validate_model
|
from pydantic.main import ModelMetaclass, validate_model
|
||||||
from pydantic.typing import ForwardRef, NoArgAnyCallable, resolve_annotations
|
from pydantic.typing import ForwardRef, NoArgAnyCallable, resolve_annotations
|
||||||
from pydantic.utils import ROOT_KEY, Representation
|
from pydantic.utils import ROOT_KEY, Representation
|
||||||
from sqlalchemy import (
|
from sqlalchemy import Boolean, Column, Date, DateTime
|
||||||
Boolean,
|
from sqlalchemy import Enum as sa_Enum
|
||||||
Column,
|
from sqlalchemy import Float, ForeignKey, Integer, Interval, Numeric, inspect
|
||||||
Date,
|
|
||||||
DateTime,
|
|
||||||
Float,
|
|
||||||
ForeignKey,
|
|
||||||
Integer,
|
|
||||||
Interval,
|
|
||||||
Numeric,
|
|
||||||
inspect,
|
|
||||||
)
|
|
||||||
from sqlalchemy.orm import RelationshipProperty, declared_attr, registry, relationship
|
from sqlalchemy.orm import RelationshipProperty, declared_attr, registry, relationship
|
||||||
from sqlalchemy.orm.attributes import set_attribute
|
from sqlalchemy.orm.attributes import set_attribute
|
||||||
from sqlalchemy.orm.decl_api import DeclarativeMeta
|
from sqlalchemy.orm.decl_api import DeclarativeMeta
|
||||||
@ -396,7 +387,7 @@ def get_sqlachemy_type(field: ModelField) -> Any:
|
|||||||
if issubclass(field.type_, time):
|
if issubclass(field.type_, time):
|
||||||
return Time
|
return Time
|
||||||
if issubclass(field.type_, Enum):
|
if issubclass(field.type_, Enum):
|
||||||
return Enum
|
return sa_Enum(field.type_)
|
||||||
if issubclass(field.type_, bytes):
|
if issubclass(field.type_, bytes):
|
||||||
return LargeBinary
|
return LargeBinary
|
||||||
if issubclass(field.type_, Decimal):
|
if issubclass(field.type_, Decimal):
|
||||||
|
72
tests/test_enums.py
Normal file
72
tests/test_enums.py
Normal file
@ -0,0 +1,72 @@
|
|||||||
|
import enum
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
from sqlalchemy import create_mock_engine
|
||||||
|
from sqlalchemy.sql.type_api import TypeEngine
|
||||||
|
from sqlmodel import Field, SQLModel
|
||||||
|
|
||||||
|
"""
|
||||||
|
Tests related to Enums
|
||||||
|
|
||||||
|
Associated issues:
|
||||||
|
* https://github.com/tiangolo/sqlmodel/issues/96
|
||||||
|
* https://github.com/tiangolo/sqlmodel/issues/164
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
class MyEnum1(enum.Enum):
|
||||||
|
A = "A"
|
||||||
|
B = "B"
|
||||||
|
|
||||||
|
|
||||||
|
class MyEnum2(enum.Enum):
|
||||||
|
C = "C"
|
||||||
|
D = "D"
|
||||||
|
|
||||||
|
|
||||||
|
class BaseModel(SQLModel):
|
||||||
|
id: uuid.UUID = Field(primary_key=True)
|
||||||
|
enum_field: MyEnum2
|
||||||
|
|
||||||
|
|
||||||
|
class FlatModel(SQLModel, table=True):
|
||||||
|
id: uuid.UUID = Field(primary_key=True)
|
||||||
|
enum_field: MyEnum1
|
||||||
|
|
||||||
|
|
||||||
|
class InheritModel(BaseModel, table=True):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def pg_dump(sql: TypeEngine, *args, **kwargs):
|
||||||
|
dialect = sql.compile(dialect=postgres_engine.dialect)
|
||||||
|
sql_str = str(dialect).rstrip()
|
||||||
|
if sql_str:
|
||||||
|
print(sql_str + ";")
|
||||||
|
|
||||||
|
|
||||||
|
def sqlite_dump(sql: TypeEngine, *args, **kwargs):
|
||||||
|
dialect = sql.compile(dialect=sqlite_engine.dialect)
|
||||||
|
sql_str = str(dialect).rstrip()
|
||||||
|
if sql_str:
|
||||||
|
print(sql_str + ";")
|
||||||
|
|
||||||
|
|
||||||
|
postgres_engine = create_mock_engine("postgresql://", pg_dump)
|
||||||
|
sqlite_engine = create_mock_engine("sqlite://", sqlite_dump)
|
||||||
|
|
||||||
|
|
||||||
|
def test_postgres_ddl_sql(capsys):
|
||||||
|
SQLModel.metadata.create_all(bind=postgres_engine, checkfirst=False)
|
||||||
|
|
||||||
|
captured = capsys.readouterr()
|
||||||
|
assert "CREATE TYPE myenum1 AS ENUM ('A', 'B');" in captured.out
|
||||||
|
assert "CREATE TYPE myenum2 AS ENUM ('C', 'D');" in captured.out
|
||||||
|
|
||||||
|
|
||||||
|
def test_sqlite_ddl_sql(capsys):
|
||||||
|
SQLModel.metadata.create_all(bind=sqlite_engine, checkfirst=False)
|
||||||
|
|
||||||
|
captured = capsys.readouterr()
|
||||||
|
assert "enum_field VARCHAR(1) NOT NULL" in captured.out
|
||||||
|
assert "CREATE TYPE" not in captured.out
|
Loading…
x
Reference in New Issue
Block a user