Update type annotations and upgrade mypy (#173)

This commit is contained in:
Sebastián Ramírez 2021-11-30 17:12:28 +01:00 committed by GitHub
parent 02da85c9ec
commit e30c7ef4e9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 90 additions and 76 deletions

View File

@ -37,7 +37,7 @@ sqlalchemy2-stubs = {version = "*", allow-prereleases = true}
[tool.poetry.dev-dependencies]
pytest = "^6.2.4"
mypy = "^0.812"
mypy = "^0.910"
flake8 = "^3.9.2"
black = {version = "^21.5-beta.1", python = "^3.7"}
mkdocs = "^1.2.1"
@ -98,3 +98,7 @@ warn_return_any = true
implicit_reexport = false
strict_equality = true
# --strict end
[[tool.mypy.overrides]]
module = "sqlmodel.sql.expression"
warn_unused_ignores = false

View File

@ -136,4 +136,4 @@ def create_engine(
if not isinstance(query_cache_size, _DefaultPlaceholder):
current_kwargs["query_cache_size"] = query_cache_size
current_kwargs.update(kwargs)
return _create_engine(url, **current_kwargs)
return _create_engine(url, **current_kwargs) # type: ignore

View File

@ -23,7 +23,7 @@ class ScalarResult(_ScalarResult, Generic[_T]):
return super().__iter__()
def __next__(self) -> _T:
return super().__next__()
return super().__next__() # type: ignore
def first(self) -> Optional[_T]:
return super().first()
@ -32,7 +32,7 @@ class ScalarResult(_ScalarResult, Generic[_T]):
return super().one_or_none()
def one(self) -> _T:
return super().one()
return super().one() # type: ignore
class Result(_Result, Generic[_T]):
@ -70,10 +70,10 @@ class Result(_Result, Generic[_T]):
return super().scalar_one() # type: ignore
def scalar_one_or_none(self) -> Optional[_T]:
return super().scalar_one_or_none() # type: ignore
return super().scalar_one_or_none()
def one(self) -> _T: # type: ignore
return super().one() # type: ignore
def scalar(self) -> Optional[_T]:
return super().scalar() # type: ignore
return super().scalar()

View File

@ -21,7 +21,7 @@ class AsyncSession(_AsyncSession):
self,
bind: Optional[Union[AsyncConnection, AsyncEngine]] = None,
binds: Optional[Mapping[object, Union[AsyncConnection, AsyncEngine]]] = None,
**kw,
**kw: Any,
):
# All the same code of the original AsyncSession
kw["future"] = True
@ -52,7 +52,7 @@ class AsyncSession(_AsyncSession):
# util.immutabledict has the union() method. Is this a bug in SQLAlchemy?
execution_options = execution_options.union({"prebuffer_rows": True}) # type: ignore
return await greenlet_spawn( # type: ignore
return await greenlet_spawn(
self.sync_session.exec,
statement,
params=params,

View File

@ -101,7 +101,7 @@ class RelationshipInfo(Representation):
*,
back_populates: Optional[str] = None,
link_model: Optional[Any] = None,
sa_relationship: Optional[RelationshipProperty] = None,
sa_relationship: Optional[RelationshipProperty] = None, # type: ignore
sa_relationship_args: Optional[Sequence[Any]] = None,
sa_relationship_kwargs: Optional[Mapping[str, Any]] = None,
) -> None:
@ -127,32 +127,32 @@ def Field(
default: Any = Undefined,
*,
default_factory: Optional[NoArgAnyCallable] = None,
alias: str = None,
title: str = None,
description: str = None,
alias: Optional[str] = None,
title: Optional[str] = None,
description: Optional[str] = None,
exclude: Union[
AbstractSet[Union[int, str]], Mapping[Union[int, str], Any], Any
] = None,
include: Union[
AbstractSet[Union[int, str]], Mapping[Union[int, str], Any], Any
] = None,
const: bool = None,
gt: float = None,
ge: float = None,
lt: float = None,
le: float = None,
multiple_of: float = None,
min_items: int = None,
max_items: int = None,
min_length: int = None,
max_length: int = None,
const: Optional[bool] = None,
gt: Optional[float] = None,
ge: Optional[float] = None,
lt: Optional[float] = None,
le: Optional[float] = None,
multiple_of: Optional[float] = None,
min_items: Optional[int] = None,
max_items: Optional[int] = None,
min_length: Optional[int] = None,
max_length: Optional[int] = None,
allow_mutation: bool = True,
regex: str = None,
regex: Optional[str] = None,
primary_key: bool = False,
foreign_key: Optional[Any] = None,
nullable: Union[bool, UndefinedType] = Undefined,
index: Union[bool, UndefinedType] = Undefined,
sa_column: Union[Column, UndefinedType] = Undefined,
sa_column: Union[Column, UndefinedType] = Undefined, # type: ignore
sa_column_args: Union[Sequence[Any], UndefinedType] = Undefined,
sa_column_kwargs: Union[Mapping[str, Any], UndefinedType] = Undefined,
schema_extra: Optional[Dict[str, Any]] = None,
@ -195,7 +195,7 @@ def Relationship(
*,
back_populates: Optional[str] = None,
link_model: Optional[Any] = None,
sa_relationship: Optional[RelationshipProperty] = None,
sa_relationship: Optional[RelationshipProperty] = None, # type: ignore
sa_relationship_args: Optional[Sequence[Any]] = None,
sa_relationship_kwargs: Optional[Mapping[str, Any]] = None,
) -> Any:
@ -217,19 +217,25 @@ class SQLModelMetaclass(ModelMetaclass, DeclarativeMeta):
# Replicate SQLAlchemy
def __setattr__(cls, name: str, value: Any) -> None:
if getattr(cls.__config__, "table", False): # type: ignore
if getattr(cls.__config__, "table", False):
DeclarativeMeta.__setattr__(cls, name, value)
else:
super().__setattr__(name, value)
def __delattr__(cls, name: str) -> None:
if getattr(cls.__config__, "table", False): # type: ignore
if getattr(cls.__config__, "table", False):
DeclarativeMeta.__delattr__(cls, name)
else:
super().__delattr__(name)
# From Pydantic
def __new__(cls, name, bases, class_dict: dict, **kwargs) -> Any:
def __new__(
cls,
name: str,
bases: Tuple[Type[Any], ...],
class_dict: Dict[str, Any],
**kwargs: Any,
) -> Any:
relationships: Dict[str, RelationshipInfo] = {}
dict_for_pydantic = {}
original_annotations = resolve_annotations(
@ -342,7 +348,7 @@ class SQLModelMetaclass(ModelMetaclass, DeclarativeMeta):
)
relationship_to = temp_field.type_
if isinstance(temp_field.type_, ForwardRef):
relationship_to = temp_field.type_.__forward_arg__ # type: ignore
relationship_to = temp_field.type_.__forward_arg__
rel_kwargs: Dict[str, Any] = {}
if rel_info.back_populates:
rel_kwargs["back_populates"] = rel_info.back_populates
@ -360,7 +366,7 @@ class SQLModelMetaclass(ModelMetaclass, DeclarativeMeta):
rel_args.extend(rel_info.sa_relationship_args)
if rel_info.sa_relationship_kwargs:
rel_kwargs.update(rel_info.sa_relationship_kwargs)
rel_value: RelationshipProperty = relationship(
rel_value: RelationshipProperty = relationship( # type: ignore
relationship_to, *rel_args, **rel_kwargs
)
dict_used[rel_name] = rel_value
@ -408,7 +414,7 @@ def get_sqlachemy_type(field: ModelField) -> Any:
return GUID
def get_column_from_field(field: ModelField) -> Column:
def get_column_from_field(field: ModelField) -> Column: # type: ignore
sa_column = getattr(field.field_info, "sa_column", Undefined)
if isinstance(sa_column, Column):
return sa_column
@ -440,10 +446,10 @@ def get_column_from_field(field: ModelField) -> Column:
kwargs["default"] = sa_default
sa_column_args = getattr(field.field_info, "sa_column_args", Undefined)
if sa_column_args is not Undefined:
args.extend(list(cast(Sequence, sa_column_args)))
args.extend(list(cast(Sequence[Any], sa_column_args)))
sa_column_kwargs = getattr(field.field_info, "sa_column_kwargs", Undefined)
if sa_column_kwargs is not Undefined:
kwargs.update(cast(dict, sa_column_kwargs))
kwargs.update(cast(Dict[Any, Any], sa_column_kwargs))
return Column(sa_type, *args, **kwargs)
@ -452,24 +458,27 @@ class_registry = weakref.WeakValueDictionary() # type: ignore
default_registry = registry()
def _value_items_is_true(v) -> bool:
def _value_items_is_true(v: Any) -> bool:
# Re-implement Pydantic's ValueItems.is_true() as it hasn't been released as of
# the current latest, Pydantic 1.8.2
return v is True or v is ...
_TSQLModel = TypeVar("_TSQLModel", bound="SQLModel")
class SQLModel(BaseModel, metaclass=SQLModelMetaclass, registry=default_registry):
# SQLAlchemy needs to set weakref(s), Pydantic will set the other slots values
__slots__ = ("__weakref__",)
__tablename__: ClassVar[Union[str, Callable[..., str]]]
__sqlmodel_relationships__: ClassVar[Dict[str, RelationshipProperty]]
__sqlmodel_relationships__: ClassVar[Dict[str, RelationshipProperty]] # type: ignore
__name__: ClassVar[str]
metadata: ClassVar[MetaData]
class Config:
orm_mode = True
def __new__(cls, *args, **kwargs) -> Any:
def __new__(cls, *args: Any, **kwargs: Any) -> Any:
new_object = super().__new__(cls)
# SQLAlchemy doesn't call __init__ on the base class
# Ref: https://docs.sqlalchemy.org/en/14/orm/constructors.html
@ -520,7 +529,9 @@ class SQLModel(BaseModel, metaclass=SQLModelMetaclass, registry=default_registry
super().__setattr__(name, value)
@classmethod
def from_orm(cls: Type["SQLModel"], obj: Any, update: Dict[str, Any] = None):
def from_orm(
cls: Type[_TSQLModel], obj: Any, update: Optional[Dict[str, Any]] = None
) -> _TSQLModel:
# Duplicated from Pydantic
if not cls.__config__.orm_mode:
raise ConfigError(
@ -533,7 +544,7 @@ class SQLModel(BaseModel, metaclass=SQLModelMetaclass, registry=default_registry
# End SQLModel support dict
if not getattr(cls.__config__, "table", False):
# If not table, normal Pydantic code
m = cls.__new__(cls)
m: _TSQLModel = cls.__new__(cls)
else:
# If table, create the new instance normally to make SQLAlchemy create
# the _sa_instance_state attribute
@ -554,7 +565,7 @@ class SQLModel(BaseModel, metaclass=SQLModelMetaclass, registry=default_registry
@classmethod
def parse_obj(
cls: Type["SQLModel"], obj: Any, update: Dict[str, Any] = None
cls: Type["SQLModel"], obj: Any, update: Optional[Dict[str, Any]] = None
) -> "SQLModel":
obj = cls._enforce_dict_if_root(obj)
# SQLModel, support update dict

View File

@ -60,7 +60,7 @@ class Session(_Session):
results = super().execute(
statement,
params=params,
execution_options=execution_options, # type: ignore
execution_options=execution_options,
bind_arguments=bind_arguments,
_parent_execute_state=_parent_execute_state,
_add_event=_add_event,
@ -74,7 +74,7 @@ class Session(_Session):
self,
statement: _Executable,
params: Optional[Union[Mapping[str, Any], Sequence[Mapping[str, Any]]]] = None,
execution_options: Mapping[str, Any] = util.EMPTY_DICT,
execution_options: Optional[Mapping[str, Any]] = util.EMPTY_DICT,
bind_arguments: Optional[Mapping[str, Any]] = None,
_parent_execute_state: Optional[Any] = None,
_add_event: Optional[Any] = None,
@ -101,7 +101,7 @@ class Session(_Session):
return super().execute( # type: ignore
statement,
params=params,
execution_options=execution_options, # type: ignore
execution_options=execution_options,
bind_arguments=bind_arguments,
_parent_execute_state=_parent_execute_state,
_add_event=_add_event,

View File

@ -6,6 +6,4 @@ _T = TypeVar("_T")
class Executable(_Executable, Generic[_T]):
def __init__(self, *args, **kwargs):
self.__dict__["_exec_options"] = kwargs.pop("_exec_options", None)
super(_Executable, self).__init__(*args, **kwargs)
pass

View File

@ -45,10 +45,10 @@ else:
class GenericSelectMeta(GenericMeta, _Select.__class__): # type: ignore
pass
class _Py36Select(_Select, Generic[_TSelect], metaclass=GenericSelectMeta): # type: ignore
class _Py36Select(_Select, Generic[_TSelect], metaclass=GenericSelectMeta):
pass
class _Py36SelectOfScalar(_Select, Generic[_TSelect], metaclass=GenericSelectMeta): # type: ignore
class _Py36SelectOfScalar(_Select, Generic[_TSelect], metaclass=GenericSelectMeta):
pass
# Cast them for editors to work correctly, from several tricks tried, this works
@ -65,9 +65,9 @@ if TYPE_CHECKING: # pragma: no cover
_TScalar_0 = TypeVar(
"_TScalar_0",
Column,
Sequence,
Mapping,
Column, # type: ignore
Sequence, # type: ignore
Mapping, # type: ignore
UUID,
datetime,
float,
@ -83,9 +83,9 @@ _TModel_0 = TypeVar("_TModel_0", bound="SQLModel")
_TScalar_1 = TypeVar(
"_TScalar_1",
Column,
Sequence,
Mapping,
Column, # type: ignore
Sequence, # type: ignore
Mapping, # type: ignore
UUID,
datetime,
float,
@ -101,9 +101,9 @@ _TModel_1 = TypeVar("_TModel_1", bound="SQLModel")
_TScalar_2 = TypeVar(
"_TScalar_2",
Column,
Sequence,
Mapping,
Column, # type: ignore
Sequence, # type: ignore
Mapping, # type: ignore
UUID,
datetime,
float,
@ -119,9 +119,9 @@ _TModel_2 = TypeVar("_TModel_2", bound="SQLModel")
_TScalar_3 = TypeVar(
"_TScalar_3",
Column,
Sequence,
Mapping,
Column, # type: ignore
Sequence, # type: ignore
Mapping, # type: ignore
UUID,
datetime,
float,
@ -446,14 +446,14 @@ def select( # type: ignore
# Generated overloads end
def select(*entities: Any, **kw: Any) -> Union[Select, SelectOfScalar]:
def select(*entities: Any, **kw: Any) -> Union[Select, SelectOfScalar]: # type: ignore
if len(entities) == 1:
return SelectOfScalar._create(*entities, **kw) # type: ignore
return Select._create(*entities, **kw) # type: ignore
# TODO: add several @overload from Python types to SQLAlchemy equivalents
def col(column_expression: Any) -> ColumnClause:
def col(column_expression: Any) -> ColumnClause: # type: ignore
if not isinstance(column_expression, (ColumnClause, Column, InstrumentedAttribute)):
raise RuntimeError(f"Not a SQLAlchemy column: {column_expression}")
return column_expression

View File

@ -63,9 +63,9 @@ if TYPE_CHECKING: # pragma: no cover
{% for i in range(number_of_types) %}
_TScalar_{{ i }} = TypeVar(
"_TScalar_{{ i }}",
Column,
Sequence,
Mapping,
Column, # type: ignore
Sequence, # type: ignore
Mapping, # type: ignore
UUID,
datetime,
float,
@ -106,14 +106,14 @@ def select( # type: ignore
# Generated overloads end
def select(*entities: Any, **kw: Any) -> Union[Select, SelectOfScalar]:
def select(*entities: Any, **kw: Any) -> Union[Select, SelectOfScalar]: # type: ignore
if len(entities) == 1:
return SelectOfScalar._create(*entities, **kw) # type: ignore
return Select._create(*entities, **kw)
return Select._create(*entities, **kw) # type: ignore
# TODO: add several @overload from Python types to SQLAlchemy equivalents
def col(column_expression: Any) -> ColumnClause:
def col(column_expression: Any) -> ColumnClause: # type: ignore
if not isinstance(column_expression, (ColumnClause, Column, InstrumentedAttribute)):
raise RuntimeError(f"Not a SQLAlchemy column: {column_expression}")
return column_expression

View File

@ -1,13 +1,14 @@
import uuid
from typing import Any, cast
from typing import Any, Optional, cast
from sqlalchemy import types
from sqlalchemy.dialects.postgresql import UUID
from sqlalchemy.engine.interfaces import Dialect
from sqlalchemy.sql.type_api import TypeEngine
from sqlalchemy.types import CHAR, TypeDecorator
class AutoString(types.TypeDecorator):
class AutoString(types.TypeDecorator): # type: ignore
impl = types.String
cache_ok = True
@ -22,7 +23,7 @@ class AutoString(types.TypeDecorator):
# Reference form SQLAlchemy docs: https://docs.sqlalchemy.org/en/14/core/custom_types.html#backend-agnostic-guid-type
# with small modifications
class GUID(TypeDecorator):
class GUID(TypeDecorator): # type: ignore
"""Platform-independent GUID type.
Uses PostgreSQL's UUID type, otherwise uses
@ -33,13 +34,13 @@ class GUID(TypeDecorator):
impl = CHAR
cache_ok = True
def load_dialect_impl(self, dialect):
def load_dialect_impl(self, dialect: Dialect) -> TypeEngine: # type: ignore
if dialect.name == "postgresql":
return dialect.type_descriptor(UUID())
return dialect.type_descriptor(UUID()) # type: ignore
else:
return dialect.type_descriptor(CHAR(32))
return dialect.type_descriptor(CHAR(32)) # type: ignore
def process_bind_param(self, value, dialect):
def process_bind_param(self, value: Any, dialect: Dialect) -> Optional[str]:
if value is None:
return value
elif dialect.name == "postgresql":
@ -51,10 +52,10 @@ class GUID(TypeDecorator):
# hexstring
return f"{value.int:x}"
def process_result_value(self, value, dialect):
def process_result_value(self, value: Any, dialect: Dialect) -> Optional[uuid.UUID]:
if value is None:
return value
else:
if not isinstance(value, uuid.UUID):
value = uuid.UUID(value)
return value
return cast(uuid.UUID, value)