Upgrade SQLAlchemy to 2.0, including initial work by farahats9 (#700)

Co-authored-by: Mohamed Farahat <farahats9@yahoo.com>
Co-authored-by: Stefan Borer <stefan.borer@gmail.com>
Co-authored-by: Peter Landry <peter.landry@gmail.com>
This commit is contained in:
Sebastián Ramírez 2023-11-18 12:30:37 +01:00 committed by GitHub
parent 77c6fed305
commit 8ed856d322
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
24 changed files with 808 additions and 510 deletions

View File

@ -56,6 +56,8 @@ jobs:
if: steps.cache.outputs.cache-hit != 'true'
run: python -m poetry install
- name: Lint
# Do not run on Python 3.7 as mypy behaves differently
if: matrix.python-version != '3.7'
run: python -m poetry run bash scripts/lint.sh
- run: mkdir coverage
- name: Test

View File

@ -31,9 +31,8 @@ classifiers = [
[tool.poetry.dependencies]
python = "^3.7"
SQLAlchemy = ">=1.4.36,<2.0.0"
SQLAlchemy = ">=2.0.0,<2.1.0"
pydantic = "^1.9.0"
sqlalchemy2-stubs = {version = "*", allow-prereleases = true}
[tool.poetry.group.dev.dependencies]
pytest = "^7.0.1"
@ -45,9 +44,10 @@ pillow = "^9.3.0"
cairosvg = "^2.5.2"
mdx-include = "^1.4.1"
coverage = {extras = ["toml"], version = ">=6.2,<8.0"}
fastapi = "^0.68.1"
requests = "^2.26.0"
fastapi = "^0.103.2"
ruff = "^0.1.2"
# For FastAPI tests
httpx = "0.24.1"
[build-system]
requires = ["poetry-core"]
@ -80,6 +80,12 @@ strict = true
module = "sqlmodel.sql.expression"
warn_unused_ignores = false
[[tool.mypy.overrides]]
module = "docs_src.*"
disallow_incomplete_defs = false
disallow_untyped_defs = false
disallow_untyped_calls = false
[tool.ruff]
select = [
"E", # pycodestyle errors

View File

@ -34,9 +34,9 @@ for total_args in range(2, number_of_types + 1):
arg = Arg(name=f"entity_{i}", annotation=t_var)
ret_type = t_var
else:
t_type = f"_TModel_{i}"
t_var = f"Type[{t_type}]"
arg = Arg(name=f"entity_{i}", annotation=t_var)
t_type = f"_T{i}"
t_var = f"_TCCA[{t_type}]"
arg = Arg(name=f"__ent{i}", annotation=t_var)
ret_type = t_type
args.append(arg)
return_types.append(ret_type)

View File

@ -1,9 +1,12 @@
__version__ = "0.0.11"
# Re-export from SQLAlchemy
from sqlalchemy.engine import create_engine as create_engine
from sqlalchemy.engine import create_mock_engine as create_mock_engine
from sqlalchemy.engine import engine_from_config as engine_from_config
from sqlalchemy.inspection import inspect as inspect
from sqlalchemy.pool import QueuePool as QueuePool
from sqlalchemy.pool import StaticPool as StaticPool
from sqlalchemy.schema import BLANK_SCHEMA as BLANK_SCHEMA
from sqlalchemy.schema import DDL as DDL
from sqlalchemy.schema import CheckConstraint as CheckConstraint
@ -21,7 +24,6 @@ from sqlalchemy.schema import MetaData as MetaData
from sqlalchemy.schema import PrimaryKeyConstraint as PrimaryKeyConstraint
from sqlalchemy.schema import Sequence as Sequence
from sqlalchemy.schema import Table as Table
from sqlalchemy.schema import ThreadLocalMetaData as ThreadLocalMetaData
from sqlalchemy.schema import UniqueConstraint as UniqueConstraint
from sqlalchemy.sql import LABEL_STYLE_DEFAULT as LABEL_STYLE_DEFAULT
from sqlalchemy.sql import (
@ -32,26 +34,14 @@ from sqlalchemy.sql import (
LABEL_STYLE_TABLENAME_PLUS_COL as LABEL_STYLE_TABLENAME_PLUS_COL,
)
from sqlalchemy.sql import alias as alias
from sqlalchemy.sql import all_ as all_
from sqlalchemy.sql import and_ as and_
from sqlalchemy.sql import any_ as any_
from sqlalchemy.sql import asc as asc
from sqlalchemy.sql import between as between
from sqlalchemy.sql import bindparam as bindparam
from sqlalchemy.sql import case as case
from sqlalchemy.sql import cast as cast
from sqlalchemy.sql import collate as collate
from sqlalchemy.sql import column as column
from sqlalchemy.sql import delete as delete
from sqlalchemy.sql import desc as desc
from sqlalchemy.sql import distinct as distinct
from sqlalchemy.sql import except_ as except_
from sqlalchemy.sql import except_all as except_all
from sqlalchemy.sql import exists as exists
from sqlalchemy.sql import extract as extract
from sqlalchemy.sql import false as false
from sqlalchemy.sql import func as func
from sqlalchemy.sql import funcfilter as funcfilter
from sqlalchemy.sql import insert as insert
from sqlalchemy.sql import intersect as intersect
from sqlalchemy.sql import intersect_all as intersect_all
@ -61,28 +51,19 @@ from sqlalchemy.sql import lateral as lateral
from sqlalchemy.sql import literal as literal
from sqlalchemy.sql import literal_column as literal_column
from sqlalchemy.sql import modifier as modifier
from sqlalchemy.sql import not_ as not_
from sqlalchemy.sql import null as null
from sqlalchemy.sql import nulls_first as nulls_first
from sqlalchemy.sql import nulls_last as nulls_last
from sqlalchemy.sql import nullsfirst as nullsfirst
from sqlalchemy.sql import nullslast as nullslast
from sqlalchemy.sql import or_ as or_
from sqlalchemy.sql import outerjoin as outerjoin
from sqlalchemy.sql import outparam as outparam
from sqlalchemy.sql import over as over
from sqlalchemy.sql import subquery as subquery
from sqlalchemy.sql import table as table
from sqlalchemy.sql import tablesample as tablesample
from sqlalchemy.sql import text as text
from sqlalchemy.sql import true as true
from sqlalchemy.sql import tuple_ as tuple_
from sqlalchemy.sql import type_coerce as type_coerce
from sqlalchemy.sql import union as union
from sqlalchemy.sql import union_all as union_all
from sqlalchemy.sql import update as update
from sqlalchemy.sql import values as values
from sqlalchemy.sql import within_group as within_group
from sqlalchemy.types import ARRAY as ARRAY
from sqlalchemy.types import BIGINT as BIGINT
from sqlalchemy.types import BINARY as BINARY
@ -93,6 +74,8 @@ from sqlalchemy.types import CLOB as CLOB
from sqlalchemy.types import DATE as DATE
from sqlalchemy.types import DATETIME as DATETIME
from sqlalchemy.types import DECIMAL as DECIMAL
from sqlalchemy.types import DOUBLE as DOUBLE
from sqlalchemy.types import DOUBLE_PRECISION as DOUBLE_PRECISION
from sqlalchemy.types import FLOAT as FLOAT
from sqlalchemy.types import INT as INT
from sqlalchemy.types import INTEGER as INTEGER
@ -105,12 +88,14 @@ from sqlalchemy.types import SMALLINT as SMALLINT
from sqlalchemy.types import TEXT as TEXT
from sqlalchemy.types import TIME as TIME
from sqlalchemy.types import TIMESTAMP as TIMESTAMP
from sqlalchemy.types import UUID as UUID
from sqlalchemy.types import VARBINARY as VARBINARY
from sqlalchemy.types import VARCHAR as VARCHAR
from sqlalchemy.types import BigInteger as BigInteger
from sqlalchemy.types import Boolean as Boolean
from sqlalchemy.types import Date as Date
from sqlalchemy.types import DateTime as DateTime
from sqlalchemy.types import Double as Double
from sqlalchemy.types import Enum as Enum
from sqlalchemy.types import Float as Float
from sqlalchemy.types import Integer as Integer
@ -122,16 +107,38 @@ from sqlalchemy.types import SmallInteger as SmallInteger
from sqlalchemy.types import String as String
from sqlalchemy.types import Text as Text
from sqlalchemy.types import Time as Time
from sqlalchemy.types import TupleType as TupleType
from sqlalchemy.types import TypeDecorator as TypeDecorator
from sqlalchemy.types import Unicode as Unicode
from sqlalchemy.types import UnicodeText as UnicodeText
from sqlalchemy.types import Uuid as Uuid
# From SQLModel, modifications of SQLAlchemy or equivalents of Pydantic
from .engine.create import create_engine as create_engine
from .main import Field as Field
from .main import Relationship as Relationship
from .main import SQLModel as SQLModel
from .orm.session import Session as Session
from .sql.expression import all_ as all_
from .sql.expression import and_ as and_
from .sql.expression import any_ as any_
from .sql.expression import asc as asc
from .sql.expression import between as between
from .sql.expression import case as case
from .sql.expression import cast as cast
from .sql.expression import col as col
from .sql.expression import collate as collate
from .sql.expression import desc as desc
from .sql.expression import distinct as distinct
from .sql.expression import extract as extract
from .sql.expression import funcfilter as funcfilter
from .sql.expression import not_ as not_
from .sql.expression import nulls_first as nulls_first
from .sql.expression import nulls_last as nulls_last
from .sql.expression import or_ as or_
from .sql.expression import over as over
from .sql.expression import select as select
from .sql.expression import tuple_ as tuple_
from .sql.expression import type_coerce as type_coerce
from .sql.expression import within_group as within_group
from .sql.sqltypes import GUID as GUID
from .sql.sqltypes import AutoString as AutoString

View File

@ -1,139 +0,0 @@
import json
import sqlite3
from typing import Any, Callable, Dict, List, Optional, Type, Union
from sqlalchemy import create_engine as _create_engine
from sqlalchemy.engine.url import URL
from sqlalchemy.future import Engine as _FutureEngine
from sqlalchemy.pool import Pool
from typing_extensions import Literal, TypedDict
from ..default import Default, _DefaultPlaceholder
# Types defined in sqlalchemy2-stubs, but can't be imported, so re-define here
_Debug = Literal["debug"]
_IsolationLevel = Literal[
"SERIALIZABLE",
"REPEATABLE READ",
"READ COMMITTED",
"READ UNCOMMITTED",
"AUTOCOMMIT",
]
_ParamStyle = Literal["qmark", "numeric", "named", "format", "pyformat"]
_ResetOnReturn = Literal["rollback", "commit"]
class _SQLiteConnectArgs(TypedDict, total=False):
timeout: float
detect_types: Any
isolation_level: Optional[Literal["DEFERRED", "IMMEDIATE", "EXCLUSIVE"]]
check_same_thread: bool
factory: Type[sqlite3.Connection]
cached_statements: int
uri: bool
_ConnectArgs = Union[_SQLiteConnectArgs, Dict[str, Any]]
# Re-define create_engine to have by default future=True, and assume that's what is used
# Also show the default values used for each parameter, but don't set them unless
# explicitly passed as arguments by the user to prevent errors. E.g. SQLite doesn't
# support pool connection arguments.
def create_engine(
url: Union[str, URL],
*,
connect_args: _ConnectArgs = Default({}), # type: ignore
echo: Union[bool, _Debug] = Default(False),
echo_pool: Union[bool, _Debug] = Default(False),
enable_from_linting: bool = Default(True),
encoding: str = Default("utf-8"),
execution_options: Dict[Any, Any] = Default({}),
future: bool = True,
hide_parameters: bool = Default(False),
implicit_returning: bool = Default(True),
isolation_level: Optional[_IsolationLevel] = Default(None),
json_deserializer: Callable[..., Any] = Default(json.loads),
json_serializer: Callable[..., Any] = Default(json.dumps),
label_length: Optional[int] = Default(None),
logging_name: Optional[str] = Default(None),
max_identifier_length: Optional[int] = Default(None),
max_overflow: int = Default(10),
module: Optional[Any] = Default(None),
paramstyle: Optional[_ParamStyle] = Default(None),
pool: Optional[Pool] = Default(None),
poolclass: Optional[Type[Pool]] = Default(None),
pool_logging_name: Optional[str] = Default(None),
pool_pre_ping: bool = Default(False),
pool_size: int = Default(5),
pool_recycle: int = Default(-1),
pool_reset_on_return: Optional[_ResetOnReturn] = Default("rollback"),
pool_timeout: float = Default(30),
pool_use_lifo: bool = Default(False),
plugins: Optional[List[str]] = Default(None),
query_cache_size: Optional[int] = Default(None),
**kwargs: Any,
) -> _FutureEngine:
current_kwargs: Dict[str, Any] = {
"future": future,
}
if not isinstance(echo, _DefaultPlaceholder):
current_kwargs["echo"] = echo
if not isinstance(echo_pool, _DefaultPlaceholder):
current_kwargs["echo_pool"] = echo_pool
if not isinstance(enable_from_linting, _DefaultPlaceholder):
current_kwargs["enable_from_linting"] = enable_from_linting
if not isinstance(connect_args, _DefaultPlaceholder):
current_kwargs["connect_args"] = connect_args
if not isinstance(encoding, _DefaultPlaceholder):
current_kwargs["encoding"] = encoding
if not isinstance(execution_options, _DefaultPlaceholder):
current_kwargs["execution_options"] = execution_options
if not isinstance(hide_parameters, _DefaultPlaceholder):
current_kwargs["hide_parameters"] = hide_parameters
if not isinstance(implicit_returning, _DefaultPlaceholder):
current_kwargs["implicit_returning"] = implicit_returning
if not isinstance(isolation_level, _DefaultPlaceholder):
current_kwargs["isolation_level"] = isolation_level
if not isinstance(json_deserializer, _DefaultPlaceholder):
current_kwargs["json_deserializer"] = json_deserializer
if not isinstance(json_serializer, _DefaultPlaceholder):
current_kwargs["json_serializer"] = json_serializer
if not isinstance(label_length, _DefaultPlaceholder):
current_kwargs["label_length"] = label_length
if not isinstance(logging_name, _DefaultPlaceholder):
current_kwargs["logging_name"] = logging_name
if not isinstance(max_identifier_length, _DefaultPlaceholder):
current_kwargs["max_identifier_length"] = max_identifier_length
if not isinstance(max_overflow, _DefaultPlaceholder):
current_kwargs["max_overflow"] = max_overflow
if not isinstance(module, _DefaultPlaceholder):
current_kwargs["module"] = module
if not isinstance(paramstyle, _DefaultPlaceholder):
current_kwargs["paramstyle"] = paramstyle
if not isinstance(pool, _DefaultPlaceholder):
current_kwargs["pool"] = pool
if not isinstance(poolclass, _DefaultPlaceholder):
current_kwargs["poolclass"] = poolclass
if not isinstance(pool_logging_name, _DefaultPlaceholder):
current_kwargs["pool_logging_name"] = pool_logging_name
if not isinstance(pool_pre_ping, _DefaultPlaceholder):
current_kwargs["pool_pre_ping"] = pool_pre_ping
if not isinstance(pool_size, _DefaultPlaceholder):
current_kwargs["pool_size"] = pool_size
if not isinstance(pool_recycle, _DefaultPlaceholder):
current_kwargs["pool_recycle"] = pool_recycle
if not isinstance(pool_reset_on_return, _DefaultPlaceholder):
current_kwargs["pool_reset_on_return"] = pool_reset_on_return
if not isinstance(pool_timeout, _DefaultPlaceholder):
current_kwargs["pool_timeout"] = pool_timeout
if not isinstance(pool_use_lifo, _DefaultPlaceholder):
current_kwargs["pool_use_lifo"] = pool_use_lifo
if not isinstance(plugins, _DefaultPlaceholder):
current_kwargs["plugins"] = plugins
if not isinstance(query_cache_size, _DefaultPlaceholder):
current_kwargs["query_cache_size"] = query_cache_size
current_kwargs.update(kwargs)
return _create_engine(url, **current_kwargs) # type: ignore

View File

@ -1,79 +0,0 @@
from typing import Generic, Iterator, List, Optional, TypeVar
from sqlalchemy.engine.result import Result as _Result
from sqlalchemy.engine.result import ScalarResult as _ScalarResult
_T = TypeVar("_T")
class ScalarResult(_ScalarResult, Generic[_T]):
def all(self) -> List[_T]:
return super().all()
def partitions(self, size: Optional[int] = None) -> Iterator[List[_T]]:
return super().partitions(size)
def fetchall(self) -> List[_T]:
return super().fetchall()
def fetchmany(self, size: Optional[int] = None) -> List[_T]:
return super().fetchmany(size)
def __iter__(self) -> Iterator[_T]:
return super().__iter__()
def __next__(self) -> _T:
return super().__next__() # type: ignore
def first(self) -> Optional[_T]:
return super().first()
def one_or_none(self) -> Optional[_T]:
return super().one_or_none()
def one(self) -> _T:
return super().one() # type: ignore
class Result(_Result, Generic[_T]):
def scalars(self, index: int = 0) -> ScalarResult[_T]:
return super().scalars(index) # type: ignore
def __iter__(self) -> Iterator[_T]: # type: ignore
return super().__iter__() # type: ignore
def __next__(self) -> _T: # type: ignore
return super().__next__() # type: ignore
def partitions(self, size: Optional[int] = None) -> Iterator[List[_T]]: # type: ignore
return super().partitions(size) # type: ignore
def fetchall(self) -> List[_T]: # type: ignore
return super().fetchall() # type: ignore
def fetchone(self) -> Optional[_T]: # type: ignore
return super().fetchone() # type: ignore
def fetchmany(self, size: Optional[int] = None) -> List[_T]: # type: ignore
return super().fetchmany() # type: ignore
def all(self) -> List[_T]: # type: ignore
return super().all() # type: ignore
def first(self) -> Optional[_T]: # type: ignore
return super().first() # type: ignore
def one_or_none(self) -> Optional[_T]: # type: ignore
return super().one_or_none() # type: ignore
def scalar_one(self) -> _T:
return super().scalar_one() # type: ignore
def scalar_one_or_none(self) -> Optional[_T]:
return super().scalar_one_or_none()
def one(self) -> _T: # type: ignore
return super().one() # type: ignore
def scalar(self) -> Optional[_T]:
return super().scalar()

View File

@ -1,45 +1,38 @@
from typing import Any, Mapping, Optional, Sequence, TypeVar, Union, overload
from typing import (
Any,
Dict,
Mapping,
Optional,
Sequence,
Type,
TypeVar,
Union,
cast,
overload,
)
from sqlalchemy import util
from sqlalchemy.engine.interfaces import _CoreAnyExecuteParams
from sqlalchemy.engine.result import Result, ScalarResult, TupleResult
from sqlalchemy.ext.asyncio import AsyncSession as _AsyncSession
from sqlalchemy.ext.asyncio import engine
from sqlalchemy.ext.asyncio.engine import AsyncConnection, AsyncEngine
from sqlalchemy.ext.asyncio.result import _ensure_sync_result
from sqlalchemy.ext.asyncio.session import _EXECUTE_OPTIONS
from sqlalchemy.orm._typing import OrmExecuteOptionsParameter
from sqlalchemy.sql.base import Executable as _Executable
from sqlalchemy.util.concurrency import greenlet_spawn
from typing_extensions import deprecated
from ...engine.result import Result, ScalarResult
from ...orm.session import Session
from ...sql.base import Executable
from ...sql.expression import Select, SelectOfScalar
_TSelectParam = TypeVar("_TSelectParam")
_TSelectParam = TypeVar("_TSelectParam", bound=Any)
class AsyncSession(_AsyncSession):
sync_session_class: Type[Session] = Session
sync_session: Session
def __init__(
self,
bind: Optional[Union[AsyncConnection, AsyncEngine]] = None,
binds: Optional[Mapping[object, Union[AsyncConnection, AsyncEngine]]] = None,
**kw: Any,
):
# All the same code of the original AsyncSession
kw["future"] = True
if bind:
self.bind = bind
bind = engine._get_sync_engine_or_connection(bind) # type: ignore
if binds:
self.binds = binds
binds = {
key: engine._get_sync_engine_or_connection(b) # type: ignore
for key, b in binds.items()
}
self.sync_session = self._proxied = self._assign_proxied( # type: ignore
Session(bind=bind, binds=binds, **kw) # type: ignore
)
@overload
async def exec(
self,
@ -47,11 +40,10 @@ class AsyncSession(_AsyncSession):
*,
params: Optional[Union[Mapping[str, Any], Sequence[Mapping[str, Any]]]] = None,
execution_options: Mapping[str, Any] = util.EMPTY_DICT,
bind_arguments: Optional[Mapping[str, Any]] = None,
bind_arguments: Optional[Dict[str, Any]] = None,
_parent_execute_state: Optional[Any] = None,
_add_event: Optional[Any] = None,
**kw: Any,
) -> Result[_TSelectParam]:
) -> TupleResult[_TSelectParam]:
...
@overload
@ -61,10 +53,9 @@ class AsyncSession(_AsyncSession):
*,
params: Optional[Union[Mapping[str, Any], Sequence[Mapping[str, Any]]]] = None,
execution_options: Mapping[str, Any] = util.EMPTY_DICT,
bind_arguments: Optional[Mapping[str, Any]] = None,
bind_arguments: Optional[Dict[str, Any]] = None,
_parent_execute_state: Optional[Any] = None,
_add_event: Optional[Any] = None,
**kw: Any,
) -> ScalarResult[_TSelectParam]:
...
@ -75,20 +66,87 @@ class AsyncSession(_AsyncSession):
SelectOfScalar[_TSelectParam],
Executable[_TSelectParam],
],
*,
params: Optional[Union[Mapping[str, Any], Sequence[Mapping[str, Any]]]] = None,
execution_options: Mapping[Any, Any] = util.EMPTY_DICT,
bind_arguments: Optional[Mapping[str, Any]] = None,
**kw: Any,
) -> Union[Result[_TSelectParam], ScalarResult[_TSelectParam]]:
# TODO: the documentation says execution_options accepts a dict, but only
# util.immutabledict has the union() method. Is this a bug in SQLAlchemy?
execution_options = execution_options.union({"prebuffer_rows": True}) # type: ignore
execution_options: Mapping[str, Any] = util.EMPTY_DICT,
bind_arguments: Optional[Dict[str, Any]] = None,
_parent_execute_state: Optional[Any] = None,
_add_event: Optional[Any] = None,
) -> Union[TupleResult[_TSelectParam], ScalarResult[_TSelectParam]]:
if execution_options:
execution_options = util.immutabledict(execution_options).union(
_EXECUTE_OPTIONS
)
else:
execution_options = _EXECUTE_OPTIONS
return await greenlet_spawn(
result = await greenlet_spawn(
self.sync_session.exec,
statement,
params=params,
execution_options=execution_options,
bind_arguments=bind_arguments,
**kw,
_parent_execute_state=_parent_execute_state,
_add_event=_add_event,
)
result_value = await _ensure_sync_result(
cast(Result[_TSelectParam], result), self.exec
)
return result_value # type: ignore
@deprecated(
"""
🚨 You probably want to use `session.exec()` instead of `session.execute()`.
This is the original SQLAlchemy `session.execute()` method that returns objects
of type `Row`, and that you have to call `scalars()` to get the model objects.
For example:
```Python
heroes = await session.execute(select(Hero)).scalars().all()
```
instead you could use `exec()`:
```Python
heroes = await session.exec(select(Hero)).all()
```
"""
)
async def execute( # type: ignore
self,
statement: _Executable,
params: Optional[_CoreAnyExecuteParams] = None,
*,
execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT,
bind_arguments: Optional[Dict[str, Any]] = None,
_parent_execute_state: Optional[Any] = None,
_add_event: Optional[Any] = None,
) -> Result[Any]:
"""
🚨 You probably want to use `session.exec()` instead of `session.execute()`.
This is the original SQLAlchemy `session.execute()` method that returns objects
of type `Row`, and that you have to call `scalars()` to get the model objects.
For example:
```Python
heroes = await session.execute(select(Hero)).scalars().all()
```
instead you could use `exec()`:
```Python
heroes = await session.exec(select(Hero)).all()
```
"""
return await super().execute(
statement,
params=params,
execution_options=execution_options,
bind_arguments=bind_arguments,
_parent_execute_state=_parent_execute_state,
_add_event=_add_event,
)

View File

@ -45,12 +45,19 @@ from sqlalchemy import (
inspect,
)
from sqlalchemy import Enum as sa_Enum
from sqlalchemy.orm import RelationshipProperty, declared_attr, registry, relationship
from sqlalchemy.orm import (
Mapped,
RelationshipProperty,
declared_attr,
registry,
relationship,
)
from sqlalchemy.orm.attributes import set_attribute
from sqlalchemy.orm.decl_api import DeclarativeMeta
from sqlalchemy.orm.instrumentation import is_instrumented
from sqlalchemy.sql.schema import MetaData
from sqlalchemy.sql.sqltypes import LargeBinary, Time
from typing_extensions import get_origin
from .sql.sqltypes import GUID, AutoString
@ -483,7 +490,16 @@ class SQLModelMetaclass(ModelMetaclass, DeclarativeMeta):
# over anything else, use that and continue with the next attribute
setattr(cls, rel_name, rel_info.sa_relationship) # Fix #315
continue
ann = cls.__annotations__[rel_name]
raw_ann = cls.__annotations__[rel_name]
origin = get_origin(raw_ann)
if origin is Mapped:
ann = raw_ann.__args__[0]
else:
ann = raw_ann
# Plain forward references, for models not yet defined, are not
# handled well by SQLAlchemy without Mapped, so, wrap the
# annotations in Mapped here
cls.__annotations__[rel_name] = Mapped[ann] # type: ignore[valid-type]
temp_field = ModelField.infer(
name=rel_name,
value=rel_info,
@ -511,9 +527,7 @@ class SQLModelMetaclass(ModelMetaclass, DeclarativeMeta):
rel_args.extend(rel_info.sa_relationship_args)
if rel_info.sa_relationship_kwargs:
rel_kwargs.update(rel_info.sa_relationship_kwargs)
rel_value: RelationshipProperty = relationship( # type: ignore
relationship_to, *rel_args, **rel_kwargs
)
rel_value = relationship(relationship_to, *rel_args, **rel_kwargs)
setattr(cls, rel_name, rel_value) # Fix #315
# SQLAlchemy no longer uses dict_
# Ref: https://github.com/sqlalchemy/sqlalchemy/commit/428ea01f00a9cc7f85e435018565eb6da7af1b77
@ -642,6 +656,7 @@ class SQLModel(BaseModel, metaclass=SQLModelMetaclass, registry=default_registry
__sqlmodel_relationships__: ClassVar[Dict[str, RelationshipProperty]] # type: ignore
__name__: ClassVar[str]
metadata: ClassVar[MetaData]
__allow_unmapped__ = True # https://docs.sqlalchemy.org/en/20/changelog/migration_20.html#migration-20-step-six
class Config:
orm_mode = True
@ -685,7 +700,7 @@ class SQLModel(BaseModel, metaclass=SQLModelMetaclass, registry=default_registry
return
else:
# Set in SQLAlchemy, before Pydantic to trigger events and updates
if getattr(self.__config__, "table", False) and is_instrumented(self, name):
if getattr(self.__config__, "table", False) and is_instrumented(self, name): # type: ignore
set_attribute(self, name, value)
# Set in Pydantic model to trigger possible validation changes, only for
# non relationship values

View File

@ -1,16 +1,27 @@
from typing import Any, Mapping, Optional, Sequence, Type, TypeVar, Union, overload
from typing import (
Any,
Dict,
Mapping,
Optional,
Sequence,
TypeVar,
Union,
overload,
)
from sqlalchemy import util
from sqlalchemy.engine.interfaces import _CoreAnyExecuteParams
from sqlalchemy.engine.result import Result, ScalarResult, TupleResult
from sqlalchemy.orm import Query as _Query
from sqlalchemy.orm import Session as _Session
from sqlalchemy.orm._typing import OrmExecuteOptionsParameter
from sqlalchemy.sql._typing import _ColumnsClauseArgument
from sqlalchemy.sql.base import Executable as _Executable
from typing_extensions import Literal
from sqlmodel.sql.base import Executable
from sqlmodel.sql.expression import Select, SelectOfScalar
from typing_extensions import deprecated
from ..engine.result import Result, ScalarResult
from ..sql.base import Executable
from ..sql.expression import Select, SelectOfScalar
_TSelectParam = TypeVar("_TSelectParam")
_TSelectParam = TypeVar("_TSelectParam", bound=Any)
class Session(_Session):
@ -21,11 +32,10 @@ class Session(_Session):
*,
params: Optional[Union[Mapping[str, Any], Sequence[Mapping[str, Any]]]] = None,
execution_options: Mapping[str, Any] = util.EMPTY_DICT,
bind_arguments: Optional[Mapping[str, Any]] = None,
bind_arguments: Optional[Dict[str, Any]] = None,
_parent_execute_state: Optional[Any] = None,
_add_event: Optional[Any] = None,
**kw: Any,
) -> Result[_TSelectParam]:
) -> TupleResult[_TSelectParam]:
...
@overload
@ -35,10 +45,9 @@ class Session(_Session):
*,
params: Optional[Union[Mapping[str, Any], Sequence[Mapping[str, Any]]]] = None,
execution_options: Mapping[str, Any] = util.EMPTY_DICT,
bind_arguments: Optional[Mapping[str, Any]] = None,
bind_arguments: Optional[Dict[str, Any]] = None,
_parent_execute_state: Optional[Any] = None,
_add_event: Optional[Any] = None,
**kw: Any,
) -> ScalarResult[_TSelectParam]:
...
@ -52,11 +61,10 @@ class Session(_Session):
*,
params: Optional[Union[Mapping[str, Any], Sequence[Mapping[str, Any]]]] = None,
execution_options: Mapping[str, Any] = util.EMPTY_DICT,
bind_arguments: Optional[Mapping[str, Any]] = None,
bind_arguments: Optional[Dict[str, Any]] = None,
_parent_execute_state: Optional[Any] = None,
_add_event: Optional[Any] = None,
**kw: Any,
) -> Union[Result[_TSelectParam], ScalarResult[_TSelectParam]]:
) -> Union[TupleResult[_TSelectParam], ScalarResult[_TSelectParam]]:
results = super().execute(
statement,
params=params,
@ -64,21 +72,40 @@ class Session(_Session):
bind_arguments=bind_arguments,
_parent_execute_state=_parent_execute_state,
_add_event=_add_event,
**kw,
)
if isinstance(statement, SelectOfScalar):
return results.scalars() # type: ignore
return results.scalars()
return results # type: ignore
def execute(
@deprecated(
"""
🚨 You probably want to use `session.exec()` instead of `session.execute()`.
This is the original SQLAlchemy `session.execute()` method that returns objects
of type `Row`, and that you have to call `scalars()` to get the model objects.
For example:
```Python
heroes = session.execute(select(Hero)).scalars().all()
```
instead you could use `exec()`:
```Python
heroes = session.exec(select(Hero)).all()
```
"""
)
def execute( # type: ignore
self,
statement: _Executable,
params: Optional[Union[Mapping[str, Any], Sequence[Mapping[str, Any]]]] = None,
execution_options: Optional[Mapping[str, Any]] = util.EMPTY_DICT,
bind_arguments: Optional[Mapping[str, Any]] = None,
params: Optional[_CoreAnyExecuteParams] = None,
*,
execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT,
bind_arguments: Optional[Dict[str, Any]] = None,
_parent_execute_state: Optional[Any] = None,
_add_event: Optional[Any] = None,
**kw: Any,
) -> Result[Any]:
"""
🚨 You probably want to use `session.exec()` instead of `session.execute()`.
@ -98,17 +125,29 @@ class Session(_Session):
heroes = session.exec(select(Hero)).all()
```
"""
return super().execute( # type: ignore
return super().execute(
statement,
params=params,
execution_options=execution_options,
bind_arguments=bind_arguments,
_parent_execute_state=_parent_execute_state,
_add_event=_add_event,
**kw,
)
def query(self, *entities: Any, **kwargs: Any) -> "_Query[Any]":
@deprecated(
"""
🚨 You probably want to use `session.exec()` instead of `session.query()`.
`session.exec()` is SQLModel's own short version with increased type
annotations.
Or otherwise you might want to use `session.execute()` instead of
`session.query()`.
"""
)
def query( # type: ignore
self, *entities: _ColumnsClauseArgument[Any], **kwargs: Any
) -> _Query[Any]:
"""
🚨 You probably want to use `session.exec()` instead of `session.query()`.
@ -119,23 +158,3 @@ class Session(_Session):
`session.query()`.
"""
return super().query(*entities, **kwargs)
def get(
self,
entity: Type[_TSelectParam],
ident: Any,
options: Optional[Sequence[Any]] = None,
populate_existing: bool = False,
with_for_update: Optional[Union[Literal[True], Mapping[str, Any]]] = None,
identity_token: Optional[Any] = None,
execution_options: Optional[Mapping[Any, Any]] = util.EMPTY_DICT,
) -> Optional[_TSelectParam]:
return super().get(
entity,
ident,
options=options,
populate_existing=populate_existing,
with_for_update=with_for_update,
identity_token=identity_token,
execution_options=execution_options,
)

View File

@ -2,10 +2,10 @@
from datetime import datetime
from typing import (
TYPE_CHECKING,
Any,
Generic,
Iterable,
Mapping,
Optional,
Sequence,
Tuple,
Type,
@ -15,15 +15,223 @@ from typing import (
)
from uuid import UUID
from sqlalchemy import Column
from sqlalchemy.orm import InstrumentedAttribute
from sqlalchemy.sql.elements import ColumnClause
import sqlalchemy
from sqlalchemy import (
Column,
ColumnElement,
Extract,
FunctionElement,
FunctionFilter,
Label,
Over,
TypeCoerce,
WithinGroup,
)
from sqlalchemy.orm import InstrumentedAttribute, Mapped
from sqlalchemy.sql._typing import (
_ColumnExpressionArgument,
_ColumnExpressionOrLiteralArgument,
_ColumnExpressionOrStrLabelArgument,
)
from sqlalchemy.sql.elements import (
BinaryExpression,
Case,
Cast,
CollectionAggregate,
ColumnClause,
SQLCoreOperations,
TryCast,
UnaryExpression,
)
from sqlalchemy.sql.expression import Select as _Select
from sqlalchemy.sql.roles import TypedColumnsClauseRole
from sqlalchemy.sql.type_api import TypeEngine
from typing_extensions import Literal, Self
_TSelect = TypeVar("_TSelect")
_T = TypeVar("_T")
_TypeEngineArgument = Union[Type[TypeEngine[_T]], TypeEngine[_T]]
# Redefine operatos that would only take a column expresion to also take the (virtual)
# types of Pydantic models, e.g. str instead of only Mapped[str].
class Select(_Select, Generic[_TSelect]):
def all_(expr: Union[_ColumnExpressionArgument[_T], _T]) -> CollectionAggregate[bool]:
return sqlalchemy.all_(expr) # type: ignore[arg-type]
def and_(
initial_clause: Union[Literal[True], _ColumnExpressionArgument[bool], bool],
*clauses: Union[_ColumnExpressionArgument[bool], bool],
) -> ColumnElement[bool]:
return sqlalchemy.and_(initial_clause, *clauses) # type: ignore[arg-type]
def any_(expr: Union[_ColumnExpressionArgument[_T], _T]) -> CollectionAggregate[bool]:
return sqlalchemy.any_(expr) # type: ignore[arg-type]
def asc(
column: Union[_ColumnExpressionOrStrLabelArgument[_T], _T],
) -> UnaryExpression[_T]:
return sqlalchemy.asc(column) # type: ignore[arg-type]
def collate(
expression: Union[_ColumnExpressionArgument[str], str], collation: str
) -> BinaryExpression[str]:
return sqlalchemy.collate(expression, collation) # type: ignore[arg-type]
def between(
expr: Union[_ColumnExpressionOrLiteralArgument[_T], _T],
lower_bound: Any,
upper_bound: Any,
symmetric: bool = False,
) -> BinaryExpression[bool]:
return sqlalchemy.between(expr, lower_bound, upper_bound, symmetric=symmetric) # type: ignore[arg-type]
def not_(clause: Union[_ColumnExpressionArgument[_T], _T]) -> ColumnElement[_T]:
return sqlalchemy.not_(clause) # type: ignore[arg-type]
def case(
*whens: Union[
Tuple[Union[_ColumnExpressionArgument[bool], bool], Any], Mapping[Any, Any]
],
value: Optional[Any] = None,
else_: Optional[Any] = None,
) -> Case[Any]:
return sqlalchemy.case(*whens, value=value, else_=else_) # type: ignore[arg-type]
def cast(
expression: Union[_ColumnExpressionOrLiteralArgument[Any], Any],
type_: "_TypeEngineArgument[_T]",
) -> Cast[_T]:
return sqlalchemy.cast(expression, type_) # type: ignore[arg-type]
def try_cast(
expression: Union[_ColumnExpressionOrLiteralArgument[Any], Any],
type_: "_TypeEngineArgument[_T]",
) -> TryCast[_T]:
return sqlalchemy.try_cast(expression, type_) # type: ignore[arg-type]
def desc(
column: Union[_ColumnExpressionOrStrLabelArgument[_T], _T],
) -> UnaryExpression[_T]:
return sqlalchemy.desc(column) # type: ignore[arg-type]
def distinct(expr: Union[_ColumnExpressionArgument[_T], _T]) -> UnaryExpression[_T]:
return sqlalchemy.distinct(expr) # type: ignore[arg-type]
def bitwise_not(expr: Union[_ColumnExpressionArgument[_T], _T]) -> UnaryExpression[_T]:
return sqlalchemy.bitwise_not(expr) # type: ignore[arg-type]
def extract(field: str, expr: Union[_ColumnExpressionArgument[Any], Any]) -> Extract:
return sqlalchemy.extract(field, expr) # type: ignore[arg-type]
def funcfilter(
func: FunctionElement[_T], *criterion: Union[_ColumnExpressionArgument[bool], bool]
) -> FunctionFilter[_T]:
return sqlalchemy.funcfilter(func, *criterion) # type: ignore[arg-type]
def label(
name: str,
element: Union[_ColumnExpressionArgument[_T], _T],
type_: Optional["_TypeEngineArgument[_T]"] = None,
) -> Label[_T]:
return sqlalchemy.label(name, element, type_=type_) # type: ignore[arg-type]
def nulls_first(
column: Union[_ColumnExpressionArgument[_T], _T]
) -> UnaryExpression[_T]:
return sqlalchemy.nulls_first(column) # type: ignore[arg-type]
def nulls_last(column: Union[_ColumnExpressionArgument[_T], _T]) -> UnaryExpression[_T]:
return sqlalchemy.nulls_last(column) # type: ignore[arg-type]
def or_( # type: ignore[empty-body]
initial_clause: Union[Literal[False], _ColumnExpressionArgument[bool], bool],
*clauses: Union[_ColumnExpressionArgument[bool], bool],
) -> ColumnElement[bool]:
return sqlalchemy.or_(initial_clause, *clauses) # type: ignore[arg-type]
def over(
element: FunctionElement[_T],
partition_by: Optional[
Union[
Iterable[Union[_ColumnExpressionArgument[Any], Any]],
_ColumnExpressionArgument[Any],
Any,
]
] = None,
order_by: Optional[
Union[
Iterable[Union[_ColumnExpressionArgument[Any], Any]],
_ColumnExpressionArgument[Any],
Any,
]
] = None,
range_: Optional[Tuple[Optional[int], Optional[int]]] = None,
rows: Optional[Tuple[Optional[int], Optional[int]]] = None,
) -> Over[_T]:
return sqlalchemy.over(
element, partition_by=partition_by, order_by=order_by, range_=range_, rows=rows
) # type: ignore[arg-type]
def tuple_(
*clauses: Union[_ColumnExpressionArgument[Any], Any],
types: Optional[Sequence["_TypeEngineArgument[Any]"]] = None,
) -> Tuple[Any, ...]:
return sqlalchemy.tuple_(*clauses, types=types) # type: ignore[return-value]
def type_coerce(
expression: Union[_ColumnExpressionOrLiteralArgument[Any], Any],
type_: "_TypeEngineArgument[_T]",
) -> TypeCoerce[_T]:
return sqlalchemy.type_coerce(expression, type_) # type: ignore[arg-type]
def within_group(
element: FunctionElement[_T], *order_by: Union[_ColumnExpressionArgument[Any], Any]
) -> WithinGroup[_T]:
return sqlalchemy.within_group(element, *order_by) # type: ignore[arg-type]
# Separate this class in SelectBase, Select, and SelectOfScalar so that they can share
# where and having without having type overlap incompatibility in session.exec().
class SelectBase(_Select[Tuple[_T]]):
inherit_cache = True
def where(self, *whereclause: Union[_ColumnExpressionArgument[bool], bool]) -> Self:
"""Return a new `Select` construct with the given expression added to
its `WHERE` clause, joined to the existing clause via `AND`, if any.
"""
return super().where(*whereclause) # type: ignore[arg-type]
def having(self, *having: Union[_ColumnExpressionArgument[bool], bool]) -> Self:
"""Return a new `Select` construct with the given expression added to
its `HAVING` clause, joined to the existing clause via `AND`, if any.
"""
return super().having(*having) # type: ignore[arg-type]
class Select(SelectBase[_T]):
inherit_cache = True
@ -31,12 +239,15 @@ class Select(_Select, Generic[_TSelect]):
# purpose. This is the same as a normal SQLAlchemy Select class where there's only one
# entity, so the result will be converted to a scalar by default. This way writing
# for loops on the results will feel natural.
class SelectOfScalar(_Select, Generic[_TSelect]):
class SelectOfScalar(SelectBase[_T]):
inherit_cache = True
if TYPE_CHECKING: # pragma: no cover
from ..main import SQLModel
_TCCA = Union[
TypedColumnsClauseRole[_T],
SQLCoreOperations[_T],
Type[_T],
]
# Generated TypeVars start
@ -56,7 +267,7 @@ _TScalar_0 = TypeVar(
None,
)
_TModel_0 = TypeVar("_TModel_0", bound="SQLModel")
_T0 = TypeVar("_T0")
_TScalar_1 = TypeVar(
@ -74,7 +285,7 @@ _TScalar_1 = TypeVar(
None,
)
_TModel_1 = TypeVar("_TModel_1", bound="SQLModel")
_T1 = TypeVar("_T1")
_TScalar_2 = TypeVar(
@ -92,7 +303,7 @@ _TScalar_2 = TypeVar(
None,
)
_TModel_2 = TypeVar("_TModel_2", bound="SQLModel")
_T2 = TypeVar("_T2")
_TScalar_3 = TypeVar(
@ -110,19 +321,19 @@ _TScalar_3 = TypeVar(
None,
)
_TModel_3 = TypeVar("_TModel_3", bound="SQLModel")
_T3 = TypeVar("_T3")
# Generated TypeVars end
@overload
def select(entity_0: _TScalar_0, **kw: Any) -> SelectOfScalar[_TScalar_0]: # type: ignore
def select(__ent0: _TScalar_0) -> SelectOfScalar[_TScalar_0]: # type: ignore
...
@overload
def select(entity_0: Type[_TModel_0], **kw: Any) -> SelectOfScalar[_TModel_0]: # type: ignore
def select(__ent0: _TCCA[_T0]) -> SelectOfScalar[_T0]:
...
@ -133,7 +344,6 @@ def select(entity_0: Type[_TModel_0], **kw: Any) -> SelectOfScalar[_TModel_0]:
def select( # type: ignore
entity_0: _TScalar_0,
entity_1: _TScalar_1,
**kw: Any,
) -> Select[Tuple[_TScalar_0, _TScalar_1]]:
...
@ -141,27 +351,24 @@ def select( # type: ignore
@overload
def select( # type: ignore
entity_0: _TScalar_0,
entity_1: Type[_TModel_1],
**kw: Any,
) -> Select[Tuple[_TScalar_0, _TModel_1]]:
__ent1: _TCCA[_T1],
) -> Select[Tuple[_TScalar_0, _T1]]:
...
@overload
def select( # type: ignore
entity_0: Type[_TModel_0],
__ent0: _TCCA[_T0],
entity_1: _TScalar_1,
**kw: Any,
) -> Select[Tuple[_TModel_0, _TScalar_1]]:
) -> Select[Tuple[_T0, _TScalar_1]]:
...
@overload
def select( # type: ignore
entity_0: Type[_TModel_0],
entity_1: Type[_TModel_1],
**kw: Any,
) -> Select[Tuple[_TModel_0, _TModel_1]]:
__ent0: _TCCA[_T0],
__ent1: _TCCA[_T1],
) -> Select[Tuple[_T0, _T1]]:
...
@ -170,7 +377,6 @@ def select( # type: ignore
entity_0: _TScalar_0,
entity_1: _TScalar_1,
entity_2: _TScalar_2,
**kw: Any,
) -> Select[Tuple[_TScalar_0, _TScalar_1, _TScalar_2]]:
...
@ -179,69 +385,62 @@ def select( # type: ignore
def select( # type: ignore
entity_0: _TScalar_0,
entity_1: _TScalar_1,
entity_2: Type[_TModel_2],
**kw: Any,
) -> Select[Tuple[_TScalar_0, _TScalar_1, _TModel_2]]:
__ent2: _TCCA[_T2],
) -> Select[Tuple[_TScalar_0, _TScalar_1, _T2]]:
...
@overload
def select( # type: ignore
entity_0: _TScalar_0,
entity_1: Type[_TModel_1],
__ent1: _TCCA[_T1],
entity_2: _TScalar_2,
**kw: Any,
) -> Select[Tuple[_TScalar_0, _TModel_1, _TScalar_2]]:
) -> Select[Tuple[_TScalar_0, _T1, _TScalar_2]]:
...
@overload
def select( # type: ignore
entity_0: _TScalar_0,
entity_1: Type[_TModel_1],
entity_2: Type[_TModel_2],
**kw: Any,
) -> Select[Tuple[_TScalar_0, _TModel_1, _TModel_2]]:
__ent1: _TCCA[_T1],
__ent2: _TCCA[_T2],
) -> Select[Tuple[_TScalar_0, _T1, _T2]]:
...
@overload
def select( # type: ignore
entity_0: Type[_TModel_0],
__ent0: _TCCA[_T0],
entity_1: _TScalar_1,
entity_2: _TScalar_2,
**kw: Any,
) -> Select[Tuple[_TModel_0, _TScalar_1, _TScalar_2]]:
) -> Select[Tuple[_T0, _TScalar_1, _TScalar_2]]:
...
@overload
def select( # type: ignore
entity_0: Type[_TModel_0],
__ent0: _TCCA[_T0],
entity_1: _TScalar_1,
entity_2: Type[_TModel_2],
**kw: Any,
) -> Select[Tuple[_TModel_0, _TScalar_1, _TModel_2]]:
__ent2: _TCCA[_T2],
) -> Select[Tuple[_T0, _TScalar_1, _T2]]:
...
@overload
def select( # type: ignore
entity_0: Type[_TModel_0],
entity_1: Type[_TModel_1],
__ent0: _TCCA[_T0],
__ent1: _TCCA[_T1],
entity_2: _TScalar_2,
**kw: Any,
) -> Select[Tuple[_TModel_0, _TModel_1, _TScalar_2]]:
) -> Select[Tuple[_T0, _T1, _TScalar_2]]:
...
@overload
def select( # type: ignore
entity_0: Type[_TModel_0],
entity_1: Type[_TModel_1],
entity_2: Type[_TModel_2],
**kw: Any,
) -> Select[Tuple[_TModel_0, _TModel_1, _TModel_2]]:
__ent0: _TCCA[_T0],
__ent1: _TCCA[_T1],
__ent2: _TCCA[_T2],
) -> Select[Tuple[_T0, _T1, _T2]]:
...
@ -251,7 +450,6 @@ def select( # type: ignore
entity_1: _TScalar_1,
entity_2: _TScalar_2,
entity_3: _TScalar_3,
**kw: Any,
) -> Select[Tuple[_TScalar_0, _TScalar_1, _TScalar_2, _TScalar_3]]:
...
@ -261,9 +459,8 @@ def select( # type: ignore
entity_0: _TScalar_0,
entity_1: _TScalar_1,
entity_2: _TScalar_2,
entity_3: Type[_TModel_3],
**kw: Any,
) -> Select[Tuple[_TScalar_0, _TScalar_1, _TScalar_2, _TModel_3]]:
__ent3: _TCCA[_T3],
) -> Select[Tuple[_TScalar_0, _TScalar_1, _TScalar_2, _T3]]:
...
@ -271,10 +468,9 @@ def select( # type: ignore
def select( # type: ignore
entity_0: _TScalar_0,
entity_1: _TScalar_1,
entity_2: Type[_TModel_2],
__ent2: _TCCA[_T2],
entity_3: _TScalar_3,
**kw: Any,
) -> Select[Tuple[_TScalar_0, _TScalar_1, _TModel_2, _TScalar_3]]:
) -> Select[Tuple[_TScalar_0, _TScalar_1, _T2, _TScalar_3]]:
...
@ -282,156 +478,142 @@ def select( # type: ignore
def select( # type: ignore
entity_0: _TScalar_0,
entity_1: _TScalar_1,
entity_2: Type[_TModel_2],
entity_3: Type[_TModel_3],
**kw: Any,
) -> Select[Tuple[_TScalar_0, _TScalar_1, _TModel_2, _TModel_3]]:
__ent2: _TCCA[_T2],
__ent3: _TCCA[_T3],
) -> Select[Tuple[_TScalar_0, _TScalar_1, _T2, _T3]]:
...
@overload
def select( # type: ignore
entity_0: _TScalar_0,
entity_1: Type[_TModel_1],
__ent1: _TCCA[_T1],
entity_2: _TScalar_2,
entity_3: _TScalar_3,
**kw: Any,
) -> Select[Tuple[_TScalar_0, _TModel_1, _TScalar_2, _TScalar_3]]:
) -> Select[Tuple[_TScalar_0, _T1, _TScalar_2, _TScalar_3]]:
...
@overload
def select( # type: ignore
entity_0: _TScalar_0,
entity_1: Type[_TModel_1],
__ent1: _TCCA[_T1],
entity_2: _TScalar_2,
entity_3: Type[_TModel_3],
**kw: Any,
) -> Select[Tuple[_TScalar_0, _TModel_1, _TScalar_2, _TModel_3]]:
__ent3: _TCCA[_T3],
) -> Select[Tuple[_TScalar_0, _T1, _TScalar_2, _T3]]:
...
@overload
def select( # type: ignore
entity_0: _TScalar_0,
entity_1: Type[_TModel_1],
entity_2: Type[_TModel_2],
__ent1: _TCCA[_T1],
__ent2: _TCCA[_T2],
entity_3: _TScalar_3,
**kw: Any,
) -> Select[Tuple[_TScalar_0, _TModel_1, _TModel_2, _TScalar_3]]:
) -> Select[Tuple[_TScalar_0, _T1, _T2, _TScalar_3]]:
...
@overload
def select( # type: ignore
entity_0: _TScalar_0,
entity_1: Type[_TModel_1],
entity_2: Type[_TModel_2],
entity_3: Type[_TModel_3],
**kw: Any,
) -> Select[Tuple[_TScalar_0, _TModel_1, _TModel_2, _TModel_3]]:
__ent1: _TCCA[_T1],
__ent2: _TCCA[_T2],
__ent3: _TCCA[_T3],
) -> Select[Tuple[_TScalar_0, _T1, _T2, _T3]]:
...
@overload
def select( # type: ignore
entity_0: Type[_TModel_0],
__ent0: _TCCA[_T0],
entity_1: _TScalar_1,
entity_2: _TScalar_2,
entity_3: _TScalar_3,
**kw: Any,
) -> Select[Tuple[_TModel_0, _TScalar_1, _TScalar_2, _TScalar_3]]:
) -> Select[Tuple[_T0, _TScalar_1, _TScalar_2, _TScalar_3]]:
...
@overload
def select( # type: ignore
entity_0: Type[_TModel_0],
__ent0: _TCCA[_T0],
entity_1: _TScalar_1,
entity_2: _TScalar_2,
entity_3: Type[_TModel_3],
**kw: Any,
) -> Select[Tuple[_TModel_0, _TScalar_1, _TScalar_2, _TModel_3]]:
__ent3: _TCCA[_T3],
) -> Select[Tuple[_T0, _TScalar_1, _TScalar_2, _T3]]:
...
@overload
def select( # type: ignore
entity_0: Type[_TModel_0],
__ent0: _TCCA[_T0],
entity_1: _TScalar_1,
entity_2: Type[_TModel_2],
__ent2: _TCCA[_T2],
entity_3: _TScalar_3,
**kw: Any,
) -> Select[Tuple[_TModel_0, _TScalar_1, _TModel_2, _TScalar_3]]:
) -> Select[Tuple[_T0, _TScalar_1, _T2, _TScalar_3]]:
...
@overload
def select( # type: ignore
entity_0: Type[_TModel_0],
__ent0: _TCCA[_T0],
entity_1: _TScalar_1,
entity_2: Type[_TModel_2],
entity_3: Type[_TModel_3],
**kw: Any,
) -> Select[Tuple[_TModel_0, _TScalar_1, _TModel_2, _TModel_3]]:
__ent2: _TCCA[_T2],
__ent3: _TCCA[_T3],
) -> Select[Tuple[_T0, _TScalar_1, _T2, _T3]]:
...
@overload
def select( # type: ignore
entity_0: Type[_TModel_0],
entity_1: Type[_TModel_1],
__ent0: _TCCA[_T0],
__ent1: _TCCA[_T1],
entity_2: _TScalar_2,
entity_3: _TScalar_3,
**kw: Any,
) -> Select[Tuple[_TModel_0, _TModel_1, _TScalar_2, _TScalar_3]]:
) -> Select[Tuple[_T0, _T1, _TScalar_2, _TScalar_3]]:
...
@overload
def select( # type: ignore
entity_0: Type[_TModel_0],
entity_1: Type[_TModel_1],
__ent0: _TCCA[_T0],
__ent1: _TCCA[_T1],
entity_2: _TScalar_2,
entity_3: Type[_TModel_3],
**kw: Any,
) -> Select[Tuple[_TModel_0, _TModel_1, _TScalar_2, _TModel_3]]:
__ent3: _TCCA[_T3],
) -> Select[Tuple[_T0, _T1, _TScalar_2, _T3]]:
...
@overload
def select( # type: ignore
entity_0: Type[_TModel_0],
entity_1: Type[_TModel_1],
entity_2: Type[_TModel_2],
__ent0: _TCCA[_T0],
__ent1: _TCCA[_T1],
__ent2: _TCCA[_T2],
entity_3: _TScalar_3,
**kw: Any,
) -> Select[Tuple[_TModel_0, _TModel_1, _TModel_2, _TScalar_3]]:
) -> Select[Tuple[_T0, _T1, _T2, _TScalar_3]]:
...
@overload
def select( # type: ignore
entity_0: Type[_TModel_0],
entity_1: Type[_TModel_1],
entity_2: Type[_TModel_2],
entity_3: Type[_TModel_3],
**kw: Any,
) -> Select[Tuple[_TModel_0, _TModel_1, _TModel_2, _TModel_3]]:
__ent0: _TCCA[_T0],
__ent1: _TCCA[_T1],
__ent2: _TCCA[_T2],
__ent3: _TCCA[_T3],
) -> Select[Tuple[_T0, _T1, _T2, _T3]]:
...
# Generated overloads end
def select(*entities: Any, **kw: Any) -> Union[Select, SelectOfScalar]: # type: ignore
def select(*entities: Any) -> Union[Select, SelectOfScalar]: # type: ignore
if len(entities) == 1:
return SelectOfScalar._create(*entities, **kw) # type: ignore
return Select._create(*entities, **kw) # type: ignore
return SelectOfScalar(*entities)
return Select(*entities)
# TODO: add several @overload from Python types to SQLAlchemy equivalents
def col(column_expression: Any) -> ColumnClause: # type: ignore
def col(column_expression: _T) -> Mapped[_T]:
if not isinstance(column_expression, (ColumnClause, Column, InstrumentedAttribute)):
raise RuntimeError(f"Not a SQLAlchemy column: {column_expression}")
return column_expression
return column_expression # type: ignore

