🐛 Fix setting nullable property of Fields that don't accept None (#79)

Co-authored-by: Sebastián Ramírez <tiangolo@gmail.com>
This commit is contained in:
Evangelos Anagnostopoulos 2022-08-28 01:18:57 +03:00 committed by GitHub
parent 2407ecd2bf
commit 9830ee0d89
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 13 additions and 2 deletions

View File

@ -25,6 +25,7 @@ from typing import (
from pydantic import BaseConfig, BaseModel from pydantic import BaseConfig, BaseModel
from pydantic.errors import ConfigError, DictError from pydantic.errors import ConfigError, DictError
from pydantic.fields import SHAPE_SINGLETON
from pydantic.fields import FieldInfo as PydanticFieldInfo from pydantic.fields import FieldInfo as PydanticFieldInfo
from pydantic.fields import ModelField, Undefined, UndefinedType from pydantic.fields import ModelField, Undefined, UndefinedType
from pydantic.main import ModelMetaclass, validate_model from pydantic.main import ModelMetaclass, validate_model
@ -424,7 +425,6 @@ def get_column_from_field(field: ModelField) -> Column: # type: ignore
return sa_column return sa_column
sa_type = get_sqlachemy_type(field) sa_type = get_sqlachemy_type(field)
primary_key = getattr(field.field_info, "primary_key", False) primary_key = getattr(field.field_info, "primary_key", False)
nullable = not field.required
index = getattr(field.field_info, "index", Undefined) index = getattr(field.field_info, "index", Undefined)
if index is Undefined: if index is Undefined:
index = False index = False
@ -432,6 +432,7 @@ def get_column_from_field(field: ModelField) -> Column: # type: ignore
field_nullable = getattr(field.field_info, "nullable") field_nullable = getattr(field.field_info, "nullable")
if field_nullable != Undefined: if field_nullable != Undefined:
nullable = field_nullable nullable = field_nullable
nullable = not primary_key and _is_field_nullable(field)
args = [] args = []
foreign_key = getattr(field.field_info, "foreign_key", None) foreign_key = getattr(field.field_info, "foreign_key", None)
if foreign_key: if foreign_key:
@ -646,3 +647,13 @@ class SQLModel(BaseModel, metaclass=SQLModelMetaclass, registry=default_registry
@declared_attr # type: ignore @declared_attr # type: ignore
def __tablename__(cls) -> str: def __tablename__(cls) -> str:
return cls.__name__.lower() return cls.__name__.lower()
def _is_field_nullable(field: ModelField) -> bool:
if not field.required:
# Taken from [Pydantic](https://github.com/samuelcolvin/pydantic/blob/v1.8.2/pydantic/fields.py#L946-L947)
is_optional = field.allow_none and (
field.shape != SHAPE_SINGLETON or not field.sub_fields
)
return is_optional and field.default is None and field.default_factory is None
return False

View File

@ -9,7 +9,7 @@ def test_create_db_and_table(cov_tmp_path: Path):
assert "BEGIN" in result.stdout assert "BEGIN" in result.stdout
assert 'PRAGMA main.table_info("hero")' in result.stdout assert 'PRAGMA main.table_info("hero")' in result.stdout
assert "CREATE TABLE hero (" in result.stdout assert "CREATE TABLE hero (" in result.stdout
assert "id INTEGER," in result.stdout assert "id INTEGER NOT NULL," in result.stdout
assert "name VARCHAR NOT NULL," in result.stdout assert "name VARCHAR NOT NULL," in result.stdout
assert "secret_name VARCHAR NOT NULL," in result.stdout assert "secret_name VARCHAR NOT NULL," in result.stdout
assert "age INTEGER," in result.stdout assert "age INTEGER," in result.stdout