🐛 Fix AsyncSession type annotations for exec() (#58)

Co-authored-by: Sebastián Ramírez <tiangolo@gmail.com>
This commit is contained in:
Arseny Boykov 2023-10-23 17:58:16 +03:00 committed by GitHub
parent b8996f0e62
commit 9732c5ac60
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 40 additions and 8 deletions

View File

@ -1,17 +1,17 @@
from typing import Any, Mapping, Optional, Sequence, TypeVar, Union
from typing import Any, Mapping, Optional, Sequence, TypeVar, Union, overload
from sqlalchemy import util
from sqlalchemy.ext.asyncio import AsyncSession as _AsyncSession
from sqlalchemy.ext.asyncio import engine
from sqlalchemy.ext.asyncio.engine import AsyncConnection, AsyncEngine
from sqlalchemy.util.concurrency import greenlet_spawn
from sqlmodel.sql.base import Executable
from ...engine.result import ScalarResult
from ...engine.result import Result, ScalarResult
from ...orm.session import Session
from ...sql.expression import Select
from ...sql.base import Executable
from ...sql.expression import Select, SelectOfScalar
_T = TypeVar("_T")
_TSelectParam = TypeVar("_TSelectParam")
class AsyncSession(_AsyncSession):
@ -40,14 +40,46 @@ class AsyncSession(_AsyncSession):
Session(bind=bind, binds=binds, **kw) # type: ignore
)
@overload
async def exec(
self,
statement: Union[Select[_T], Executable[_T]],
statement: Select[_TSelectParam],
*,
params: Optional[Union[Mapping[str, Any], Sequence[Mapping[str, Any]]]] = None,
execution_options: Mapping[str, Any] = util.EMPTY_DICT,
bind_arguments: Optional[Mapping[str, Any]] = None,
_parent_execute_state: Optional[Any] = None,
_add_event: Optional[Any] = None,
**kw: Any,
) -> Result[_TSelectParam]:
...
@overload
async def exec(
self,
statement: SelectOfScalar[_TSelectParam],
*,
params: Optional[Union[Mapping[str, Any], Sequence[Mapping[str, Any]]]] = None,
execution_options: Mapping[str, Any] = util.EMPTY_DICT,
bind_arguments: Optional[Mapping[str, Any]] = None,
_parent_execute_state: Optional[Any] = None,
_add_event: Optional[Any] = None,
**kw: Any,
) -> ScalarResult[_TSelectParam]:
...
async def exec(
self,
statement: Union[
Select[_TSelectParam],
SelectOfScalar[_TSelectParam],
Executable[_TSelectParam],
],
params: Optional[Union[Mapping[str, Any], Sequence[Mapping[str, Any]]]] = None,
execution_options: Mapping[Any, Any] = util.EMPTY_DICT,
bind_arguments: Optional[Mapping[str, Any]] = None,
**kw: Any,
) -> ScalarResult[_T]:
) -> Union[Result[_TSelectParam], ScalarResult[_TSelectParam]]:
# TODO: the documentation says execution_options accepts a dict, but only
# util.immutabledict has the union() method. Is this a bug in SQLAlchemy?
execution_options = execution_options.union({"prebuffer_rows": True}) # type: ignore

View File

@ -4,11 +4,11 @@ from sqlalchemy import util
from sqlalchemy.orm import Query as _Query
from sqlalchemy.orm import Session as _Session
from sqlalchemy.sql.base import Executable as _Executable
from sqlmodel.sql.expression import Select, SelectOfScalar
from typing_extensions import Literal
from ..engine.result import Result, ScalarResult
from ..sql.base import Executable
from ..sql.expression import Select, SelectOfScalar
_TSelectParam = TypeVar("_TSelectParam")