View File

@ -1,9 +1,9 @@
from datetime import datetime
from typing import (
TYPE_CHECKING,
Any,
Generic,
Iterable,
Mapping,
Optional,
Sequence,
Tuple,
Type,
@ -13,28 +13,243 @@ from typing import (
)
from uuid import UUID
from sqlalchemy import Column
from sqlalchemy.orm import InstrumentedAttribute
from sqlalchemy.sql.elements import ColumnClause
import sqlalchemy
from sqlalchemy import (
Column,
ColumnElement,
Extract,
FunctionElement,
FunctionFilter,
Label,
Over,
TypeCoerce,
WithinGroup,
)
from sqlalchemy.orm import InstrumentedAttribute, Mapped
from sqlalchemy.sql._typing import (
_ColumnExpressionArgument,
_ColumnExpressionOrLiteralArgument,
_ColumnExpressionOrStrLabelArgument,
)
from sqlalchemy.sql.elements import (
BinaryExpression,
Case,
Cast,
CollectionAggregate,
ColumnClause,
SQLCoreOperations,
TryCast,
UnaryExpression,
)
from sqlalchemy.sql.expression import Select as _Select
from sqlalchemy.sql.roles import TypedColumnsClauseRole
from sqlalchemy.sql.type_api import TypeEngine
from typing_extensions import Literal, Self
_TSelect = TypeVar("_TSelect")
_T = TypeVar("_T")
class Select(_Select, Generic[_TSelect]):
_TypeEngineArgument = Union[Type[TypeEngine[_T]], TypeEngine[_T]]
# Redefine operatos that would only take a column expresion to also take the (virtual)
# types of Pydantic models, e.g. str instead of only Mapped[str].
def all_(expr: Union[_ColumnExpressionArgument[_T], _T]) -> CollectionAggregate[bool]:
return sqlalchemy.all_(expr) # type: ignore[arg-type]
def and_(
initial_clause: Union[Literal[True], _ColumnExpressionArgument[bool], bool],
*clauses: Union[_ColumnExpressionArgument[bool], bool],
) -> ColumnElement[bool]:
return sqlalchemy.and_(initial_clause, *clauses) # type: ignore[arg-type]
def any_(expr: Union[_ColumnExpressionArgument[_T], _T]) -> CollectionAggregate[bool]:
return sqlalchemy.any_(expr) # type: ignore[arg-type]
def asc(
column: Union[_ColumnExpressionOrStrLabelArgument[_T], _T],
) -> UnaryExpression[_T]:
return sqlalchemy.asc(column) # type: ignore[arg-type]
def collate(
expression: Union[_ColumnExpressionArgument[str], str], collation: str
) -> BinaryExpression[str]:
return sqlalchemy.collate(expression, collation) # type: ignore[arg-type]
def between(
expr: Union[_ColumnExpressionOrLiteralArgument[_T], _T],
lower_bound: Any,
upper_bound: Any,
symmetric: bool = False,
) -> BinaryExpression[bool]:
return sqlalchemy.between(expr, lower_bound, upper_bound, symmetric=symmetric) # type: ignore[arg-type]
def not_(clause: Union[_ColumnExpressionArgument[_T], _T]) -> ColumnElement[_T]:
return sqlalchemy.not_(clause) # type: ignore[arg-type]
def case(
*whens: Union[
Tuple[Union[_ColumnExpressionArgument[bool], bool], Any], Mapping[Any, Any]
],
value: Optional[Any] = None,
else_: Optional[Any] = None,
) -> Case[Any]:
return sqlalchemy.case(*whens, value=value, else_=else_) # type: ignore[arg-type]
def cast(
expression: Union[_ColumnExpressionOrLiteralArgument[Any], Any],
type_: "_TypeEngineArgument[_T]",
) -> Cast[_T]:
return sqlalchemy.cast(expression, type_) # type: ignore[arg-type]
def try_cast(
expression: Union[_ColumnExpressionOrLiteralArgument[Any], Any],
type_: "_TypeEngineArgument[_T]",
) -> TryCast[_T]:
return sqlalchemy.try_cast(expression, type_) # type: ignore[arg-type]
def desc(
column: Union[_ColumnExpressionOrStrLabelArgument[_T], _T],
) -> UnaryExpression[_T]:
return sqlalchemy.desc(column) # type: ignore[arg-type]
def distinct(expr: Union[_ColumnExpressionArgument[_T], _T]) -> UnaryExpression[_T]:
return sqlalchemy.distinct(expr) # type: ignore[arg-type]
def bitwise_not(expr: Union[_ColumnExpressionArgument[_T], _T]) -> UnaryExpression[_T]:
return sqlalchemy.bitwise_not(expr) # type: ignore[arg-type]
def extract(field: str, expr: Union[_ColumnExpressionArgument[Any], Any]) -> Extract:
return sqlalchemy.extract(field, expr) # type: ignore[arg-type]
def funcfilter(
func: FunctionElement[_T], *criterion: Union[_ColumnExpressionArgument[bool], bool]
) -> FunctionFilter[_T]:
return sqlalchemy.funcfilter(func, *criterion) # type: ignore[arg-type]
def label(
name: str,
element: Union[_ColumnExpressionArgument[_T], _T],
type_: Optional["_TypeEngineArgument[_T]"] = None,
) -> Label[_T]:
return sqlalchemy.label(name, element, type_=type_) # type: ignore[arg-type]
def nulls_first(
column: Union[_ColumnExpressionArgument[_T], _T]
) -> UnaryExpression[_T]:
return sqlalchemy.nulls_first(column) # type: ignore[arg-type]
def nulls_last(column: Union[_ColumnExpressionArgument[_T], _T]) -> UnaryExpression[_T]:
return sqlalchemy.nulls_last(column) # type: ignore[arg-type]
def or_( # type: ignore[empty-body]
initial_clause: Union[Literal[False], _ColumnExpressionArgument[bool], bool],
*clauses: Union[_ColumnExpressionArgument[bool], bool],
) -> ColumnElement[bool]:
return sqlalchemy.or_(initial_clause, *clauses) # type: ignore[arg-type]
def over(
element: FunctionElement[_T],
partition_by: Optional[
Union[
Iterable[Union[_ColumnExpressionArgument[Any], Any]],
_ColumnExpressionArgument[Any],
Any,
]
] = None,
order_by: Optional[
Union[
Iterable[Union[_ColumnExpressionArgument[Any], Any]],
_ColumnExpressionArgument[Any],
Any,
]
] = None,
range_: Optional[Tuple[Optional[int], Optional[int]]] = None,
rows: Optional[Tuple[Optional[int], Optional[int]]] = None,
) -> Over[_T]:
return sqlalchemy.over(
element, partition_by=partition_by, order_by=order_by, range_=range_, rows=rows
) # type: ignore[arg-type]
def tuple_(
*clauses: Union[_ColumnExpressionArgument[Any], Any],
types: Optional[Sequence["_TypeEngineArgument[Any]"]] = None,
) -> Tuple[Any, ...]:
return sqlalchemy.tuple_(*clauses, types=types) # type: ignore[return-value]
def type_coerce(
expression: Union[_ColumnExpressionOrLiteralArgument[Any], Any],
type_: "_TypeEngineArgument[_T]",
) -> TypeCoerce[_T]:
return sqlalchemy.type_coerce(expression, type_) # type: ignore[arg-type]
def within_group(
element: FunctionElement[_T], *order_by: Union[_ColumnExpressionArgument[Any], Any]
) -> WithinGroup[_T]:
return sqlalchemy.within_group(element, *order_by) # type: ignore[arg-type]
# Separate this class in SelectBase, Select, and SelectOfScalar so that they can share
# where and having without having type overlap incompatibility in session.exec().
class SelectBase(_Select[Tuple[_T]]):
inherit_cache = True
def where(self, *whereclause: Union[_ColumnExpressionArgument[bool], bool]) -> Self:
"""Return a new `Select` construct with the given expression added to
its `WHERE` clause, joined to the existing clause via `AND`, if any.
"""
return super().where(*whereclause) # type: ignore[arg-type]
def having(self, *having: Union[_ColumnExpressionArgument[bool], bool]) -> Self:
"""Return a new `Select` construct with the given expression added to
its `HAVING` clause, joined to the existing clause via `AND`, if any.
"""
return super().having(*having) # type: ignore[arg-type]
class Select(SelectBase[_T]):
inherit_cache = True
# This is not comparable to sqlalchemy.sql.selectable.ScalarSelect, that has a different
# purpose. This is the same as a normal SQLAlchemy Select class where there's only one
# entity, so the result will be converted to a scalar by default. This way writing
# for loops on the results will feel natural.
class SelectOfScalar(_Select, Generic[_TSelect]):
class SelectOfScalar(SelectBase[_T]):
inherit_cache = True
if TYPE_CHECKING: # pragma: no cover
from ..main import SQLModel
_TCCA = Union[
TypedColumnsClauseRole[_T],
SQLCoreOperations[_T],
Type[_T],
]
# Generated TypeVars start
{% for i in range(number_of_types) %}
_TScalar_{{ i }} = TypeVar(
"_TScalar_{{ i }}",
@ -51,19 +266,19 @@ _TScalar_{{ i }} = TypeVar(
None,
)
_TModel_{{ i }} = TypeVar("_TModel_{{ i }}", bound="SQLModel")
_T{{ i }} = TypeVar("_T{{ i }}")
{% endfor %}
# Generated TypeVars end
@overload
def select(entity_0: _TScalar_0, **kw: Any) -> SelectOfScalar[_TScalar_0]: # type: ignore
def select(__ent0: _TScalar_0) -> SelectOfScalar[_TScalar_0]: # type: ignore
...
@overload
def select(entity_0: Type[_TModel_0], **kw: Any) -> SelectOfScalar[_TModel_0]: # type: ignore
def select(__ent0: _TCCA[_T0]) -> SelectOfScalar[_T0]:
...
@ -73,7 +288,7 @@ def select(entity_0: Type[_TModel_0], **kw: Any) -> SelectOfScalar[_TModel_0]:
@overload
def select( # type: ignore
{% for arg in signature[0] %}{{ arg.name }}: {{ arg.annotation }}, {% endfor %}**kw: Any,
{% for arg in signature[0] %}{{ arg.name }}: {{ arg.annotation }}, {% endfor %}
) -> Select[Tuple[{%for ret in signature[1] %}{{ ret }} {% if not loop.last %}, {% endif %}{% endfor %}]]:
...
@ -81,14 +296,14 @@ def select( # type: ignore
# Generated overloads end
def select(*entities: Any, **kw: Any) -> Union[Select, SelectOfScalar]: # type: ignore
def select(*entities: Any) -> Union[Select, SelectOfScalar]: # type: ignore
if len(entities) == 1:
return SelectOfScalar._create(*entities, **kw) # type: ignore
return Select._create(*entities, **kw) # type: ignore
return SelectOfScalar(*entities)
return Select(*entities)
# TODO: add several @overload from Python types to SQLAlchemy equivalents
def col(column_expression: Any) -> ColumnClause: # type: ignore
def col(column_expression: _T) -> Mapped[_T]:
if not isinstance(column_expression, (ColumnClause, Column, InstrumentedAttribute)):
raise RuntimeError(f"Not a SQLAlchemy column: {column_expression}")
return column_expression
return column_expression # type: ignore

View File

@ -15,7 +15,7 @@ class AutoString(types.TypeDecorator): # type: ignore
def load_dialect_impl(self, dialect: Dialect) -> "types.TypeEngine[Any]":
impl = cast(types.String, self.impl)
if impl.length is None and dialect.name == "mysql":
return dialect.type_descriptor(types.String(self.mysql_default_length)) # type: ignore
return dialect.type_descriptor(types.String(self.mysql_default_length))
return super().load_dialect_impl(dialect)
@ -32,11 +32,11 @@ class GUID(types.TypeDecorator): # type: ignore
impl = CHAR
cache_ok = True
def load_dialect_impl(self, dialect: Dialect) -> TypeEngine: # type: ignore
def load_dialect_impl(self, dialect: Dialect) -> TypeEngine[Any]:
if dialect.name == "postgresql":
return dialect.type_descriptor(UUID()) # type: ignore
return dialect.type_descriptor(UUID())
else:
return dialect.type_descriptor(CHAR(32)) # type: ignore
return dialect.type_descriptor(CHAR(32))
def process_bind_param(self, value: Any, dialect: Dialect) -> Optional[str]:
if value is None:

View File

@ -59,7 +59,7 @@ def test_tutorial(clear_sqlmodel):
response = client.get("/openapi.json")
assert response.status_code == 200, response.text
assert response.json() == {
"openapi": "3.0.2",
"openapi": "3.1.0",
"info": {"title": "FastAPI", "version": "0.1.0"},
"paths": {
"/heroes/": {
@ -315,7 +315,9 @@ def test_tutorial(clear_sqlmodel):
"loc": {
"title": "Location",
"type": "array",
"items": {"type": "string"},
"items": {
"anyOf": [{"type": "string"}, {"type": "integer"}]
},
},
"msg": {"title": "Message", "type": "string"},
"type": {"title": "Error Type", "type": "string"},

View File

@ -64,7 +64,7 @@ def test_tutorial(clear_sqlmodel):
response = client.get("/openapi.json")
assert response.status_code == 200, response.text
assert response.json() == {
"openapi": "3.0.2",
"openapi": "3.1.0",
"info": {"title": "FastAPI", "version": "0.1.0"},
"paths": {
"/heroes/": {
@ -239,7 +239,9 @@ def test_tutorial(clear_sqlmodel):
"loc": {
"title": "Location",
"type": "array",
"items": {"type": "string"},
"items": {
"anyOf": [{"type": "string"}, {"type": "integer"}]
},
},
"msg": {"title": "Message", "type": "string"},
"type": {"title": "Error Type", "type": "string"},

View File

@ -5,7 +5,7 @@ from sqlmodel import create_engine
from sqlmodel.pool import StaticPool
openapi_schema = {
"openapi": "3.0.2",
"openapi": "3.1.0",
"info": {"title": "FastAPI", "version": "0.1.0"},
"paths": {
"/heroes/": {
@ -103,7 +103,7 @@ openapi_schema = {
"loc": {
"title": "Location",
"type": "array",
"items": {"type": "string"},
"items": {"anyOf": [{"type": "string"}, {"type": "integer"}]},
},
"msg": {"title": "Message", "type": "string"},
"type": {"title": "Error Type", "type": "string"},

View File

@ -5,7 +5,7 @@ from sqlmodel import create_engine
from sqlmodel.pool import StaticPool
openapi_schema = {
"openapi": "3.0.2",
"openapi": "3.1.0",
"info": {"title": "FastAPI", "version": "0.1.0"},
"paths": {
"/heroes/": {
@ -103,7 +103,7 @@ openapi_schema = {
"loc": {
"title": "Location",
"type": "array",
"items": {"type": "string"},
"items": {"anyOf": [{"type": "string"}, {"type": "integer"}]},
},
"msg": {"title": "Message", "type": "string"},
"type": {"title": "Error Type", "type": "string"},

View File

@ -3,7 +3,7 @@ from sqlmodel import create_engine
from sqlmodel.pool import StaticPool
openapi_schema = {
"openapi": "3.0.2",
"openapi": "3.1.0",
"info": {"title": "FastAPI", "version": "0.1.0"},
"paths": {
"/heroes/": {
@ -135,7 +135,7 @@ openapi_schema = {
"loc": {
"title": "Location",
"type": "array",
"items": {"type": "string"},
"items": {"anyOf": [{"type": "string"}, {"type": "integer"}]},
},
"msg": {"title": "Message", "type": "string"},
"type": {"title": "Error Type", "type": "string"},

View File

@ -107,7 +107,7 @@ def test_tutorial(clear_sqlmodel):
response = client.get("/openapi.json")
assert response.status_code == 200, response.text
assert response.json() == {
"openapi": "3.0.2",
"openapi": "3.1.0",
"info": {"title": "FastAPI", "version": "0.1.0"},
"paths": {
"/heroes/": {
@ -622,7 +622,9 @@ def test_tutorial(clear_sqlmodel):
"loc": {
"title": "Location",
"type": "array",
"items": {"type": "string"},
"items": {
"anyOf": [{"type": "string"}, {"type": "integer"}]
},
},
"msg": {"title": "Message", "type": "string"},
"type": {"title": "Error Type", "type": "string"},

View File

@ -3,7 +3,7 @@ from sqlmodel import create_engine
from sqlmodel.pool import StaticPool
openapi_schema = {
"openapi": "3.0.2",
"openapi": "3.1.0",
"info": {"title": "FastAPI", "version": "0.1.0"},
"paths": {
"/heroes/": {
@ -91,7 +91,7 @@ openapi_schema = {
"loc": {
"title": "Location",
"type": "array",
"items": {"type": "string"},
"items": {"anyOf": [{"type": "string"}, {"type": "integer"}]},
},
"msg": {"title": "Message", "type": "string"},
"type": {"title": "Error Type", "type": "string"},

View File

@ -59,7 +59,7 @@ def test_tutorial(clear_sqlmodel):
response = client.get("/openapi.json")
assert response.status_code == 200, response.text
assert response.json() == {
"openapi": "3.0.2",
"openapi": "3.1.0",
"info": {"title": "FastAPI", "version": "0.1.0"},
"paths": {
"/heroes/": {
@ -315,7 +315,9 @@ def test_tutorial(clear_sqlmodel):
"loc": {
"title": "Location",
"type": "array",
"items": {"type": "string"},
"items": {
"anyOf": [{"type": "string"}, {"type": "integer"}]
},
},
"msg": {"title": "Message", "type": "string"},
"type": {"title": "Error Type", "type": "string"},

View File

@ -3,7 +3,7 @@ from sqlmodel import create_engine
from sqlmodel.pool import StaticPool
openapi_schema = {
"openapi": "3.0.2",
"openapi": "3.1.0",
"info": {"title": "FastAPI", "version": "0.1.0"},
"paths": {
"/heroes/": {
@ -79,7 +79,7 @@ openapi_schema = {
"loc": {
"title": "Location",
"type": "array",
"items": {"type": "string"},
"items": {"anyOf": [{"type": "string"}, {"type": "integer"}]},
},
"msg": {"title": "Message", "type": "string"},
"type": {"title": "Error Type", "type": "string"},

View File

@ -94,7 +94,7 @@ def test_tutorial(clear_sqlmodel):
response = client.get("/openapi.json")
assert response.status_code == 200, response.text
assert response.json() == {
"openapi": "3.0.2",
"openapi": "3.1.0",
"info": {"title": "FastAPI", "version": "0.1.0"},
"paths": {
"/heroes/": {
@ -579,7 +579,9 @@ def test_tutorial(clear_sqlmodel):
"loc": {
"title": "Location",
"type": "array",
"items": {"type": "string"},
"items": {
"anyOf": [{"type": "string"}, {"type": "integer"}]
},
},
"msg": {"title": "Message", "type": "string"},
"type": {"title": "Error Type", "type": "string"},

View File

@ -66,7 +66,7 @@ def test_tutorial(clear_sqlmodel):
response = client.get("/openapi.json")
assert response.status_code == 200, response.text
assert response.json() == {
"openapi": "3.0.2",
"openapi": "3.1.0",
"info": {"title": "FastAPI", "version": "0.1.0"},
"paths": {
"/heroes/": {
@ -294,7 +294,9 @@ def test_tutorial(clear_sqlmodel):
"loc": {
"title": "Location",
"type": "array",
"items": {"type": "string"},
"items": {
"anyOf": [{"type": "string"}, {"type": "integer"}]
},
},
"msg": {"title": "Message", "type": "string"},
"type": {"title": "Error Type", "type": "string"},