diff --git a/sqlmodel/ext/asyncio/session.py b/sqlmodel/ext/asyncio/session.py index 80267b2..f500c44 100644 --- a/sqlmodel/ext/asyncio/session.py +++ b/sqlmodel/ext/asyncio/session.py @@ -1,17 +1,17 @@ -from typing import Any, Mapping, Optional, Sequence, TypeVar, Union +from typing import Any, Mapping, Optional, Sequence, TypeVar, Union, overload from sqlalchemy import util from sqlalchemy.ext.asyncio import AsyncSession as _AsyncSession from sqlalchemy.ext.asyncio import engine from sqlalchemy.ext.asyncio.engine import AsyncConnection, AsyncEngine from sqlalchemy.util.concurrency import greenlet_spawn -from sqlmodel.sql.base import Executable -from ...engine.result import ScalarResult +from ...engine.result import Result, ScalarResult from ...orm.session import Session -from ...sql.expression import Select +from ...sql.base import Executable +from ...sql.expression import Select, SelectOfScalar -_T = TypeVar("_T") +_TSelectParam = TypeVar("_TSelectParam") class AsyncSession(_AsyncSession): @@ -40,14 +40,46 @@ class AsyncSession(_AsyncSession): Session(bind=bind, binds=binds, **kw) # type: ignore ) + @overload async def exec( self, - statement: Union[Select[_T], Executable[_T]], + statement: Select[_TSelectParam], + *, + params: Optional[Union[Mapping[str, Any], Sequence[Mapping[str, Any]]]] = None, + execution_options: Mapping[str, Any] = util.EMPTY_DICT, + bind_arguments: Optional[Mapping[str, Any]] = None, + _parent_execute_state: Optional[Any] = None, + _add_event: Optional[Any] = None, + **kw: Any, + ) -> Result[_TSelectParam]: + ... + + @overload + async def exec( + self, + statement: SelectOfScalar[_TSelectParam], + *, + params: Optional[Union[Mapping[str, Any], Sequence[Mapping[str, Any]]]] = None, + execution_options: Mapping[str, Any] = util.EMPTY_DICT, + bind_arguments: Optional[Mapping[str, Any]] = None, + _parent_execute_state: Optional[Any] = None, + _add_event: Optional[Any] = None, + **kw: Any, + ) -> ScalarResult[_TSelectParam]: + ... + + async def exec( + self, + statement: Union[ + Select[_TSelectParam], + SelectOfScalar[_TSelectParam], + Executable[_TSelectParam], + ], params: Optional[Union[Mapping[str, Any], Sequence[Mapping[str, Any]]]] = None, execution_options: Mapping[Any, Any] = util.EMPTY_DICT, bind_arguments: Optional[Mapping[str, Any]] = None, **kw: Any, - ) -> ScalarResult[_T]: + ) -> Union[Result[_TSelectParam], ScalarResult[_TSelectParam]]: # TODO: the documentation says execution_options accepts a dict, but only # util.immutabledict has the union() method. Is this a bug in SQLAlchemy? execution_options = execution_options.union({"prebuffer_rows": True}) # type: ignore diff --git a/sqlmodel/orm/session.py b/sqlmodel/orm/session.py index 1692fdc..0c70c29 100644 --- a/sqlmodel/orm/session.py +++ b/sqlmodel/orm/session.py @@ -4,11 +4,11 @@ from sqlalchemy import util from sqlalchemy.orm import Query as _Query from sqlalchemy.orm import Session as _Session from sqlalchemy.sql.base import Executable as _Executable -from sqlmodel.sql.expression import Select, SelectOfScalar from typing_extensions import Literal from ..engine.result import Result, ScalarResult from ..sql.base import Executable +from ..sql.expression import Select, SelectOfScalar _TSelectParam = TypeVar("_TSelectParam")