Update type annotations and upgrade mypy (#173)

This commit is contained in:
Sebastián Ramírez 2021-11-30 17:12:28 +01:00 committed by GitHub
parent 02da85c9ec
commit e30c7ef4e9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 90 additions and 76 deletions

View File

@ -37,7 +37,7 @@ sqlalchemy2-stubs = {version = "*", allow-prereleases = true}
[tool.poetry.dev-dependencies] [tool.poetry.dev-dependencies]
pytest = "^6.2.4" pytest = "^6.2.4"
mypy = "^0.812" mypy = "^0.910"
flake8 = "^3.9.2" flake8 = "^3.9.2"
black = {version = "^21.5-beta.1", python = "^3.7"} black = {version = "^21.5-beta.1", python = "^3.7"}
mkdocs = "^1.2.1" mkdocs = "^1.2.1"
@ -98,3 +98,7 @@ warn_return_any = true
implicit_reexport = false implicit_reexport = false
strict_equality = true strict_equality = true
# --strict end # --strict end
[[tool.mypy.overrides]]
module = "sqlmodel.sql.expression"
warn_unused_ignores = false

View File

@ -136,4 +136,4 @@ def create_engine(
if not isinstance(query_cache_size, _DefaultPlaceholder): if not isinstance(query_cache_size, _DefaultPlaceholder):
current_kwargs["query_cache_size"] = query_cache_size current_kwargs["query_cache_size"] = query_cache_size
current_kwargs.update(kwargs) current_kwargs.update(kwargs)
return _create_engine(url, **current_kwargs) return _create_engine(url, **current_kwargs) # type: ignore

View File

@ -23,7 +23,7 @@ class ScalarResult(_ScalarResult, Generic[_T]):
return super().__iter__() return super().__iter__()
def __next__(self) -> _T: def __next__(self) -> _T:
return super().__next__() return super().__next__() # type: ignore
def first(self) -> Optional[_T]: def first(self) -> Optional[_T]:
return super().first() return super().first()
@ -32,7 +32,7 @@ class ScalarResult(_ScalarResult, Generic[_T]):
return super().one_or_none() return super().one_or_none()
def one(self) -> _T: def one(self) -> _T:
return super().one() return super().one() # type: ignore
class Result(_Result, Generic[_T]): class Result(_Result, Generic[_T]):
@ -70,10 +70,10 @@ class Result(_Result, Generic[_T]):
return super().scalar_one() # type: ignore return super().scalar_one() # type: ignore
def scalar_one_or_none(self) -> Optional[_T]: def scalar_one_or_none(self) -> Optional[_T]:
return super().scalar_one_or_none() # type: ignore return super().scalar_one_or_none()
def one(self) -> _T: # type: ignore def one(self) -> _T: # type: ignore
return super().one() # type: ignore return super().one() # type: ignore
def scalar(self) -> Optional[_T]: def scalar(self) -> Optional[_T]:
return super().scalar() # type: ignore return super().scalar()

View File

@ -21,7 +21,7 @@ class AsyncSession(_AsyncSession):
self, self,
bind: Optional[Union[AsyncConnection, AsyncEngine]] = None, bind: Optional[Union[AsyncConnection, AsyncEngine]] = None,
binds: Optional[Mapping[object, Union[AsyncConnection, AsyncEngine]]] = None, binds: Optional[Mapping[object, Union[AsyncConnection, AsyncEngine]]] = None,
**kw, **kw: Any,
): ):
# All the same code of the original AsyncSession # All the same code of the original AsyncSession
kw["future"] = True kw["future"] = True
@ -52,7 +52,7 @@ class AsyncSession(_AsyncSession):
# util.immutabledict has the union() method. Is this a bug in SQLAlchemy? # util.immutabledict has the union() method. Is this a bug in SQLAlchemy?
execution_options = execution_options.union({"prebuffer_rows": True}) # type: ignore execution_options = execution_options.union({"prebuffer_rows": True}) # type: ignore
return await greenlet_spawn( # type: ignore return await greenlet_spawn(
self.sync_session.exec, self.sync_session.exec,
statement, statement,
params=params, params=params,

View File

@ -101,7 +101,7 @@ class RelationshipInfo(Representation):
*, *,
back_populates: Optional[str] = None, back_populates: Optional[str] = None,
link_model: Optional[Any] = None, link_model: Optional[Any] = None,
sa_relationship: Optional[RelationshipProperty] = None, sa_relationship: Optional[RelationshipProperty] = None, # type: ignore
sa_relationship_args: Optional[Sequence[Any]] = None, sa_relationship_args: Optional[Sequence[Any]] = None,
sa_relationship_kwargs: Optional[Mapping[str, Any]] = None, sa_relationship_kwargs: Optional[Mapping[str, Any]] = None,
) -> None: ) -> None:
@ -127,32 +127,32 @@ def Field(
default: Any = Undefined, default: Any = Undefined,
*, *,
default_factory: Optional[NoArgAnyCallable] = None, default_factory: Optional[NoArgAnyCallable] = None,
alias: str = None, alias: Optional[str] = None,
title: str = None, title: Optional[str] = None,
description: str = None, description: Optional[str] = None,
exclude: Union[ exclude: Union[
AbstractSet[Union[int, str]], Mapping[Union[int, str], Any], Any AbstractSet[Union[int, str]], Mapping[Union[int, str], Any], Any
] = None, ] = None,
include: Union[ include: Union[
AbstractSet[Union[int, str]], Mapping[Union[int, str], Any], Any AbstractSet[Union[int, str]], Mapping[Union[int, str], Any], Any
] = None, ] = None,
const: bool = None, const: Optional[bool] = None,
gt: float = None, gt: Optional[float] = None,
ge: float = None, ge: Optional[float] = None,
lt: float = None, lt: Optional[float] = None,
le: float = None, le: Optional[float] = None,
multiple_of: float = None, multiple_of: Optional[float] = None,
min_items: int = None, min_items: Optional[int] = None,
max_items: int = None, max_items: Optional[int] = None,
min_length: int = None, min_length: Optional[int] = None,
max_length: int = None, max_length: Optional[int] = None,
allow_mutation: bool = True, allow_mutation: bool = True,
regex: str = None, regex: Optional[str] = None,
primary_key: bool = False, primary_key: bool = False,
foreign_key: Optional[Any] = None, foreign_key: Optional[Any] = None,
nullable: Union[bool, UndefinedType] = Undefined, nullable: Union[bool, UndefinedType] = Undefined,
index: Union[bool, UndefinedType] = Undefined, index: Union[bool, UndefinedType] = Undefined,
sa_column: Union[Column, UndefinedType] = Undefined, sa_column: Union[Column, UndefinedType] = Undefined, # type: ignore
sa_column_args: Union[Sequence[Any], UndefinedType] = Undefined, sa_column_args: Union[Sequence[Any], UndefinedType] = Undefined,
sa_column_kwargs: Union[Mapping[str, Any], UndefinedType] = Undefined, sa_column_kwargs: Union[Mapping[str, Any], UndefinedType] = Undefined,
schema_extra: Optional[Dict[str, Any]] = None, schema_extra: Optional[Dict[str, Any]] = None,
@ -195,7 +195,7 @@ def Relationship(
*, *,
back_populates: Optional[str] = None, back_populates: Optional[str] = None,
link_model: Optional[Any] = None, link_model: Optional[Any] = None,
sa_relationship: Optional[RelationshipProperty] = None, sa_relationship: Optional[RelationshipProperty] = None, # type: ignore
sa_relationship_args: Optional[Sequence[Any]] = None, sa_relationship_args: Optional[Sequence[Any]] = None,
sa_relationship_kwargs: Optional[Mapping[str, Any]] = None, sa_relationship_kwargs: Optional[Mapping[str, Any]] = None,
) -> Any: ) -> Any:
@ -217,19 +217,25 @@ class SQLModelMetaclass(ModelMetaclass, DeclarativeMeta):
# Replicate SQLAlchemy # Replicate SQLAlchemy
def __setattr__(cls, name: str, value: Any) -> None: def __setattr__(cls, name: str, value: Any) -> None:
if getattr(cls.__config__, "table", False): # type: ignore if getattr(cls.__config__, "table", False):
DeclarativeMeta.__setattr__(cls, name, value) DeclarativeMeta.__setattr__(cls, name, value)
else: else:
super().__setattr__(name, value) super().__setattr__(name, value)
def __delattr__(cls, name: str) -> None: def __delattr__(cls, name: str) -> None:
if getattr(cls.__config__, "table", False): # type: ignore if getattr(cls.__config__, "table", False):
DeclarativeMeta.__delattr__(cls, name) DeclarativeMeta.__delattr__(cls, name)
else: else:
super().__delattr__(name) super().__delattr__(name)
# From Pydantic # From Pydantic
def __new__(cls, name, bases, class_dict: dict, **kwargs) -> Any: def __new__(
cls,
name: str,
bases: Tuple[Type[Any], ...],
class_dict: Dict[str, Any],
**kwargs: Any,
) -> Any:
relationships: Dict[str, RelationshipInfo] = {} relationships: Dict[str, RelationshipInfo] = {}
dict_for_pydantic = {} dict_for_pydantic = {}
original_annotations = resolve_annotations( original_annotations = resolve_annotations(
@ -342,7 +348,7 @@ class SQLModelMetaclass(ModelMetaclass, DeclarativeMeta):
) )
relationship_to = temp_field.type_ relationship_to = temp_field.type_
if isinstance(temp_field.type_, ForwardRef): if isinstance(temp_field.type_, ForwardRef):
relationship_to = temp_field.type_.__forward_arg__ # type: ignore relationship_to = temp_field.type_.__forward_arg__
rel_kwargs: Dict[str, Any] = {} rel_kwargs: Dict[str, Any] = {}
if rel_info.back_populates: if rel_info.back_populates:
rel_kwargs["back_populates"] = rel_info.back_populates rel_kwargs["back_populates"] = rel_info.back_populates
@ -360,7 +366,7 @@ class SQLModelMetaclass(ModelMetaclass, DeclarativeMeta):
rel_args.extend(rel_info.sa_relationship_args) rel_args.extend(rel_info.sa_relationship_args)
if rel_info.sa_relationship_kwargs: if rel_info.sa_relationship_kwargs:
rel_kwargs.update(rel_info.sa_relationship_kwargs) rel_kwargs.update(rel_info.sa_relationship_kwargs)
rel_value: RelationshipProperty = relationship( rel_value: RelationshipProperty = relationship( # type: ignore
relationship_to, *rel_args, **rel_kwargs relationship_to, *rel_args, **rel_kwargs
) )
dict_used[rel_name] = rel_value dict_used[rel_name] = rel_value
@ -408,7 +414,7 @@ def get_sqlachemy_type(field: ModelField) -> Any:
return GUID return GUID
def get_column_from_field(field: ModelField) -> Column: def get_column_from_field(field: ModelField) -> Column: # type: ignore
sa_column = getattr(field.field_info, "sa_column", Undefined) sa_column = getattr(field.field_info, "sa_column", Undefined)
if isinstance(sa_column, Column): if isinstance(sa_column, Column):
return sa_column return sa_column
@ -440,10 +446,10 @@ def get_column_from_field(field: ModelField) -> Column:
kwargs["default"] = sa_default kwargs["default"] = sa_default
sa_column_args = getattr(field.field_info, "sa_column_args", Undefined) sa_column_args = getattr(field.field_info, "sa_column_args", Undefined)
if sa_column_args is not Undefined: if sa_column_args is not Undefined:
args.extend(list(cast(Sequence, sa_column_args))) args.extend(list(cast(Sequence[Any], sa_column_args)))
sa_column_kwargs = getattr(field.field_info, "sa_column_kwargs", Undefined) sa_column_kwargs = getattr(field.field_info, "sa_column_kwargs", Undefined)
if sa_column_kwargs is not Undefined: if sa_column_kwargs is not Undefined:
kwargs.update(cast(dict, sa_column_kwargs)) kwargs.update(cast(Dict[Any, Any], sa_column_kwargs))
return Column(sa_type, *args, **kwargs) return Column(sa_type, *args, **kwargs)
@ -452,24 +458,27 @@ class_registry = weakref.WeakValueDictionary() # type: ignore
default_registry = registry() default_registry = registry()
def _value_items_is_true(v) -> bool: def _value_items_is_true(v: Any) -> bool:
# Re-implement Pydantic's ValueItems.is_true() as it hasn't been released as of # Re-implement Pydantic's ValueItems.is_true() as it hasn't been released as of
# the current latest, Pydantic 1.8.2 # the current latest, Pydantic 1.8.2
return v is True or v is ... return v is True or v is ...
_TSQLModel = TypeVar("_TSQLModel", bound="SQLModel")
class SQLModel(BaseModel, metaclass=SQLModelMetaclass, registry=default_registry): class SQLModel(BaseModel, metaclass=SQLModelMetaclass, registry=default_registry):
# SQLAlchemy needs to set weakref(s), Pydantic will set the other slots values # SQLAlchemy needs to set weakref(s), Pydantic will set the other slots values
__slots__ = ("__weakref__",) __slots__ = ("__weakref__",)
__tablename__: ClassVar[Union[str, Callable[..., str]]] __tablename__: ClassVar[Union[str, Callable[..., str]]]
__sqlmodel_relationships__: ClassVar[Dict[str, RelationshipProperty]] __sqlmodel_relationships__: ClassVar[Dict[str, RelationshipProperty]] # type: ignore
__name__: ClassVar[str] __name__: ClassVar[str]
metadata: ClassVar[MetaData] metadata: ClassVar[MetaData]
class Config: class Config:
orm_mode = True orm_mode = True
def __new__(cls, *args, **kwargs) -> Any: def __new__(cls, *args: Any, **kwargs: Any) -> Any:
new_object = super().__new__(cls) new_object = super().__new__(cls)
# SQLAlchemy doesn't call __init__ on the base class # SQLAlchemy doesn't call __init__ on the base class
# Ref: https://docs.sqlalchemy.org/en/14/orm/constructors.html # Ref: https://docs.sqlalchemy.org/en/14/orm/constructors.html
@ -520,7 +529,9 @@ class SQLModel(BaseModel, metaclass=SQLModelMetaclass, registry=default_registry
super().__setattr__(name, value) super().__setattr__(name, value)
@classmethod @classmethod
def from_orm(cls: Type["SQLModel"], obj: Any, update: Dict[str, Any] = None): def from_orm(
cls: Type[_TSQLModel], obj: Any, update: Optional[Dict[str, Any]] = None
) -> _TSQLModel:
# Duplicated from Pydantic # Duplicated from Pydantic
if not cls.__config__.orm_mode: if not cls.__config__.orm_mode:
raise ConfigError( raise ConfigError(
@ -533,7 +544,7 @@ class SQLModel(BaseModel, metaclass=SQLModelMetaclass, registry=default_registry
# End SQLModel support dict # End SQLModel support dict
if not getattr(cls.__config__, "table", False): if not getattr(cls.__config__, "table", False):
# If not table, normal Pydantic code # If not table, normal Pydantic code
m = cls.__new__(cls) m: _TSQLModel = cls.__new__(cls)
else: else:
# If table, create the new instance normally to make SQLAlchemy create # If table, create the new instance normally to make SQLAlchemy create
# the _sa_instance_state attribute # the _sa_instance_state attribute
@ -554,7 +565,7 @@ class SQLModel(BaseModel, metaclass=SQLModelMetaclass, registry=default_registry
@classmethod @classmethod
def parse_obj( def parse_obj(
cls: Type["SQLModel"], obj: Any, update: Dict[str, Any] = None cls: Type["SQLModel"], obj: Any, update: Optional[Dict[str, Any]] = None
) -> "SQLModel": ) -> "SQLModel":
obj = cls._enforce_dict_if_root(obj) obj = cls._enforce_dict_if_root(obj)
# SQLModel, support update dict # SQLModel, support update dict

View File

@ -60,7 +60,7 @@ class Session(_Session):
results = super().execute( results = super().execute(
statement, statement,
params=params, params=params,
execution_options=execution_options, # type: ignore execution_options=execution_options,
bind_arguments=bind_arguments, bind_arguments=bind_arguments,
_parent_execute_state=_parent_execute_state, _parent_execute_state=_parent_execute_state,
_add_event=_add_event, _add_event=_add_event,
@ -74,7 +74,7 @@ class Session(_Session):
self, self,
statement: _Executable, statement: _Executable,
params: Optional[Union[Mapping[str, Any], Sequence[Mapping[str, Any]]]] = None, params: Optional[Union[Mapping[str, Any], Sequence[Mapping[str, Any]]]] = None,
execution_options: Mapping[str, Any] = util.EMPTY_DICT, execution_options: Optional[Mapping[str, Any]] = util.EMPTY_DICT,
bind_arguments: Optional[Mapping[str, Any]] = None, bind_arguments: Optional[Mapping[str, Any]] = None,
_parent_execute_state: Optional[Any] = None, _parent_execute_state: Optional[Any] = None,
_add_event: Optional[Any] = None, _add_event: Optional[Any] = None,
@ -101,7 +101,7 @@ class Session(_Session):
return super().execute( # type: ignore return super().execute( # type: ignore
statement, statement,
params=params, params=params,
execution_options=execution_options, # type: ignore execution_options=execution_options,
bind_arguments=bind_arguments, bind_arguments=bind_arguments,
_parent_execute_state=_parent_execute_state, _parent_execute_state=_parent_execute_state,
_add_event=_add_event, _add_event=_add_event,

View File

@ -6,6 +6,4 @@ _T = TypeVar("_T")
class Executable(_Executable, Generic[_T]): class Executable(_Executable, Generic[_T]):
def __init__(self, *args, **kwargs): pass
self.__dict__["_exec_options"] = kwargs.pop("_exec_options", None)
super(_Executable, self).__init__(*args, **kwargs)

View File

@ -45,10 +45,10 @@ else:
class GenericSelectMeta(GenericMeta, _Select.__class__): # type: ignore class GenericSelectMeta(GenericMeta, _Select.__class__): # type: ignore
pass pass
class _Py36Select(_Select, Generic[_TSelect], metaclass=GenericSelectMeta): # type: ignore class _Py36Select(_Select, Generic[_TSelect], metaclass=GenericSelectMeta):
pass pass
class _Py36SelectOfScalar(_Select, Generic[_TSelect], metaclass=GenericSelectMeta): # type: ignore class _Py36SelectOfScalar(_Select, Generic[_TSelect], metaclass=GenericSelectMeta):
pass pass
# Cast them for editors to work correctly, from several tricks tried, this works # Cast them for editors to work correctly, from several tricks tried, this works
@ -65,9 +65,9 @@ if TYPE_CHECKING: # pragma: no cover
_TScalar_0 = TypeVar( _TScalar_0 = TypeVar(
"_TScalar_0", "_TScalar_0",
Column, Column, # type: ignore
Sequence, Sequence, # type: ignore
Mapping, Mapping, # type: ignore
UUID, UUID,
datetime, datetime,
float, float,
@ -83,9 +83,9 @@ _TModel_0 = TypeVar("_TModel_0", bound="SQLModel")
_TScalar_1 = TypeVar( _TScalar_1 = TypeVar(
"_TScalar_1", "_TScalar_1",
Column, Column, # type: ignore
Sequence, Sequence, # type: ignore
Mapping, Mapping, # type: ignore
UUID, UUID,
datetime, datetime,
float, float,
@ -101,9 +101,9 @@ _TModel_1 = TypeVar("_TModel_1", bound="SQLModel")
_TScalar_2 = TypeVar( _TScalar_2 = TypeVar(
"_TScalar_2", "_TScalar_2",
Column, Column, # type: ignore
Sequence, Sequence, # type: ignore
Mapping, Mapping, # type: ignore
UUID, UUID,
datetime, datetime,
float, float,
@ -119,9 +119,9 @@ _TModel_2 = TypeVar("_TModel_2", bound="SQLModel")
_TScalar_3 = TypeVar( _TScalar_3 = TypeVar(
"_TScalar_3", "_TScalar_3",
Column, Column, # type: ignore
Sequence, Sequence, # type: ignore
Mapping, Mapping, # type: ignore
UUID, UUID,
datetime, datetime,
float, float,
@ -446,14 +446,14 @@ def select( # type: ignore
# Generated overloads end # Generated overloads end
def select(*entities: Any, **kw: Any) -> Union[Select, SelectOfScalar]: def select(*entities: Any, **kw: Any) -> Union[Select, SelectOfScalar]: # type: ignore
if len(entities) == 1: if len(entities) == 1:
return SelectOfScalar._create(*entities, **kw) # type: ignore return SelectOfScalar._create(*entities, **kw) # type: ignore
return Select._create(*entities, **kw) # type: ignore return Select._create(*entities, **kw) # type: ignore
# TODO: add several @overload from Python types to SQLAlchemy equivalents # TODO: add several @overload from Python types to SQLAlchemy equivalents
def col(column_expression: Any) -> ColumnClause: def col(column_expression: Any) -> ColumnClause: # type: ignore
if not isinstance(column_expression, (ColumnClause, Column, InstrumentedAttribute)): if not isinstance(column_expression, (ColumnClause, Column, InstrumentedAttribute)):
raise RuntimeError(f"Not a SQLAlchemy column: {column_expression}") raise RuntimeError(f"Not a SQLAlchemy column: {column_expression}")
return column_expression return column_expression

View File

@ -63,9 +63,9 @@ if TYPE_CHECKING: # pragma: no cover
{% for i in range(number_of_types) %} {% for i in range(number_of_types) %}
_TScalar_{{ i }} = TypeVar( _TScalar_{{ i }} = TypeVar(
"_TScalar_{{ i }}", "_TScalar_{{ i }}",
Column, Column, # type: ignore
Sequence, Sequence, # type: ignore
Mapping, Mapping, # type: ignore
UUID, UUID,
datetime, datetime,
float, float,
@ -106,14 +106,14 @@ def select( # type: ignore
# Generated overloads end # Generated overloads end
def select(*entities: Any, **kw: Any) -> Union[Select, SelectOfScalar]: def select(*entities: Any, **kw: Any) -> Union[Select, SelectOfScalar]: # type: ignore
if len(entities) == 1: if len(entities) == 1:
return SelectOfScalar._create(*entities, **kw) # type: ignore return SelectOfScalar._create(*entities, **kw) # type: ignore
return Select._create(*entities, **kw) return Select._create(*entities, **kw) # type: ignore
# TODO: add several @overload from Python types to SQLAlchemy equivalents # TODO: add several @overload from Python types to SQLAlchemy equivalents
def col(column_expression: Any) -> ColumnClause: def col(column_expression: Any) -> ColumnClause: # type: ignore
if not isinstance(column_expression, (ColumnClause, Column, InstrumentedAttribute)): if not isinstance(column_expression, (ColumnClause, Column, InstrumentedAttribute)):
raise RuntimeError(f"Not a SQLAlchemy column: {column_expression}") raise RuntimeError(f"Not a SQLAlchemy column: {column_expression}")
return column_expression return column_expression

View File

@ -1,13 +1,14 @@
import uuid import uuid
from typing import Any, cast from typing import Any, Optional, cast
from sqlalchemy import types from sqlalchemy import types
from sqlalchemy.dialects.postgresql import UUID from sqlalchemy.dialects.postgresql import UUID
from sqlalchemy.engine.interfaces import Dialect from sqlalchemy.engine.interfaces import Dialect
from sqlalchemy.sql.type_api import TypeEngine
from sqlalchemy.types import CHAR, TypeDecorator from sqlalchemy.types import CHAR, TypeDecorator
class AutoString(types.TypeDecorator): class AutoString(types.TypeDecorator): # type: ignore
impl = types.String impl = types.String
cache_ok = True cache_ok = True
@ -22,7 +23,7 @@ class AutoString(types.TypeDecorator):
# Reference form SQLAlchemy docs: https://docs.sqlalchemy.org/en/14/core/custom_types.html#backend-agnostic-guid-type # Reference form SQLAlchemy docs: https://docs.sqlalchemy.org/en/14/core/custom_types.html#backend-agnostic-guid-type
# with small modifications # with small modifications
class GUID(TypeDecorator): class GUID(TypeDecorator): # type: ignore
"""Platform-independent GUID type. """Platform-independent GUID type.
Uses PostgreSQL's UUID type, otherwise uses Uses PostgreSQL's UUID type, otherwise uses
@ -33,13 +34,13 @@ class GUID(TypeDecorator):
impl = CHAR impl = CHAR
cache_ok = True cache_ok = True
def load_dialect_impl(self, dialect): def load_dialect_impl(self, dialect: Dialect) -> TypeEngine: # type: ignore
if dialect.name == "postgresql": if dialect.name == "postgresql":
return dialect.type_descriptor(UUID()) return dialect.type_descriptor(UUID()) # type: ignore
else: else:
return dialect.type_descriptor(CHAR(32)) return dialect.type_descriptor(CHAR(32)) # type: ignore
def process_bind_param(self, value, dialect): def process_bind_param(self, value: Any, dialect: Dialect) -> Optional[str]:
if value is None: if value is None:
return value return value
elif dialect.name == "postgresql": elif dialect.name == "postgresql":
@ -51,10 +52,10 @@ class GUID(TypeDecorator):
# hexstring # hexstring
return f"{value.int:x}" return f"{value.int:x}"
def process_result_value(self, value, dialect): def process_result_value(self, value: Any, dialect: Dialect) -> Optional[uuid.UUID]:
if value is None: if value is None:
return value return value
else: else:
if not isinstance(value, uuid.UUID): if not isinstance(value, uuid.UUID):
value = uuid.UUID(value) value = uuid.UUID(value)
return value return cast(uuid.UUID, value)