✨ Update type annotations and upgrade mypy (#173)
This commit is contained in:
parent
02da85c9ec
commit
e30c7ef4e9
@ -37,7 +37,7 @@ sqlalchemy2-stubs = {version = "*", allow-prereleases = true}
|
||||
|
||||
[tool.poetry.dev-dependencies]
|
||||
pytest = "^6.2.4"
|
||||
mypy = "^0.812"
|
||||
mypy = "^0.910"
|
||||
flake8 = "^3.9.2"
|
||||
black = {version = "^21.5-beta.1", python = "^3.7"}
|
||||
mkdocs = "^1.2.1"
|
||||
@ -98,3 +98,7 @@ warn_return_any = true
|
||||
implicit_reexport = false
|
||||
strict_equality = true
|
||||
# --strict end
|
||||
|
||||
[[tool.mypy.overrides]]
|
||||
module = "sqlmodel.sql.expression"
|
||||
warn_unused_ignores = false
|
||||
|
@ -136,4 +136,4 @@ def create_engine(
|
||||
if not isinstance(query_cache_size, _DefaultPlaceholder):
|
||||
current_kwargs["query_cache_size"] = query_cache_size
|
||||
current_kwargs.update(kwargs)
|
||||
return _create_engine(url, **current_kwargs)
|
||||
return _create_engine(url, **current_kwargs) # type: ignore
|
||||
|
@ -23,7 +23,7 @@ class ScalarResult(_ScalarResult, Generic[_T]):
|
||||
return super().__iter__()
|
||||
|
||||
def __next__(self) -> _T:
|
||||
return super().__next__()
|
||||
return super().__next__() # type: ignore
|
||||
|
||||
def first(self) -> Optional[_T]:
|
||||
return super().first()
|
||||
@ -32,7 +32,7 @@ class ScalarResult(_ScalarResult, Generic[_T]):
|
||||
return super().one_or_none()
|
||||
|
||||
def one(self) -> _T:
|
||||
return super().one()
|
||||
return super().one() # type: ignore
|
||||
|
||||
|
||||
class Result(_Result, Generic[_T]):
|
||||
@ -70,10 +70,10 @@ class Result(_Result, Generic[_T]):
|
||||
return super().scalar_one() # type: ignore
|
||||
|
||||
def scalar_one_or_none(self) -> Optional[_T]:
|
||||
return super().scalar_one_or_none() # type: ignore
|
||||
return super().scalar_one_or_none()
|
||||
|
||||
def one(self) -> _T: # type: ignore
|
||||
return super().one() # type: ignore
|
||||
|
||||
def scalar(self) -> Optional[_T]:
|
||||
return super().scalar() # type: ignore
|
||||
return super().scalar()
|
||||
|
@ -21,7 +21,7 @@ class AsyncSession(_AsyncSession):
|
||||
self,
|
||||
bind: Optional[Union[AsyncConnection, AsyncEngine]] = None,
|
||||
binds: Optional[Mapping[object, Union[AsyncConnection, AsyncEngine]]] = None,
|
||||
**kw,
|
||||
**kw: Any,
|
||||
):
|
||||
# All the same code of the original AsyncSession
|
||||
kw["future"] = True
|
||||
@ -52,7 +52,7 @@ class AsyncSession(_AsyncSession):
|
||||
# util.immutabledict has the union() method. Is this a bug in SQLAlchemy?
|
||||
execution_options = execution_options.union({"prebuffer_rows": True}) # type: ignore
|
||||
|
||||
return await greenlet_spawn( # type: ignore
|
||||
return await greenlet_spawn(
|
||||
self.sync_session.exec,
|
||||
statement,
|
||||
params=params,
|
||||
|
@ -101,7 +101,7 @@ class RelationshipInfo(Representation):
|
||||
*,
|
||||
back_populates: Optional[str] = None,
|
||||
link_model: Optional[Any] = None,
|
||||
sa_relationship: Optional[RelationshipProperty] = None,
|
||||
sa_relationship: Optional[RelationshipProperty] = None, # type: ignore
|
||||
sa_relationship_args: Optional[Sequence[Any]] = None,
|
||||
sa_relationship_kwargs: Optional[Mapping[str, Any]] = None,
|
||||
) -> None:
|
||||
@ -127,32 +127,32 @@ def Field(
|
||||
default: Any = Undefined,
|
||||
*,
|
||||
default_factory: Optional[NoArgAnyCallable] = None,
|
||||
alias: str = None,
|
||||
title: str = None,
|
||||
description: str = None,
|
||||
alias: Optional[str] = None,
|
||||
title: Optional[str] = None,
|
||||
description: Optional[str] = None,
|
||||
exclude: Union[
|
||||
AbstractSet[Union[int, str]], Mapping[Union[int, str], Any], Any
|
||||
] = None,
|
||||
include: Union[
|
||||
AbstractSet[Union[int, str]], Mapping[Union[int, str], Any], Any
|
||||
] = None,
|
||||
const: bool = None,
|
||||
gt: float = None,
|
||||
ge: float = None,
|
||||
lt: float = None,
|
||||
le: float = None,
|
||||
multiple_of: float = None,
|
||||
min_items: int = None,
|
||||
max_items: int = None,
|
||||
min_length: int = None,
|
||||
max_length: int = None,
|
||||
const: Optional[bool] = None,
|
||||
gt: Optional[float] = None,
|
||||
ge: Optional[float] = None,
|
||||
lt: Optional[float] = None,
|
||||
le: Optional[float] = None,
|
||||
multiple_of: Optional[float] = None,
|
||||
min_items: Optional[int] = None,
|
||||
max_items: Optional[int] = None,
|
||||
min_length: Optional[int] = None,
|
||||
max_length: Optional[int] = None,
|
||||
allow_mutation: bool = True,
|
||||
regex: str = None,
|
||||
regex: Optional[str] = None,
|
||||
primary_key: bool = False,
|
||||
foreign_key: Optional[Any] = None,
|
||||
nullable: Union[bool, UndefinedType] = Undefined,
|
||||
index: Union[bool, UndefinedType] = Undefined,
|
||||
sa_column: Union[Column, UndefinedType] = Undefined,
|
||||
sa_column: Union[Column, UndefinedType] = Undefined, # type: ignore
|
||||
sa_column_args: Union[Sequence[Any], UndefinedType] = Undefined,
|
||||
sa_column_kwargs: Union[Mapping[str, Any], UndefinedType] = Undefined,
|
||||
schema_extra: Optional[Dict[str, Any]] = None,
|
||||
@ -195,7 +195,7 @@ def Relationship(
|
||||
*,
|
||||
back_populates: Optional[str] = None,
|
||||
link_model: Optional[Any] = None,
|
||||
sa_relationship: Optional[RelationshipProperty] = None,
|
||||
sa_relationship: Optional[RelationshipProperty] = None, # type: ignore
|
||||
sa_relationship_args: Optional[Sequence[Any]] = None,
|
||||
sa_relationship_kwargs: Optional[Mapping[str, Any]] = None,
|
||||
) -> Any:
|
||||
@ -217,19 +217,25 @@ class SQLModelMetaclass(ModelMetaclass, DeclarativeMeta):
|
||||
|
||||
# Replicate SQLAlchemy
|
||||
def __setattr__(cls, name: str, value: Any) -> None:
|
||||
if getattr(cls.__config__, "table", False): # type: ignore
|
||||
if getattr(cls.__config__, "table", False):
|
||||
DeclarativeMeta.__setattr__(cls, name, value)
|
||||
else:
|
||||
super().__setattr__(name, value)
|
||||
|
||||
def __delattr__(cls, name: str) -> None:
|
||||
if getattr(cls.__config__, "table", False): # type: ignore
|
||||
if getattr(cls.__config__, "table", False):
|
||||
DeclarativeMeta.__delattr__(cls, name)
|
||||
else:
|
||||
super().__delattr__(name)
|
||||
|
||||
# From Pydantic
|
||||
def __new__(cls, name, bases, class_dict: dict, **kwargs) -> Any:
|
||||
def __new__(
|
||||
cls,
|
||||
name: str,
|
||||
bases: Tuple[Type[Any], ...],
|
||||
class_dict: Dict[str, Any],
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
relationships: Dict[str, RelationshipInfo] = {}
|
||||
dict_for_pydantic = {}
|
||||
original_annotations = resolve_annotations(
|
||||
@ -342,7 +348,7 @@ class SQLModelMetaclass(ModelMetaclass, DeclarativeMeta):
|
||||
)
|
||||
relationship_to = temp_field.type_
|
||||
if isinstance(temp_field.type_, ForwardRef):
|
||||
relationship_to = temp_field.type_.__forward_arg__ # type: ignore
|
||||
relationship_to = temp_field.type_.__forward_arg__
|
||||
rel_kwargs: Dict[str, Any] = {}
|
||||
if rel_info.back_populates:
|
||||
rel_kwargs["back_populates"] = rel_info.back_populates
|
||||
@ -360,7 +366,7 @@ class SQLModelMetaclass(ModelMetaclass, DeclarativeMeta):
|
||||
rel_args.extend(rel_info.sa_relationship_args)
|
||||
if rel_info.sa_relationship_kwargs:
|
||||
rel_kwargs.update(rel_info.sa_relationship_kwargs)
|
||||
rel_value: RelationshipProperty = relationship(
|
||||
rel_value: RelationshipProperty = relationship( # type: ignore
|
||||
relationship_to, *rel_args, **rel_kwargs
|
||||
)
|
||||
dict_used[rel_name] = rel_value
|
||||
@ -408,7 +414,7 @@ def get_sqlachemy_type(field: ModelField) -> Any:
|
||||
return GUID
|
||||
|
||||
|
||||
def get_column_from_field(field: ModelField) -> Column:
|
||||
def get_column_from_field(field: ModelField) -> Column: # type: ignore
|
||||
sa_column = getattr(field.field_info, "sa_column", Undefined)
|
||||
if isinstance(sa_column, Column):
|
||||
return sa_column
|
||||
@ -440,10 +446,10 @@ def get_column_from_field(field: ModelField) -> Column:
|
||||
kwargs["default"] = sa_default
|
||||
sa_column_args = getattr(field.field_info, "sa_column_args", Undefined)
|
||||
if sa_column_args is not Undefined:
|
||||
args.extend(list(cast(Sequence, sa_column_args)))
|
||||
args.extend(list(cast(Sequence[Any], sa_column_args)))
|
||||
sa_column_kwargs = getattr(field.field_info, "sa_column_kwargs", Undefined)
|
||||
if sa_column_kwargs is not Undefined:
|
||||
kwargs.update(cast(dict, sa_column_kwargs))
|
||||
kwargs.update(cast(Dict[Any, Any], sa_column_kwargs))
|
||||
return Column(sa_type, *args, **kwargs)
|
||||
|
||||
|
||||
@ -452,24 +458,27 @@ class_registry = weakref.WeakValueDictionary() # type: ignore
|
||||
default_registry = registry()
|
||||
|
||||
|
||||
def _value_items_is_true(v) -> bool:
|
||||
def _value_items_is_true(v: Any) -> bool:
|
||||
# Re-implement Pydantic's ValueItems.is_true() as it hasn't been released as of
|
||||
# the current latest, Pydantic 1.8.2
|
||||
return v is True or v is ...
|
||||
|
||||
|
||||
_TSQLModel = TypeVar("_TSQLModel", bound="SQLModel")
|
||||
|
||||
|
||||
class SQLModel(BaseModel, metaclass=SQLModelMetaclass, registry=default_registry):
|
||||
# SQLAlchemy needs to set weakref(s), Pydantic will set the other slots values
|
||||
__slots__ = ("__weakref__",)
|
||||
__tablename__: ClassVar[Union[str, Callable[..., str]]]
|
||||
__sqlmodel_relationships__: ClassVar[Dict[str, RelationshipProperty]]
|
||||
__sqlmodel_relationships__: ClassVar[Dict[str, RelationshipProperty]] # type: ignore
|
||||
__name__: ClassVar[str]
|
||||
metadata: ClassVar[MetaData]
|
||||
|
||||
class Config:
|
||||
orm_mode = True
|
||||
|
||||
def __new__(cls, *args, **kwargs) -> Any:
|
||||
def __new__(cls, *args: Any, **kwargs: Any) -> Any:
|
||||
new_object = super().__new__(cls)
|
||||
# SQLAlchemy doesn't call __init__ on the base class
|
||||
# Ref: https://docs.sqlalchemy.org/en/14/orm/constructors.html
|
||||
@ -520,7 +529,9 @@ class SQLModel(BaseModel, metaclass=SQLModelMetaclass, registry=default_registry
|
||||
super().__setattr__(name, value)
|
||||
|
||||
@classmethod
|
||||
def from_orm(cls: Type["SQLModel"], obj: Any, update: Dict[str, Any] = None):
|
||||
def from_orm(
|
||||
cls: Type[_TSQLModel], obj: Any, update: Optional[Dict[str, Any]] = None
|
||||
) -> _TSQLModel:
|
||||
# Duplicated from Pydantic
|
||||
if not cls.__config__.orm_mode:
|
||||
raise ConfigError(
|
||||
@ -533,7 +544,7 @@ class SQLModel(BaseModel, metaclass=SQLModelMetaclass, registry=default_registry
|
||||
# End SQLModel support dict
|
||||
if not getattr(cls.__config__, "table", False):
|
||||
# If not table, normal Pydantic code
|
||||
m = cls.__new__(cls)
|
||||
m: _TSQLModel = cls.__new__(cls)
|
||||
else:
|
||||
# If table, create the new instance normally to make SQLAlchemy create
|
||||
# the _sa_instance_state attribute
|
||||
@ -554,7 +565,7 @@ class SQLModel(BaseModel, metaclass=SQLModelMetaclass, registry=default_registry
|
||||
|
||||
@classmethod
|
||||
def parse_obj(
|
||||
cls: Type["SQLModel"], obj: Any, update: Dict[str, Any] = None
|
||||
cls: Type["SQLModel"], obj: Any, update: Optional[Dict[str, Any]] = None
|
||||
) -> "SQLModel":
|
||||
obj = cls._enforce_dict_if_root(obj)
|
||||
# SQLModel, support update dict
|
||||
|
@ -60,7 +60,7 @@ class Session(_Session):
|
||||
results = super().execute(
|
||||
statement,
|
||||
params=params,
|
||||
execution_options=execution_options, # type: ignore
|
||||
execution_options=execution_options,
|
||||
bind_arguments=bind_arguments,
|
||||
_parent_execute_state=_parent_execute_state,
|
||||
_add_event=_add_event,
|
||||
@ -74,7 +74,7 @@ class Session(_Session):
|
||||
self,
|
||||
statement: _Executable,
|
||||
params: Optional[Union[Mapping[str, Any], Sequence[Mapping[str, Any]]]] = None,
|
||||
execution_options: Mapping[str, Any] = util.EMPTY_DICT,
|
||||
execution_options: Optional[Mapping[str, Any]] = util.EMPTY_DICT,
|
||||
bind_arguments: Optional[Mapping[str, Any]] = None,
|
||||
_parent_execute_state: Optional[Any] = None,
|
||||
_add_event: Optional[Any] = None,
|
||||
@ -101,7 +101,7 @@ class Session(_Session):
|
||||
return super().execute( # type: ignore
|
||||
statement,
|
||||
params=params,
|
||||
execution_options=execution_options, # type: ignore
|
||||
execution_options=execution_options,
|
||||
bind_arguments=bind_arguments,
|
||||
_parent_execute_state=_parent_execute_state,
|
||||
_add_event=_add_event,
|
||||
|
@ -6,6 +6,4 @@ _T = TypeVar("_T")
|
||||
|
||||
|
||||
class Executable(_Executable, Generic[_T]):
|
||||
def __init__(self, *args, **kwargs):
|
||||
self.__dict__["_exec_options"] = kwargs.pop("_exec_options", None)
|
||||
super(_Executable, self).__init__(*args, **kwargs)
|
||||
pass
|
||||
|
@ -45,10 +45,10 @@ else:
|
||||
class GenericSelectMeta(GenericMeta, _Select.__class__): # type: ignore
|
||||
pass
|
||||
|
||||
class _Py36Select(_Select, Generic[_TSelect], metaclass=GenericSelectMeta): # type: ignore
|
||||
class _Py36Select(_Select, Generic[_TSelect], metaclass=GenericSelectMeta):
|
||||
pass
|
||||
|
||||
class _Py36SelectOfScalar(_Select, Generic[_TSelect], metaclass=GenericSelectMeta): # type: ignore
|
||||
class _Py36SelectOfScalar(_Select, Generic[_TSelect], metaclass=GenericSelectMeta):
|
||||
pass
|
||||
|
||||
# Cast them for editors to work correctly, from several tricks tried, this works
|
||||
@ -65,9 +65,9 @@ if TYPE_CHECKING: # pragma: no cover
|
||||
|
||||
_TScalar_0 = TypeVar(
|
||||
"_TScalar_0",
|
||||
Column,
|
||||
Sequence,
|
||||
Mapping,
|
||||
Column, # type: ignore
|
||||
Sequence, # type: ignore
|
||||
Mapping, # type: ignore
|
||||
UUID,
|
||||
datetime,
|
||||
float,
|
||||
@ -83,9 +83,9 @@ _TModel_0 = TypeVar("_TModel_0", bound="SQLModel")
|
||||
|
||||
_TScalar_1 = TypeVar(
|
||||
"_TScalar_1",
|
||||
Column,
|
||||
Sequence,
|
||||
Mapping,
|
||||
Column, # type: ignore
|
||||
Sequence, # type: ignore
|
||||
Mapping, # type: ignore
|
||||
UUID,
|
||||
datetime,
|
||||
float,
|
||||
@ -101,9 +101,9 @@ _TModel_1 = TypeVar("_TModel_1", bound="SQLModel")
|
||||
|
||||
_TScalar_2 = TypeVar(
|
||||
"_TScalar_2",
|
||||
Column,
|
||||
Sequence,
|
||||
Mapping,
|
||||
Column, # type: ignore
|
||||
Sequence, # type: ignore
|
||||
Mapping, # type: ignore
|
||||
UUID,
|
||||
datetime,
|
||||
float,
|
||||
@ -119,9 +119,9 @@ _TModel_2 = TypeVar("_TModel_2", bound="SQLModel")
|
||||
|
||||
_TScalar_3 = TypeVar(
|
||||
"_TScalar_3",
|
||||
Column,
|
||||
Sequence,
|
||||
Mapping,
|
||||
Column, # type: ignore
|
||||
Sequence, # type: ignore
|
||||
Mapping, # type: ignore
|
||||
UUID,
|
||||
datetime,
|
||||
float,
|
||||
@ -446,14 +446,14 @@ def select( # type: ignore
|
||||
# Generated overloads end
|
||||
|
||||
|
||||
def select(*entities: Any, **kw: Any) -> Union[Select, SelectOfScalar]:
|
||||
def select(*entities: Any, **kw: Any) -> Union[Select, SelectOfScalar]: # type: ignore
|
||||
if len(entities) == 1:
|
||||
return SelectOfScalar._create(*entities, **kw) # type: ignore
|
||||
return Select._create(*entities, **kw) # type: ignore
|
||||
|
||||
|
||||
# TODO: add several @overload from Python types to SQLAlchemy equivalents
|
||||
def col(column_expression: Any) -> ColumnClause:
|
||||
def col(column_expression: Any) -> ColumnClause: # type: ignore
|
||||
if not isinstance(column_expression, (ColumnClause, Column, InstrumentedAttribute)):
|
||||
raise RuntimeError(f"Not a SQLAlchemy column: {column_expression}")
|
||||
return column_expression
|
||||
|
@ -63,9 +63,9 @@ if TYPE_CHECKING: # pragma: no cover
|
||||
{% for i in range(number_of_types) %}
|
||||
_TScalar_{{ i }} = TypeVar(
|
||||
"_TScalar_{{ i }}",
|
||||
Column,
|
||||
Sequence,
|
||||
Mapping,
|
||||
Column, # type: ignore
|
||||
Sequence, # type: ignore
|
||||
Mapping, # type: ignore
|
||||
UUID,
|
||||
datetime,
|
||||
float,
|
||||
@ -106,14 +106,14 @@ def select( # type: ignore
|
||||
|
||||
# Generated overloads end
|
||||
|
||||
def select(*entities: Any, **kw: Any) -> Union[Select, SelectOfScalar]:
|
||||
def select(*entities: Any, **kw: Any) -> Union[Select, SelectOfScalar]: # type: ignore
|
||||
if len(entities) == 1:
|
||||
return SelectOfScalar._create(*entities, **kw) # type: ignore
|
||||
return Select._create(*entities, **kw)
|
||||
return Select._create(*entities, **kw) # type: ignore
|
||||
|
||||
|
||||
# TODO: add several @overload from Python types to SQLAlchemy equivalents
|
||||
def col(column_expression: Any) -> ColumnClause:
|
||||
def col(column_expression: Any) -> ColumnClause: # type: ignore
|
||||
if not isinstance(column_expression, (ColumnClause, Column, InstrumentedAttribute)):
|
||||
raise RuntimeError(f"Not a SQLAlchemy column: {column_expression}")
|
||||
return column_expression
|
||||
|
@ -1,13 +1,14 @@
|
||||
import uuid
|
||||
from typing import Any, cast
|
||||
from typing import Any, Optional, cast
|
||||
|
||||
from sqlalchemy import types
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
from sqlalchemy.engine.interfaces import Dialect
|
||||
from sqlalchemy.sql.type_api import TypeEngine
|
||||
from sqlalchemy.types import CHAR, TypeDecorator
|
||||
|
||||
|
||||
class AutoString(types.TypeDecorator):
|
||||
class AutoString(types.TypeDecorator): # type: ignore
|
||||
|
||||
impl = types.String
|
||||
cache_ok = True
|
||||
@ -22,7 +23,7 @@ class AutoString(types.TypeDecorator):
|
||||
|
||||
# Reference form SQLAlchemy docs: https://docs.sqlalchemy.org/en/14/core/custom_types.html#backend-agnostic-guid-type
|
||||
# with small modifications
|
||||
class GUID(TypeDecorator):
|
||||
class GUID(TypeDecorator): # type: ignore
|
||||
"""Platform-independent GUID type.
|
||||
|
||||
Uses PostgreSQL's UUID type, otherwise uses
|
||||
@ -33,13 +34,13 @@ class GUID(TypeDecorator):
|
||||
impl = CHAR
|
||||
cache_ok = True
|
||||
|
||||
def load_dialect_impl(self, dialect):
|
||||
def load_dialect_impl(self, dialect: Dialect) -> TypeEngine: # type: ignore
|
||||
if dialect.name == "postgresql":
|
||||
return dialect.type_descriptor(UUID())
|
||||
return dialect.type_descriptor(UUID()) # type: ignore
|
||||
else:
|
||||
return dialect.type_descriptor(CHAR(32))
|
||||
return dialect.type_descriptor(CHAR(32)) # type: ignore
|
||||
|
||||
def process_bind_param(self, value, dialect):
|
||||
def process_bind_param(self, value: Any, dialect: Dialect) -> Optional[str]:
|
||||
if value is None:
|
||||
return value
|
||||
elif dialect.name == "postgresql":
|
||||
@ -51,10 +52,10 @@ class GUID(TypeDecorator):
|
||||
# hexstring
|
||||
return f"{value.int:x}"
|
||||
|
||||
def process_result_value(self, value, dialect):
|
||||
def process_result_value(self, value: Any, dialect: Dialect) -> Optional[uuid.UUID]:
|
||||
if value is None:
|
||||
return value
|
||||
else:
|
||||
if not isinstance(value, uuid.UUID):
|
||||
value = uuid.UUID(value)
|
||||
return value
|
||||
return cast(uuid.UUID, value)
|
||||
|
Loading…
x
Reference in New Issue
Block a user