Compare commits

..

21 Commits
0.0.2 ... 0.0.5

Author SHA1 Message Date
Sebastián Ramírez
02697459b8 🔖 Release version 0.0.5 2021-12-13 12:41:51 +01:00
github-actions
7eadc90558 📝 Update release notes 2021-12-13 11:38:40 +00:00
Sebastián Ramírez
95c02962ba ✏ Update decimal tutorial source for consistency (#188) 2021-12-13 11:37:59 +00:00
github-actions
75540f9728 📝 Update release notes 2021-12-13 11:30:57 +00:00
robcxyz
580f372059 Add support for Decimal fields from Pydantic and SQLAlchemy (#103)
Co-authored-by: Sebastián Ramírez <tiangolo@gmail.com>
2021-12-13 12:30:20 +01:00
github-actions
1c276ef88f 📝 Update release notes 2021-12-13 10:47:44 +00:00
Sebastián Ramírez
14a9788eb1 🔧 Split MkDocs insiders build in CI to support building from PRs (#186) 2021-12-13 11:47:07 +01:00
github-actions
dbcaa50c69 📝 Update release notes 2021-12-13 10:41:14 +00:00
Sebastián Ramírez
362eb81701 🎨 Format expression.py and expression template, currently needed by CI (#187) 2021-12-13 10:40:40 +00:00
github-actions
a36c6d5778 📝 Update release notes 2021-12-03 10:24:01 +00:00
Lehoczky Zoltán
82935cae9f 🐛Fix docs light/dark theme switcher (#1)
* 🐛Fix tooltip text for theme switcher

* 🔧 Update lightbulb icon

Co-authored-by: Sebastián Ramírez <tiangolo@gmail.com>
2021-12-03 11:23:20 +01:00
github-actions
455794da2c 📝 Update release notes 2021-11-30 16:28:31 +00:00
Sebastián Ramírez
55259b3c8b 🔧 Add MkDocs Material social cards (#90) 2021-11-30 17:27:50 +01:00
github-actions
328c8c725d 📝 Update release notes 2021-11-30 16:13:10 +00:00
Sebastián Ramírez
e30c7ef4e9 Update type annotations and upgrade mypy (#173) 2021-11-30 17:12:28 +01:00
Sebastián Ramírez
02da85c9ec 🔖 Release version 0.0.4 2021-08-25 15:46:57 +02:00
github-actions
878e230782 📝 Update release notes 2021-08-25 13:44:35 +00:00
Sebastián Ramírez
1da849ac48 🎨 Fix type detection of select results in PyCharm (#15) 2021-08-25 13:43:53 +00:00
Sebastián Ramírez
af03df88ac 🔖 Release version 0.0.3 2021-08-24 20:44:18 +02:00
github-actions
d80a2fb7ed 📝 Update release notes 2021-08-24 18:43:29 +00:00
Sebastián Ramírez
230911ab42 ⬆️ Update and relax specification range for sqlalchemy-stubs (#4) 2021-08-24 20:42:52 +02:00
24 changed files with 410 additions and 102 deletions

View File

@@ -51,8 +51,16 @@ jobs:
- name: Install Material for MkDocs Insiders
if: github.event.pull_request.head.repo.fork == false && steps.cache.outputs.cache-hit != 'true'
run: python -m poetry run pip install git+https://${{ secrets.ACTIONS_TOKEN }}@github.com/squidfunk/mkdocs-material-insiders.git
- uses: actions/cache@v2
with:
key: mkdocs-cards-${{ github.ref }}
path: .cache
- name: Build Docs
if: github.event.pull_request.head.repo.fork == true
run: python -m poetry run mkdocs build
- name: Build Docs with Insiders
if: github.event.pull_request.head.repo.fork == false
run: python -m poetry run mkdocs build --config-file mkdocs.insiders.yml
- name: Zip docs
run: python -m poetry run bash ./scripts/zip-docs.sh
- uses: actions/upload-artifact@v2

1
.gitignore vendored
View File

@@ -11,3 +11,4 @@ htmlcov
coverage.xml
site
*.db
.cache

148
docs/advanced/decimal.md Normal file
View File

@@ -0,0 +1,148 @@
# Decimal Numbers
In some cases you might need to be able to store decimal numbers with guarantees about the precision.
This is particularly important if you are storing things like **currencies**, **prices**, **accounts**, and others, as you would want to know that you wouldn't have rounding errors.
As an example, if you open Python and sum `1.1` + `2.2` you would expect to see `3.3`, but you will actually get `3.3000000000000003`:
```Python
>>> 1.1 + 2.2
3.3000000000000003
```
This is because of the way numbers are stored in "ones and zeros" (binary). But Python has a module and some types to have strict decimal values. You can read more about it in the official <a href="https://docs.python.org/3/library/decimal.html" class="external-link" target="_blank">Python docs for Decimal</a>.
Because databases store data in the same ways as computers (in binary), they would have the same types of issues. And because of that, they also have a special **decimal** type.
In most cases this would probably not be a problem, for example measuring views in a video, or the life bar in a videogame. But as you can imagine, this is particularly important when dealing with **money** and **finances**.
## Decimal Types
Pydantic has special support for `Decimal` types using the <a href="https://pydantic-docs.helpmanual.io/usage/types/#arguments-to-condecimal" class="external-link" target="_blank">`condecimal()` special function</a>.
!!! tip
Pydantic 1.9, that will be released soon, has improved support for `Decimal` types, without needing to use the `condecimal()` function.
But meanwhile, you can already use this feature with `condecimal()` in **SQLModel** it as it's explained here.
When you use `condecimal()` you can specify the number of digits and decimal places to support. They will be validated by Pydantic (for example when using FastAPI) and the same information will also be used for the database columns.
!!! info
For the database, **SQLModel** will use <a href="https://docs.sqlalchemy.org/en/14/core/type_basics.html#sqlalchemy.types.DECIMAL" class="external-link" target="_blank">SQLAlchemy's `DECIMAL` type</a>.
## Decimals in SQLModel
Let's say that each hero in the database will have an amount of money. We could make that field a `Decimal` type using the `condecimal()` function:
```{.python .annotate hl_lines="12" }
{!./docs_src/advanced/decimal/tutorial001.py[ln:1-12]!}
# More code here later 👇
```
<details>
<summary>👀 Full file preview</summary>
```Python
{!./docs_src/advanced/decimal/tutorial001.py!}
```
</details>
Here we are saying that `money` can have at most `5` digits with `max_digits`, **this includes the integers** (to the left of the decimal dot) **and the decimals** (to the right of the decimal dot).
We are also saying that the number of decimal places (to the right of the decimal dot) is `3`, so we can have **3 decimal digits** for these numbers in the `money` field. This means that we will have **2 digits for the integer part** and **3 digits for the decimal part**.
✅ So, for example, these are all valid numbers for the `money` field:
* `12.345`
* `12.3`
* `12`
* `1.2`
* `0.123`
* `0`
🚫 But these are all invalid numbers for that `money` field:
* `1.2345`
* This number has more than 3 decimal places.
* `123.234`
* This number has more than 5 digits in total (integer and decimal part).
* `123`
* Even though this number doesn't have any decimals, we still have 3 places saved for them, which means that we can **only use 2 places** for the **integer part**, and this number has 3 integer digits. So, the allowed number of integer digits is `max_digits` - `decimal_places` = 2.
!!! tip
Make sure you adjust the number of digits and decimal places for your own needs, in your own application. 🤓
## Create models with Decimals
When creating new models you can actually pass normal (`float`) numbers, Pydantic will automatically convert them to `Decimal` types, and **SQLModel** will store them as `Decimal` types in the database (using SQLAlchemy).
```Python hl_lines="4-6"
# Code above omitted 👆
{!./docs_src/advanced/decimal/tutorial001.py[ln:25-35]!}
# Code below omitted 👇
```
<details>
<summary>👀 Full file preview</summary>
```Python
{!./docs_src/advanced/decimal/tutorial001.py!}
```
</details>
## Select Decimal data
Then, when working with Decimal types, you can confirm that they indeed avoid those rounding errors from floats:
```Python hl_lines="15-16"
# Code above omitted 👆
{!./docs_src/advanced/decimal/tutorial001.py[ln:38-51]!}
# Code below omitted 👇
```
<details>
<summary>👀 Full file preview</summary>
```Python
{!./docs_src/advanced/decimal/tutorial001.py!}
```
</details>
## Review the results
Now if you run this, instead of printing the unexpected number `3.3000000000000003`, it prints `3.300`:
<div class="termy">
```console
$ python app.py
// Some boilerplate and previous output omitted 😉
// The type of money is Decimal('1.100')
Hero 1: id=1 secret_name='Dive Wilson' age=None name='Deadpond' money=Decimal('1.100')
// More output omitted here 🤓
// The type of money is Decimal('1.100')
Hero 2: id=3 secret_name='Tommy Sharp' age=48 name='Rusty-Man' money=Decimal('2.200')
// No rounding errors, just 3.3! 🎉
Total money: 3.300
```
</div>
!!! warning
Although Decimal types are supported and used in the Python side, not all databases support it. In particular, SQLite doesn't support decimals, so it will convert them to the same floating `NUMERIC` type it supports.
But decimals are supported by most of the other SQL databases. 🎉

View File

@@ -1,12 +1,10 @@
# Advanced User Guide
The **Advanced User Guide** will be coming soon to a <del>theater</del> **documentation** near you... 😅
The **Advanced User Guide** is gradually growing, you can already read about some advanced topics.
I just have to `add` it, `commit` it, and `refresh` it. 😉
At some point it will include:
It will include:
* How to use the `async` and `await` with the async session.
* How to use `async` and `await` with the async session.
* How to run migrations.
* How to combine **SQLModel** models with SQLAlchemy.
* ...and more.
* ...and more. 🤓

View File

@@ -3,6 +3,32 @@
## Latest Changes
## 0.0.5
### Features
* ✨ Add support for Decimal fields from Pydantic and SQLAlchemy. Original PR [#103](https://github.com/tiangolo/sqlmodel/pull/103) by [@robcxyz](https://github.com/robcxyz). New docs: [Advanced User Guide - Decimal Numbers](https://sqlmodel.tiangolo.com/advanced/decimal/).
### Docs
* ✏ Update decimal tutorial source for consistency. PR [#188](https://github.com/tiangolo/sqlmodel/pull/188) by [@tiangolo](https://github.com/tiangolo).
### Internal
* 🔧 Split MkDocs insiders build in CI to support building from PRs. PR [#186](https://github.com/tiangolo/sqlmodel/pull/186) by [@tiangolo](https://github.com/tiangolo).
* 🎨 Format `expression.py` and expression template, currently needed by CI. PR [#187](https://github.com/tiangolo/sqlmodel/pull/187) by [@tiangolo](https://github.com/tiangolo).
* 🐛Fix docs light/dark theme switcher. PR [#1](https://github.com/tiangolo/sqlmodel/pull/1) by [@Lehoczky](https://github.com/Lehoczky).
* 🔧 Add MkDocs Material social cards. PR [#90](https://github.com/tiangolo/sqlmodel/pull/90) by [@tiangolo](https://github.com/tiangolo).
* ✨ Update type annotations and upgrade mypy. PR [#173](https://github.com/tiangolo/sqlmodel/pull/173) by [@tiangolo](https://github.com/tiangolo).
## 0.0.4
* 🎨 Fix type detection of select results in PyCharm. PR [#15](https://github.com/tiangolo/sqlmodel/pull/15) by [@tiangolo](https://github.com/tiangolo).
## 0.0.3
* ⬆️ Update and relax specification range for `sqlalchemy-stubs`. PR [#4](https://github.com/tiangolo/sqlmodel/pull/4) by [@tiangolo](https://github.com/tiangolo).
## 0.0.2
* This includes several small bug fixes detected during the first CI runs.

View File

View File

View File

@@ -0,0 +1,61 @@
from typing import Optional
from pydantic import condecimal
from sqlmodel import Field, Session, SQLModel, create_engine, select
class Hero(SQLModel, table=True):
id: Optional[int] = Field(default=None, primary_key=True)
name: str
secret_name: str
age: Optional[int] = None
money: condecimal(max_digits=5, decimal_places=3) = Field(default=0)
sqlite_file_name = "database.db"
sqlite_url = f"sqlite:///{sqlite_file_name}"
engine = create_engine(sqlite_url, echo=True)
def create_db_and_tables():
SQLModel.metadata.create_all(engine)
def create_heroes():
hero_1 = Hero(name="Deadpond", secret_name="Dive Wilson", money=1.1)
hero_2 = Hero(name="Spider-Boy", secret_name="Pedro Parqueador", money=0.001)
hero_3 = Hero(name="Rusty-Man", secret_name="Tommy Sharp", age=48, money=2.2)
with Session(engine) as session:
session.add(hero_1)
session.add(hero_2)
session.add(hero_3)
session.commit()
def select_heroes():
with Session(engine) as session:
statement = select(Hero).where(Hero.name == "Deadpond")
results = session.exec(statement)
hero_1 = results.one()
print("Hero 1:", hero_1)
statement = select(Hero).where(Hero.name == "Rusty-Man")
results = session.exec(statement)
hero_2 = results.one()
print("Hero 2:", hero_2)
total_money = hero_1.money + hero_2.money
print(f"Total money: {total_money}")
def main():
create_db_and_tables()
create_heroes()
select_heroes()
if __name__ == "__main__":
main()

4
mkdocs.insiders.yml Normal file
View File

@@ -0,0 +1,4 @@
INHERIT: mkdocs.yml
plugins:
- search
- social

View File

@@ -8,14 +8,14 @@ theme:
primary: deep purple
accent: amber
toggle:
icon: material/lightbulb-outline
name: Switch to light mode
icon: material/lightbulb
name: Switch to dark mode
- scheme: slate
primary: deep purple
accent: amber
toggle:
icon: material/lightbulb
name: Switch to dark mode
icon: material/lightbulb-outline
name: Switch to light mode
features:
- search.suggest
- search.highlight
@@ -30,8 +30,6 @@ edit_uri: ''
google_analytics:
- UA-205713594-2
- auto
plugins:
- search
nav:
- SQLModel: index.md
- features.md
@@ -86,6 +84,7 @@ nav:
- tutorial/fastapi/tests.md
- Advanced User Guide:
- advanced/index.md
- advanced/decimal.md
- alternatives.md
- help.md
- contributing.md

View File

@@ -33,11 +33,11 @@ classifiers = [
python = "^3.6.1"
SQLAlchemy = ">=1.4.17,<1.5.0"
pydantic = "^1.8.2"
sqlalchemy2-stubs = "^0.0.2-alpha.5"
sqlalchemy2-stubs = {version = "*", allow-prereleases = true}
[tool.poetry.dev-dependencies]
pytest = "^6.2.4"
mypy = "^0.812"
mypy = "^0.910"
flake8 = "^3.9.2"
black = {version = "^21.5-beta.1", python = "^3.7"}
mkdocs = "^1.2.1"
@@ -98,3 +98,7 @@ warn_return_any = true
implicit_reexport = false
strict_equality = true
# --strict end
[[tool.mypy.overrides]]
module = "sqlmodel.sql.expression"
warn_unused_ignores = false

View File

@@ -1,4 +1,4 @@
__version__ = "0.0.2"
__version__ = "0.0.5"
# Re-export from SQLAlchemy
from sqlalchemy.engine import create_mock_engine as create_mock_engine

View File

@@ -136,4 +136,4 @@ def create_engine(
if not isinstance(query_cache_size, _DefaultPlaceholder):
current_kwargs["query_cache_size"] = query_cache_size
current_kwargs.update(kwargs)
return _create_engine(url, **current_kwargs)
return _create_engine(url, **current_kwargs) # type: ignore

View File

@@ -23,7 +23,7 @@ class ScalarResult(_ScalarResult, Generic[_T]):
return super().__iter__()
def __next__(self) -> _T:
return super().__next__()
return super().__next__() # type: ignore
def first(self) -> Optional[_T]:
return super().first()
@@ -32,7 +32,7 @@ class ScalarResult(_ScalarResult, Generic[_T]):
return super().one_or_none()
def one(self) -> _T:
return super().one()
return super().one() # type: ignore
class Result(_Result, Generic[_T]):
@@ -70,10 +70,10 @@ class Result(_Result, Generic[_T]):
return super().scalar_one() # type: ignore
def scalar_one_or_none(self) -> Optional[_T]:
return super().scalar_one_or_none() # type: ignore
return super().scalar_one_or_none()
def one(self) -> _T: # type: ignore
return super().one() # type: ignore
def scalar(self) -> Optional[_T]:
return super().scalar() # type: ignore
return super().scalar()

View File

@@ -21,7 +21,7 @@ class AsyncSession(_AsyncSession):
self,
bind: Optional[Union[AsyncConnection, AsyncEngine]] = None,
binds: Optional[Mapping[object, Union[AsyncConnection, AsyncEngine]]] = None,
**kw,
**kw: Any,
):
# All the same code of the original AsyncSession
kw["future"] = True
@@ -52,7 +52,7 @@ class AsyncSession(_AsyncSession):
# util.immutabledict has the union() method. Is this a bug in SQLAlchemy?
execution_options = execution_options.union({"prebuffer_rows": True}) # type: ignore
return await greenlet_spawn( # type: ignore
return await greenlet_spawn(
self.sync_session.exec,
statement,
params=params,

View File

@@ -101,7 +101,7 @@ class RelationshipInfo(Representation):
*,
back_populates: Optional[str] = None,
link_model: Optional[Any] = None,
sa_relationship: Optional[RelationshipProperty] = None,
sa_relationship: Optional[RelationshipProperty] = None, # type: ignore
sa_relationship_args: Optional[Sequence[Any]] = None,
sa_relationship_kwargs: Optional[Mapping[str, Any]] = None,
) -> None:
@@ -127,32 +127,32 @@ def Field(
default: Any = Undefined,
*,
default_factory: Optional[NoArgAnyCallable] = None,
alias: str = None,
title: str = None,
description: str = None,
alias: Optional[str] = None,
title: Optional[str] = None,
description: Optional[str] = None,
exclude: Union[
AbstractSet[Union[int, str]], Mapping[Union[int, str], Any], Any
] = None,
include: Union[
AbstractSet[Union[int, str]], Mapping[Union[int, str], Any], Any
] = None,
const: bool = None,
gt: float = None,
ge: float = None,
lt: float = None,
le: float = None,
multiple_of: float = None,
min_items: int = None,
max_items: int = None,
min_length: int = None,
max_length: int = None,
const: Optional[bool] = None,
gt: Optional[float] = None,
ge: Optional[float] = None,
lt: Optional[float] = None,
le: Optional[float] = None,
multiple_of: Optional[float] = None,
min_items: Optional[int] = None,
max_items: Optional[int] = None,
min_length: Optional[int] = None,
max_length: Optional[int] = None,
allow_mutation: bool = True,
regex: str = None,
regex: Optional[str] = None,
primary_key: bool = False,
foreign_key: Optional[Any] = None,
nullable: Union[bool, UndefinedType] = Undefined,
index: Union[bool, UndefinedType] = Undefined,
sa_column: Union[Column, UndefinedType] = Undefined,
sa_column: Union[Column, UndefinedType] = Undefined, # type: ignore
sa_column_args: Union[Sequence[Any], UndefinedType] = Undefined,
sa_column_kwargs: Union[Mapping[str, Any], UndefinedType] = Undefined,
schema_extra: Optional[Dict[str, Any]] = None,
@@ -195,7 +195,7 @@ def Relationship(
*,
back_populates: Optional[str] = None,
link_model: Optional[Any] = None,
sa_relationship: Optional[RelationshipProperty] = None,
sa_relationship: Optional[RelationshipProperty] = None, # type: ignore
sa_relationship_args: Optional[Sequence[Any]] = None,
sa_relationship_kwargs: Optional[Mapping[str, Any]] = None,
) -> Any:
@@ -217,19 +217,25 @@ class SQLModelMetaclass(ModelMetaclass, DeclarativeMeta):
# Replicate SQLAlchemy
def __setattr__(cls, name: str, value: Any) -> None:
if getattr(cls.__config__, "table", False): # type: ignore
if getattr(cls.__config__, "table", False):
DeclarativeMeta.__setattr__(cls, name, value)
else:
super().__setattr__(name, value)
def __delattr__(cls, name: str) -> None:
if getattr(cls.__config__, "table", False): # type: ignore
if getattr(cls.__config__, "table", False):
DeclarativeMeta.__delattr__(cls, name)
else:
super().__delattr__(name)
# From Pydantic
def __new__(cls, name, bases, class_dict: dict, **kwargs) -> Any:
def __new__(
cls,
name: str,
bases: Tuple[Type[Any], ...],
class_dict: Dict[str, Any],
**kwargs: Any,
) -> Any:
relationships: Dict[str, RelationshipInfo] = {}
dict_for_pydantic = {}
original_annotations = resolve_annotations(
@@ -342,7 +348,7 @@ class SQLModelMetaclass(ModelMetaclass, DeclarativeMeta):
)
relationship_to = temp_field.type_
if isinstance(temp_field.type_, ForwardRef):
relationship_to = temp_field.type_.__forward_arg__ # type: ignore
relationship_to = temp_field.type_.__forward_arg__
rel_kwargs: Dict[str, Any] = {}
if rel_info.back_populates:
rel_kwargs["back_populates"] = rel_info.back_populates
@@ -360,7 +366,7 @@ class SQLModelMetaclass(ModelMetaclass, DeclarativeMeta):
rel_args.extend(rel_info.sa_relationship_args)
if rel_info.sa_relationship_kwargs:
rel_kwargs.update(rel_info.sa_relationship_kwargs)
rel_value: RelationshipProperty = relationship(
rel_value: RelationshipProperty = relationship( # type: ignore
relationship_to, *rel_args, **rel_kwargs
)
dict_used[rel_name] = rel_value
@@ -393,7 +399,10 @@ def get_sqlachemy_type(field: ModelField) -> Any:
if issubclass(field.type_, bytes):
return LargeBinary
if issubclass(field.type_, Decimal):
return Numeric
return Numeric(
precision=getattr(field.type_, "max_digits", None),
scale=getattr(field.type_, "decimal_places", None),
)
if issubclass(field.type_, ipaddress.IPv4Address):
return AutoString
if issubclass(field.type_, ipaddress.IPv4Network):
@@ -408,7 +417,7 @@ def get_sqlachemy_type(field: ModelField) -> Any:
return GUID
def get_column_from_field(field: ModelField) -> Column:
def get_column_from_field(field: ModelField) -> Column: # type: ignore
sa_column = getattr(field.field_info, "sa_column", Undefined)
if isinstance(sa_column, Column):
return sa_column
@@ -440,10 +449,10 @@ def get_column_from_field(field: ModelField) -> Column:
kwargs["default"] = sa_default
sa_column_args = getattr(field.field_info, "sa_column_args", Undefined)
if sa_column_args is not Undefined:
args.extend(list(cast(Sequence, sa_column_args)))
args.extend(list(cast(Sequence[Any], sa_column_args)))
sa_column_kwargs = getattr(field.field_info, "sa_column_kwargs", Undefined)
if sa_column_kwargs is not Undefined:
kwargs.update(cast(dict, sa_column_kwargs))
kwargs.update(cast(Dict[Any, Any], sa_column_kwargs))
return Column(sa_type, *args, **kwargs)
@@ -452,24 +461,27 @@ class_registry = weakref.WeakValueDictionary() # type: ignore
default_registry = registry()
def _value_items_is_true(v) -> bool:
def _value_items_is_true(v: Any) -> bool:
# Re-implement Pydantic's ValueItems.is_true() as it hasn't been released as of
# the current latest, Pydantic 1.8.2
return v is True or v is ...
_TSQLModel = TypeVar("_TSQLModel", bound="SQLModel")
class SQLModel(BaseModel, metaclass=SQLModelMetaclass, registry=default_registry):
# SQLAlchemy needs to set weakref(s), Pydantic will set the other slots values
__slots__ = ("__weakref__",)
__tablename__: ClassVar[Union[str, Callable[..., str]]]
__sqlmodel_relationships__: ClassVar[Dict[str, RelationshipProperty]]
__sqlmodel_relationships__: ClassVar[Dict[str, RelationshipProperty]] # type: ignore
__name__: ClassVar[str]
metadata: ClassVar[MetaData]
class Config:
orm_mode = True
def __new__(cls, *args, **kwargs) -> Any:
def __new__(cls, *args: Any, **kwargs: Any) -> Any:
new_object = super().__new__(cls)
# SQLAlchemy doesn't call __init__ on the base class
# Ref: https://docs.sqlalchemy.org/en/14/orm/constructors.html
@@ -520,7 +532,9 @@ class SQLModel(BaseModel, metaclass=SQLModelMetaclass, registry=default_registry
super().__setattr__(name, value)
@classmethod
def from_orm(cls: Type["SQLModel"], obj: Any, update: Dict[str, Any] = None):
def from_orm(
cls: Type[_TSQLModel], obj: Any, update: Optional[Dict[str, Any]] = None
) -> _TSQLModel:
# Duplicated from Pydantic
if not cls.__config__.orm_mode:
raise ConfigError(
@@ -533,7 +547,7 @@ class SQLModel(BaseModel, metaclass=SQLModelMetaclass, registry=default_registry
# End SQLModel support dict
if not getattr(cls.__config__, "table", False):
# If not table, normal Pydantic code
m = cls.__new__(cls)
m: _TSQLModel = cls.__new__(cls)
else:
# If table, create the new instance normally to make SQLAlchemy create
# the _sa_instance_state attribute
@@ -554,7 +568,7 @@ class SQLModel(BaseModel, metaclass=SQLModelMetaclass, registry=default_registry
@classmethod
def parse_obj(
cls: Type["SQLModel"], obj: Any, update: Dict[str, Any] = None
cls: Type["SQLModel"], obj: Any, update: Optional[Dict[str, Any]] = None
) -> "SQLModel":
obj = cls._enforce_dict_if_root(obj)
# SQLModel, support update dict

View File

@@ -10,14 +10,14 @@ from typing_extensions import Literal
from ..engine.result import Result, ScalarResult
from ..sql.base import Executable
_T = TypeVar("_T")
_TSelectParam = TypeVar("_TSelectParam")
class Session(_Session):
@overload
def exec(
self,
statement: Select[_T],
statement: Select[_TSelectParam],
*,
params: Optional[Union[Mapping[str, Any], Sequence[Mapping[str, Any]]]] = None,
execution_options: Mapping[str, Any] = util.EMPTY_DICT,
@@ -25,13 +25,13 @@ class Session(_Session):
_parent_execute_state: Optional[Any] = None,
_add_event: Optional[Any] = None,
**kw: Any,
) -> Union[Result[_T]]:
) -> Result[_TSelectParam]:
...
@overload
def exec(
self,
statement: SelectOfScalar[_T],
statement: SelectOfScalar[_TSelectParam],
*,
params: Optional[Union[Mapping[str, Any], Sequence[Mapping[str, Any]]]] = None,
execution_options: Mapping[str, Any] = util.EMPTY_DICT,
@@ -39,12 +39,16 @@ class Session(_Session):
_parent_execute_state: Optional[Any] = None,
_add_event: Optional[Any] = None,
**kw: Any,
) -> Union[ScalarResult[_T]]:
) -> ScalarResult[_TSelectParam]:
...
def exec(
self,
statement: Union[Select[_T], SelectOfScalar[_T], Executable[_T]],
statement: Union[
Select[_TSelectParam],
SelectOfScalar[_TSelectParam],
Executable[_TSelectParam],
],
*,
params: Optional[Union[Mapping[str, Any], Sequence[Mapping[str, Any]]]] = None,
execution_options: Mapping[str, Any] = util.EMPTY_DICT,
@@ -52,11 +56,11 @@ class Session(_Session):
_parent_execute_state: Optional[Any] = None,
_add_event: Optional[Any] = None,
**kw: Any,
) -> Union[Result[_T], ScalarResult[_T]]:
) -> Union[Result[_TSelectParam], ScalarResult[_TSelectParam]]:
results = super().execute(
statement,
params=params,
execution_options=execution_options, # type: ignore
execution_options=execution_options,
bind_arguments=bind_arguments,
_parent_execute_state=_parent_execute_state,
_add_event=_add_event,
@@ -70,7 +74,7 @@ class Session(_Session):
self,
statement: _Executable,
params: Optional[Union[Mapping[str, Any], Sequence[Mapping[str, Any]]]] = None,
execution_options: Mapping[str, Any] = util.EMPTY_DICT,
execution_options: Optional[Mapping[str, Any]] = util.EMPTY_DICT,
bind_arguments: Optional[Mapping[str, Any]] = None,
_parent_execute_state: Optional[Any] = None,
_add_event: Optional[Any] = None,
@@ -97,7 +101,7 @@ class Session(_Session):
return super().execute( # type: ignore
statement,
params=params,
execution_options=execution_options, # type: ignore
execution_options=execution_options,
bind_arguments=bind_arguments,
_parent_execute_state=_parent_execute_state,
_add_event=_add_event,
@@ -118,13 +122,13 @@ class Session(_Session):
def get(
self,
entity: Type[_T],
entity: Type[_TSelectParam],
ident: Any,
options: Optional[Sequence[Any]] = None,
populate_existing: bool = False,
with_for_update: Optional[Union[Literal[True], Mapping[str, Any]]] = None,
identity_token: Optional[Any] = None,
) -> Optional[_T]:
) -> Optional[_TSelectParam]:
return super().get(
entity,
ident,

View File

@@ -6,6 +6,4 @@ _T = TypeVar("_T")
class Executable(_Executable, Generic[_T]):
def __init__(self, *args, **kwargs):
self.__dict__["_exec_options"] = kwargs.pop("_exec_options", None)
super(_Executable, self).__init__(*args, **kwargs)
pass

View File

@@ -38,17 +38,16 @@ if sys.version_info.minor >= 7:
class SelectOfScalar(_Select, Generic[_TSelect]):
pass
else:
from typing import GenericMeta # type: ignore
class GenericSelectMeta(GenericMeta, _Select.__class__): # type: ignore
pass
class _Py36Select(_Select, Generic[_TSelect], metaclass=GenericSelectMeta): # type: ignore
class _Py36Select(_Select, Generic[_TSelect], metaclass=GenericSelectMeta):
pass
class _Py36SelectOfScalar(_Select, Generic[_TSelect], metaclass=GenericSelectMeta): # type: ignore
class _Py36SelectOfScalar(_Select, Generic[_TSelect], metaclass=GenericSelectMeta):
pass
# Cast them for editors to work correctly, from several tricks tried, this works
@@ -65,9 +64,9 @@ if TYPE_CHECKING: # pragma: no cover
_TScalar_0 = TypeVar(
"_TScalar_0",
Column,
Sequence,
Mapping,
Column, # type: ignore
Sequence, # type: ignore
Mapping, # type: ignore
UUID,
datetime,
float,
@@ -83,9 +82,9 @@ _TModel_0 = TypeVar("_TModel_0", bound="SQLModel")
_TScalar_1 = TypeVar(
"_TScalar_1",
Column,
Sequence,
Mapping,
Column, # type: ignore
Sequence, # type: ignore
Mapping, # type: ignore
UUID,
datetime,
float,
@@ -101,9 +100,9 @@ _TModel_1 = TypeVar("_TModel_1", bound="SQLModel")
_TScalar_2 = TypeVar(
"_TScalar_2",
Column,
Sequence,
Mapping,
Column, # type: ignore
Sequence, # type: ignore
Mapping, # type: ignore
UUID,
datetime,
float,
@@ -119,9 +118,9 @@ _TModel_2 = TypeVar("_TModel_2", bound="SQLModel")
_TScalar_3 = TypeVar(
"_TScalar_3",
Column,
Sequence,
Mapping,
Column, # type: ignore
Sequence, # type: ignore
Mapping, # type: ignore
UUID,
datetime,
float,
@@ -446,14 +445,14 @@ def select( # type: ignore
# Generated overloads end
def select(*entities: Any, **kw: Any) -> Union[Select, SelectOfScalar]:
def select(*entities: Any, **kw: Any) -> Union[Select, SelectOfScalar]: # type: ignore
if len(entities) == 1:
return SelectOfScalar._create(*entities, **kw) # type: ignore
return Select._create(*entities, **kw) # type: ignore
# TODO: add several @overload from Python types to SQLAlchemy equivalents
def col(column_expression: Any) -> ColumnClause:
def col(column_expression: Any) -> ColumnClause: # type: ignore
if not isinstance(column_expression, (ColumnClause, Column, InstrumentedAttribute)):
raise RuntimeError(f"Not a SQLAlchemy column: {column_expression}")
return column_expression

View File

@@ -36,7 +36,6 @@ if sys.version_info.minor >= 7:
class SelectOfScalar(_Select, Generic[_TSelect]):
pass
else:
from typing import GenericMeta # type: ignore
@@ -63,9 +62,9 @@ if TYPE_CHECKING: # pragma: no cover
{% for i in range(number_of_types) %}
_TScalar_{{ i }} = TypeVar(
"_TScalar_{{ i }}",
Column,
Sequence,
Mapping,
Column, # type: ignore
Sequence, # type: ignore
Mapping, # type: ignore
UUID,
datetime,
float,
@@ -106,14 +105,14 @@ def select( # type: ignore
# Generated overloads end
def select(*entities: Any, **kw: Any) -> Union[Select, SelectOfScalar]:
def select(*entities: Any, **kw: Any) -> Union[Select, SelectOfScalar]: # type: ignore
if len(entities) == 1:
return SelectOfScalar._create(*entities, **kw) # type: ignore
return Select._create(*entities, **kw)
return Select._create(*entities, **kw) # type: ignore
# TODO: add several @overload from Python types to SQLAlchemy equivalents
def col(column_expression: Any) -> ColumnClause:
def col(column_expression: Any) -> ColumnClause: # type: ignore
if not isinstance(column_expression, (ColumnClause, Column, InstrumentedAttribute)):
raise RuntimeError(f"Not a SQLAlchemy column: {column_expression}")
return column_expression

View File

@@ -1,13 +1,14 @@
import uuid
from typing import Any, cast
from typing import Any, Optional, cast
from sqlalchemy import types
from sqlalchemy.dialects.postgresql import UUID
from sqlalchemy.engine.interfaces import Dialect
from sqlalchemy.sql.type_api import TypeEngine
from sqlalchemy.types import CHAR, TypeDecorator
class AutoString(types.TypeDecorator):
class AutoString(types.TypeDecorator): # type: ignore
impl = types.String
cache_ok = True
@@ -22,7 +23,7 @@ class AutoString(types.TypeDecorator):
# Reference form SQLAlchemy docs: https://docs.sqlalchemy.org/en/14/core/custom_types.html#backend-agnostic-guid-type
# with small modifications
class GUID(TypeDecorator):
class GUID(TypeDecorator): # type: ignore
"""Platform-independent GUID type.
Uses PostgreSQL's UUID type, otherwise uses
@@ -33,13 +34,13 @@ class GUID(TypeDecorator):
impl = CHAR
cache_ok = True
def load_dialect_impl(self, dialect):
def load_dialect_impl(self, dialect: Dialect) -> TypeEngine: # type: ignore
if dialect.name == "postgresql":
return dialect.type_descriptor(UUID())
return dialect.type_descriptor(UUID()) # type: ignore
else:
return dialect.type_descriptor(CHAR(32))
return dialect.type_descriptor(CHAR(32)) # type: ignore
def process_bind_param(self, value, dialect):
def process_bind_param(self, value: Any, dialect: Dialect) -> Optional[str]:
if value is None:
return value
elif dialect.name == "postgresql":
@@ -51,10 +52,10 @@ class GUID(TypeDecorator):
# hexstring
return f"{value.int:x}"
def process_result_value(self, value, dialect):
def process_result_value(self, value: Any, dialect: Dialect) -> Optional[uuid.UUID]:
if value is None:
return value
else:
if not isinstance(value, uuid.UUID):
value = uuid.UUID(value)
return value
return cast(uuid.UUID, value)

View File

View File

@@ -0,0 +1,44 @@
from decimal import Decimal
from unittest.mock import patch
from sqlmodel import create_engine
from ...conftest import get_testing_print_function
expected_calls = [
[
"Hero 1:",
{
"name": "Deadpond",
"age": None,
"id": 1,
"secret_name": "Dive Wilson",
"money": Decimal("1.100"),
},
],
[
"Hero 2:",
{
"name": "Rusty-Man",
"age": 48,
"id": 3,
"secret_name": "Tommy Sharp",
"money": Decimal("2.200"),
},
],
["Total money: 3.300"],
]
def test_tutorial(clear_sqlmodel):
from docs_src.advanced.decimal import tutorial001 as mod
mod.sqlite_url = "sqlite://"
mod.engine = create_engine(mod.sqlite_url)
calls = []
new_print = get_testing_print_function(calls)
with patch("builtins.print", new=new_print):
mod.main()
assert calls == expected_calls