Compare commits
21 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
02697459b8 | ||
|
|
7eadc90558 | ||
|
|
95c02962ba | ||
|
|
75540f9728 | ||
|
|
580f372059 | ||
|
|
1c276ef88f | ||
|
|
14a9788eb1 | ||
|
|
dbcaa50c69 | ||
|
|
362eb81701 | ||
|
|
a36c6d5778 | ||
|
|
82935cae9f | ||
|
|
455794da2c | ||
|
|
55259b3c8b | ||
|
|
328c8c725d | ||
|
|
e30c7ef4e9 | ||
|
|
02da85c9ec | ||
|
|
878e230782 | ||
|
|
1da849ac48 | ||
|
|
af03df88ac | ||
|
|
d80a2fb7ed | ||
|
|
230911ab42 |
8
.github/workflows/build-docs.yml
vendored
8
.github/workflows/build-docs.yml
vendored
@@ -51,8 +51,16 @@ jobs:
|
||||
- name: Install Material for MkDocs Insiders
|
||||
if: github.event.pull_request.head.repo.fork == false && steps.cache.outputs.cache-hit != 'true'
|
||||
run: python -m poetry run pip install git+https://${{ secrets.ACTIONS_TOKEN }}@github.com/squidfunk/mkdocs-material-insiders.git
|
||||
- uses: actions/cache@v2
|
||||
with:
|
||||
key: mkdocs-cards-${{ github.ref }}
|
||||
path: .cache
|
||||
- name: Build Docs
|
||||
if: github.event.pull_request.head.repo.fork == true
|
||||
run: python -m poetry run mkdocs build
|
||||
- name: Build Docs with Insiders
|
||||
if: github.event.pull_request.head.repo.fork == false
|
||||
run: python -m poetry run mkdocs build --config-file mkdocs.insiders.yml
|
||||
- name: Zip docs
|
||||
run: python -m poetry run bash ./scripts/zip-docs.sh
|
||||
- uses: actions/upload-artifact@v2
|
||||
|
||||
1
.gitignore
vendored
1
.gitignore
vendored
@@ -11,3 +11,4 @@ htmlcov
|
||||
coverage.xml
|
||||
site
|
||||
*.db
|
||||
.cache
|
||||
|
||||
148
docs/advanced/decimal.md
Normal file
148
docs/advanced/decimal.md
Normal file
@@ -0,0 +1,148 @@
|
||||
# Decimal Numbers
|
||||
|
||||
In some cases you might need to be able to store decimal numbers with guarantees about the precision.
|
||||
|
||||
This is particularly important if you are storing things like **currencies**, **prices**, **accounts**, and others, as you would want to know that you wouldn't have rounding errors.
|
||||
|
||||
As an example, if you open Python and sum `1.1` + `2.2` you would expect to see `3.3`, but you will actually get `3.3000000000000003`:
|
||||
|
||||
```Python
|
||||
>>> 1.1 + 2.2
|
||||
3.3000000000000003
|
||||
```
|
||||
|
||||
This is because of the way numbers are stored in "ones and zeros" (binary). But Python has a module and some types to have strict decimal values. You can read more about it in the official <a href="https://docs.python.org/3/library/decimal.html" class="external-link" target="_blank">Python docs for Decimal</a>.
|
||||
|
||||
Because databases store data in the same ways as computers (in binary), they would have the same types of issues. And because of that, they also have a special **decimal** type.
|
||||
|
||||
In most cases this would probably not be a problem, for example measuring views in a video, or the life bar in a videogame. But as you can imagine, this is particularly important when dealing with **money** and **finances**.
|
||||
|
||||
## Decimal Types
|
||||
|
||||
Pydantic has special support for `Decimal` types using the <a href="https://pydantic-docs.helpmanual.io/usage/types/#arguments-to-condecimal" class="external-link" target="_blank">`condecimal()` special function</a>.
|
||||
|
||||
!!! tip
|
||||
Pydantic 1.9, that will be released soon, has improved support for `Decimal` types, without needing to use the `condecimal()` function.
|
||||
|
||||
But meanwhile, you can already use this feature with `condecimal()` in **SQLModel** it as it's explained here.
|
||||
|
||||
When you use `condecimal()` you can specify the number of digits and decimal places to support. They will be validated by Pydantic (for example when using FastAPI) and the same information will also be used for the database columns.
|
||||
|
||||
!!! info
|
||||
For the database, **SQLModel** will use <a href="https://docs.sqlalchemy.org/en/14/core/type_basics.html#sqlalchemy.types.DECIMAL" class="external-link" target="_blank">SQLAlchemy's `DECIMAL` type</a>.
|
||||
|
||||
## Decimals in SQLModel
|
||||
|
||||
Let's say that each hero in the database will have an amount of money. We could make that field a `Decimal` type using the `condecimal()` function:
|
||||
|
||||
```{.python .annotate hl_lines="12" }
|
||||
{!./docs_src/advanced/decimal/tutorial001.py[ln:1-12]!}
|
||||
|
||||
# More code here later 👇
|
||||
```
|
||||
|
||||
<details>
|
||||
<summary>👀 Full file preview</summary>
|
||||
|
||||
```Python
|
||||
{!./docs_src/advanced/decimal/tutorial001.py!}
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
Here we are saying that `money` can have at most `5` digits with `max_digits`, **this includes the integers** (to the left of the decimal dot) **and the decimals** (to the right of the decimal dot).
|
||||
|
||||
We are also saying that the number of decimal places (to the right of the decimal dot) is `3`, so we can have **3 decimal digits** for these numbers in the `money` field. This means that we will have **2 digits for the integer part** and **3 digits for the decimal part**.
|
||||
|
||||
✅ So, for example, these are all valid numbers for the `money` field:
|
||||
|
||||
* `12.345`
|
||||
* `12.3`
|
||||
* `12`
|
||||
* `1.2`
|
||||
* `0.123`
|
||||
* `0`
|
||||
|
||||
🚫 But these are all invalid numbers for that `money` field:
|
||||
|
||||
* `1.2345`
|
||||
* This number has more than 3 decimal places.
|
||||
* `123.234`
|
||||
* This number has more than 5 digits in total (integer and decimal part).
|
||||
* `123`
|
||||
* Even though this number doesn't have any decimals, we still have 3 places saved for them, which means that we can **only use 2 places** for the **integer part**, and this number has 3 integer digits. So, the allowed number of integer digits is `max_digits` - `decimal_places` = 2.
|
||||
|
||||
!!! tip
|
||||
Make sure you adjust the number of digits and decimal places for your own needs, in your own application. 🤓
|
||||
|
||||
## Create models with Decimals
|
||||
|
||||
When creating new models you can actually pass normal (`float`) numbers, Pydantic will automatically convert them to `Decimal` types, and **SQLModel** will store them as `Decimal` types in the database (using SQLAlchemy).
|
||||
|
||||
```Python hl_lines="4-6"
|
||||
# Code above omitted 👆
|
||||
|
||||
{!./docs_src/advanced/decimal/tutorial001.py[ln:25-35]!}
|
||||
|
||||
# Code below omitted 👇
|
||||
```
|
||||
|
||||
<details>
|
||||
<summary>👀 Full file preview</summary>
|
||||
|
||||
```Python
|
||||
{!./docs_src/advanced/decimal/tutorial001.py!}
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
## Select Decimal data
|
||||
|
||||
Then, when working with Decimal types, you can confirm that they indeed avoid those rounding errors from floats:
|
||||
|
||||
```Python hl_lines="15-16"
|
||||
# Code above omitted 👆
|
||||
|
||||
{!./docs_src/advanced/decimal/tutorial001.py[ln:38-51]!}
|
||||
|
||||
# Code below omitted 👇
|
||||
```
|
||||
|
||||
<details>
|
||||
<summary>👀 Full file preview</summary>
|
||||
|
||||
```Python
|
||||
{!./docs_src/advanced/decimal/tutorial001.py!}
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
## Review the results
|
||||
|
||||
Now if you run this, instead of printing the unexpected number `3.3000000000000003`, it prints `3.300`:
|
||||
|
||||
<div class="termy">
|
||||
|
||||
```console
|
||||
$ python app.py
|
||||
|
||||
// Some boilerplate and previous output omitted 😉
|
||||
|
||||
// The type of money is Decimal('1.100')
|
||||
Hero 1: id=1 secret_name='Dive Wilson' age=None name='Deadpond' money=Decimal('1.100')
|
||||
|
||||
// More output omitted here 🤓
|
||||
|
||||
// The type of money is Decimal('1.100')
|
||||
Hero 2: id=3 secret_name='Tommy Sharp' age=48 name='Rusty-Man' money=Decimal('2.200')
|
||||
|
||||
// No rounding errors, just 3.3! 🎉
|
||||
Total money: 3.300
|
||||
```
|
||||
|
||||
</div>
|
||||
|
||||
!!! warning
|
||||
Although Decimal types are supported and used in the Python side, not all databases support it. In particular, SQLite doesn't support decimals, so it will convert them to the same floating `NUMERIC` type it supports.
|
||||
|
||||
But decimals are supported by most of the other SQL databases. 🎉
|
||||
@@ -1,12 +1,10 @@
|
||||
# Advanced User Guide
|
||||
|
||||
The **Advanced User Guide** will be coming soon to a <del>theater</del> **documentation** near you... 😅
|
||||
The **Advanced User Guide** is gradually growing, you can already read about some advanced topics.
|
||||
|
||||
I just have to `add` it, `commit` it, and `refresh` it. 😉
|
||||
At some point it will include:
|
||||
|
||||
It will include:
|
||||
|
||||
* How to use the `async` and `await` with the async session.
|
||||
* How to use `async` and `await` with the async session.
|
||||
* How to run migrations.
|
||||
* How to combine **SQLModel** models with SQLAlchemy.
|
||||
* ...and more.
|
||||
* ...and more. 🤓
|
||||
|
||||
@@ -3,6 +3,32 @@
|
||||
## Latest Changes
|
||||
|
||||
|
||||
## 0.0.5
|
||||
|
||||
### Features
|
||||
|
||||
* ✨ Add support for Decimal fields from Pydantic and SQLAlchemy. Original PR [#103](https://github.com/tiangolo/sqlmodel/pull/103) by [@robcxyz](https://github.com/robcxyz). New docs: [Advanced User Guide - Decimal Numbers](https://sqlmodel.tiangolo.com/advanced/decimal/).
|
||||
|
||||
### Docs
|
||||
|
||||
* ✏ Update decimal tutorial source for consistency. PR [#188](https://github.com/tiangolo/sqlmodel/pull/188) by [@tiangolo](https://github.com/tiangolo).
|
||||
|
||||
### Internal
|
||||
|
||||
* 🔧 Split MkDocs insiders build in CI to support building from PRs. PR [#186](https://github.com/tiangolo/sqlmodel/pull/186) by [@tiangolo](https://github.com/tiangolo).
|
||||
* 🎨 Format `expression.py` and expression template, currently needed by CI. PR [#187](https://github.com/tiangolo/sqlmodel/pull/187) by [@tiangolo](https://github.com/tiangolo).
|
||||
* 🐛Fix docs light/dark theme switcher. PR [#1](https://github.com/tiangolo/sqlmodel/pull/1) by [@Lehoczky](https://github.com/Lehoczky).
|
||||
* 🔧 Add MkDocs Material social cards. PR [#90](https://github.com/tiangolo/sqlmodel/pull/90) by [@tiangolo](https://github.com/tiangolo).
|
||||
* ✨ Update type annotations and upgrade mypy. PR [#173](https://github.com/tiangolo/sqlmodel/pull/173) by [@tiangolo](https://github.com/tiangolo).
|
||||
|
||||
## 0.0.4
|
||||
|
||||
* 🎨 Fix type detection of select results in PyCharm. PR [#15](https://github.com/tiangolo/sqlmodel/pull/15) by [@tiangolo](https://github.com/tiangolo).
|
||||
|
||||
## 0.0.3
|
||||
|
||||
* ⬆️ Update and relax specification range for `sqlalchemy-stubs`. PR [#4](https://github.com/tiangolo/sqlmodel/pull/4) by [@tiangolo](https://github.com/tiangolo).
|
||||
|
||||
## 0.0.2
|
||||
|
||||
* This includes several small bug fixes detected during the first CI runs.
|
||||
|
||||
0
docs_src/advanced/__init__.py
Normal file
0
docs_src/advanced/__init__.py
Normal file
0
docs_src/advanced/decimal/__init__.py
Normal file
0
docs_src/advanced/decimal/__init__.py
Normal file
61
docs_src/advanced/decimal/tutorial001.py
Normal file
61
docs_src/advanced/decimal/tutorial001.py
Normal file
@@ -0,0 +1,61 @@
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import condecimal
|
||||
from sqlmodel import Field, Session, SQLModel, create_engine, select
|
||||
|
||||
|
||||
class Hero(SQLModel, table=True):
|
||||
id: Optional[int] = Field(default=None, primary_key=True)
|
||||
name: str
|
||||
secret_name: str
|
||||
age: Optional[int] = None
|
||||
money: condecimal(max_digits=5, decimal_places=3) = Field(default=0)
|
||||
|
||||
|
||||
sqlite_file_name = "database.db"
|
||||
sqlite_url = f"sqlite:///{sqlite_file_name}"
|
||||
|
||||
engine = create_engine(sqlite_url, echo=True)
|
||||
|
||||
|
||||
def create_db_and_tables():
|
||||
SQLModel.metadata.create_all(engine)
|
||||
|
||||
|
||||
def create_heroes():
|
||||
hero_1 = Hero(name="Deadpond", secret_name="Dive Wilson", money=1.1)
|
||||
hero_2 = Hero(name="Spider-Boy", secret_name="Pedro Parqueador", money=0.001)
|
||||
hero_3 = Hero(name="Rusty-Man", secret_name="Tommy Sharp", age=48, money=2.2)
|
||||
|
||||
with Session(engine) as session:
|
||||
session.add(hero_1)
|
||||
session.add(hero_2)
|
||||
session.add(hero_3)
|
||||
|
||||
session.commit()
|
||||
|
||||
|
||||
def select_heroes():
|
||||
with Session(engine) as session:
|
||||
statement = select(Hero).where(Hero.name == "Deadpond")
|
||||
results = session.exec(statement)
|
||||
hero_1 = results.one()
|
||||
print("Hero 1:", hero_1)
|
||||
|
||||
statement = select(Hero).where(Hero.name == "Rusty-Man")
|
||||
results = session.exec(statement)
|
||||
hero_2 = results.one()
|
||||
print("Hero 2:", hero_2)
|
||||
|
||||
total_money = hero_1.money + hero_2.money
|
||||
print(f"Total money: {total_money}")
|
||||
|
||||
|
||||
def main():
|
||||
create_db_and_tables()
|
||||
create_heroes()
|
||||
select_heroes()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
4
mkdocs.insiders.yml
Normal file
4
mkdocs.insiders.yml
Normal file
@@ -0,0 +1,4 @@
|
||||
INHERIT: mkdocs.yml
|
||||
plugins:
|
||||
- search
|
||||
- social
|
||||
11
mkdocs.yml
11
mkdocs.yml
@@ -8,14 +8,14 @@ theme:
|
||||
primary: deep purple
|
||||
accent: amber
|
||||
toggle:
|
||||
icon: material/lightbulb-outline
|
||||
name: Switch to light mode
|
||||
icon: material/lightbulb
|
||||
name: Switch to dark mode
|
||||
- scheme: slate
|
||||
primary: deep purple
|
||||
accent: amber
|
||||
toggle:
|
||||
icon: material/lightbulb
|
||||
name: Switch to dark mode
|
||||
icon: material/lightbulb-outline
|
||||
name: Switch to light mode
|
||||
features:
|
||||
- search.suggest
|
||||
- search.highlight
|
||||
@@ -30,8 +30,6 @@ edit_uri: ''
|
||||
google_analytics:
|
||||
- UA-205713594-2
|
||||
- auto
|
||||
plugins:
|
||||
- search
|
||||
nav:
|
||||
- SQLModel: index.md
|
||||
- features.md
|
||||
@@ -86,6 +84,7 @@ nav:
|
||||
- tutorial/fastapi/tests.md
|
||||
- Advanced User Guide:
|
||||
- advanced/index.md
|
||||
- advanced/decimal.md
|
||||
- alternatives.md
|
||||
- help.md
|
||||
- contributing.md
|
||||
|
||||
@@ -33,11 +33,11 @@ classifiers = [
|
||||
python = "^3.6.1"
|
||||
SQLAlchemy = ">=1.4.17,<1.5.0"
|
||||
pydantic = "^1.8.2"
|
||||
sqlalchemy2-stubs = "^0.0.2-alpha.5"
|
||||
sqlalchemy2-stubs = {version = "*", allow-prereleases = true}
|
||||
|
||||
[tool.poetry.dev-dependencies]
|
||||
pytest = "^6.2.4"
|
||||
mypy = "^0.812"
|
||||
mypy = "^0.910"
|
||||
flake8 = "^3.9.2"
|
||||
black = {version = "^21.5-beta.1", python = "^3.7"}
|
||||
mkdocs = "^1.2.1"
|
||||
@@ -98,3 +98,7 @@ warn_return_any = true
|
||||
implicit_reexport = false
|
||||
strict_equality = true
|
||||
# --strict end
|
||||
|
||||
[[tool.mypy.overrides]]
|
||||
module = "sqlmodel.sql.expression"
|
||||
warn_unused_ignores = false
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
__version__ = "0.0.2"
|
||||
__version__ = "0.0.5"
|
||||
|
||||
# Re-export from SQLAlchemy
|
||||
from sqlalchemy.engine import create_mock_engine as create_mock_engine
|
||||
|
||||
@@ -136,4 +136,4 @@ def create_engine(
|
||||
if not isinstance(query_cache_size, _DefaultPlaceholder):
|
||||
current_kwargs["query_cache_size"] = query_cache_size
|
||||
current_kwargs.update(kwargs)
|
||||
return _create_engine(url, **current_kwargs)
|
||||
return _create_engine(url, **current_kwargs) # type: ignore
|
||||
|
||||
@@ -23,7 +23,7 @@ class ScalarResult(_ScalarResult, Generic[_T]):
|
||||
return super().__iter__()
|
||||
|
||||
def __next__(self) -> _T:
|
||||
return super().__next__()
|
||||
return super().__next__() # type: ignore
|
||||
|
||||
def first(self) -> Optional[_T]:
|
||||
return super().first()
|
||||
@@ -32,7 +32,7 @@ class ScalarResult(_ScalarResult, Generic[_T]):
|
||||
return super().one_or_none()
|
||||
|
||||
def one(self) -> _T:
|
||||
return super().one()
|
||||
return super().one() # type: ignore
|
||||
|
||||
|
||||
class Result(_Result, Generic[_T]):
|
||||
@@ -70,10 +70,10 @@ class Result(_Result, Generic[_T]):
|
||||
return super().scalar_one() # type: ignore
|
||||
|
||||
def scalar_one_or_none(self) -> Optional[_T]:
|
||||
return super().scalar_one_or_none() # type: ignore
|
||||
return super().scalar_one_or_none()
|
||||
|
||||
def one(self) -> _T: # type: ignore
|
||||
return super().one() # type: ignore
|
||||
|
||||
def scalar(self) -> Optional[_T]:
|
||||
return super().scalar() # type: ignore
|
||||
return super().scalar()
|
||||
|
||||
@@ -21,7 +21,7 @@ class AsyncSession(_AsyncSession):
|
||||
self,
|
||||
bind: Optional[Union[AsyncConnection, AsyncEngine]] = None,
|
||||
binds: Optional[Mapping[object, Union[AsyncConnection, AsyncEngine]]] = None,
|
||||
**kw,
|
||||
**kw: Any,
|
||||
):
|
||||
# All the same code of the original AsyncSession
|
||||
kw["future"] = True
|
||||
@@ -52,7 +52,7 @@ class AsyncSession(_AsyncSession):
|
||||
# util.immutabledict has the union() method. Is this a bug in SQLAlchemy?
|
||||
execution_options = execution_options.union({"prebuffer_rows": True}) # type: ignore
|
||||
|
||||
return await greenlet_spawn( # type: ignore
|
||||
return await greenlet_spawn(
|
||||
self.sync_session.exec,
|
||||
statement,
|
||||
params=params,
|
||||
|
||||
@@ -101,7 +101,7 @@ class RelationshipInfo(Representation):
|
||||
*,
|
||||
back_populates: Optional[str] = None,
|
||||
link_model: Optional[Any] = None,
|
||||
sa_relationship: Optional[RelationshipProperty] = None,
|
||||
sa_relationship: Optional[RelationshipProperty] = None, # type: ignore
|
||||
sa_relationship_args: Optional[Sequence[Any]] = None,
|
||||
sa_relationship_kwargs: Optional[Mapping[str, Any]] = None,
|
||||
) -> None:
|
||||
@@ -127,32 +127,32 @@ def Field(
|
||||
default: Any = Undefined,
|
||||
*,
|
||||
default_factory: Optional[NoArgAnyCallable] = None,
|
||||
alias: str = None,
|
||||
title: str = None,
|
||||
description: str = None,
|
||||
alias: Optional[str] = None,
|
||||
title: Optional[str] = None,
|
||||
description: Optional[str] = None,
|
||||
exclude: Union[
|
||||
AbstractSet[Union[int, str]], Mapping[Union[int, str], Any], Any
|
||||
] = None,
|
||||
include: Union[
|
||||
AbstractSet[Union[int, str]], Mapping[Union[int, str], Any], Any
|
||||
] = None,
|
||||
const: bool = None,
|
||||
gt: float = None,
|
||||
ge: float = None,
|
||||
lt: float = None,
|
||||
le: float = None,
|
||||
multiple_of: float = None,
|
||||
min_items: int = None,
|
||||
max_items: int = None,
|
||||
min_length: int = None,
|
||||
max_length: int = None,
|
||||
const: Optional[bool] = None,
|
||||
gt: Optional[float] = None,
|
||||
ge: Optional[float] = None,
|
||||
lt: Optional[float] = None,
|
||||
le: Optional[float] = None,
|
||||
multiple_of: Optional[float] = None,
|
||||
min_items: Optional[int] = None,
|
||||
max_items: Optional[int] = None,
|
||||
min_length: Optional[int] = None,
|
||||
max_length: Optional[int] = None,
|
||||
allow_mutation: bool = True,
|
||||
regex: str = None,
|
||||
regex: Optional[str] = None,
|
||||
primary_key: bool = False,
|
||||
foreign_key: Optional[Any] = None,
|
||||
nullable: Union[bool, UndefinedType] = Undefined,
|
||||
index: Union[bool, UndefinedType] = Undefined,
|
||||
sa_column: Union[Column, UndefinedType] = Undefined,
|
||||
sa_column: Union[Column, UndefinedType] = Undefined, # type: ignore
|
||||
sa_column_args: Union[Sequence[Any], UndefinedType] = Undefined,
|
||||
sa_column_kwargs: Union[Mapping[str, Any], UndefinedType] = Undefined,
|
||||
schema_extra: Optional[Dict[str, Any]] = None,
|
||||
@@ -195,7 +195,7 @@ def Relationship(
|
||||
*,
|
||||
back_populates: Optional[str] = None,
|
||||
link_model: Optional[Any] = None,
|
||||
sa_relationship: Optional[RelationshipProperty] = None,
|
||||
sa_relationship: Optional[RelationshipProperty] = None, # type: ignore
|
||||
sa_relationship_args: Optional[Sequence[Any]] = None,
|
||||
sa_relationship_kwargs: Optional[Mapping[str, Any]] = None,
|
||||
) -> Any:
|
||||
@@ -217,19 +217,25 @@ class SQLModelMetaclass(ModelMetaclass, DeclarativeMeta):
|
||||
|
||||
# Replicate SQLAlchemy
|
||||
def __setattr__(cls, name: str, value: Any) -> None:
|
||||
if getattr(cls.__config__, "table", False): # type: ignore
|
||||
if getattr(cls.__config__, "table", False):
|
||||
DeclarativeMeta.__setattr__(cls, name, value)
|
||||
else:
|
||||
super().__setattr__(name, value)
|
||||
|
||||
def __delattr__(cls, name: str) -> None:
|
||||
if getattr(cls.__config__, "table", False): # type: ignore
|
||||
if getattr(cls.__config__, "table", False):
|
||||
DeclarativeMeta.__delattr__(cls, name)
|
||||
else:
|
||||
super().__delattr__(name)
|
||||
|
||||
# From Pydantic
|
||||
def __new__(cls, name, bases, class_dict: dict, **kwargs) -> Any:
|
||||
def __new__(
|
||||
cls,
|
||||
name: str,
|
||||
bases: Tuple[Type[Any], ...],
|
||||
class_dict: Dict[str, Any],
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
relationships: Dict[str, RelationshipInfo] = {}
|
||||
dict_for_pydantic = {}
|
||||
original_annotations = resolve_annotations(
|
||||
@@ -342,7 +348,7 @@ class SQLModelMetaclass(ModelMetaclass, DeclarativeMeta):
|
||||
)
|
||||
relationship_to = temp_field.type_
|
||||
if isinstance(temp_field.type_, ForwardRef):
|
||||
relationship_to = temp_field.type_.__forward_arg__ # type: ignore
|
||||
relationship_to = temp_field.type_.__forward_arg__
|
||||
rel_kwargs: Dict[str, Any] = {}
|
||||
if rel_info.back_populates:
|
||||
rel_kwargs["back_populates"] = rel_info.back_populates
|
||||
@@ -360,7 +366,7 @@ class SQLModelMetaclass(ModelMetaclass, DeclarativeMeta):
|
||||
rel_args.extend(rel_info.sa_relationship_args)
|
||||
if rel_info.sa_relationship_kwargs:
|
||||
rel_kwargs.update(rel_info.sa_relationship_kwargs)
|
||||
rel_value: RelationshipProperty = relationship(
|
||||
rel_value: RelationshipProperty = relationship( # type: ignore
|
||||
relationship_to, *rel_args, **rel_kwargs
|
||||
)
|
||||
dict_used[rel_name] = rel_value
|
||||
@@ -393,7 +399,10 @@ def get_sqlachemy_type(field: ModelField) -> Any:
|
||||
if issubclass(field.type_, bytes):
|
||||
return LargeBinary
|
||||
if issubclass(field.type_, Decimal):
|
||||
return Numeric
|
||||
return Numeric(
|
||||
precision=getattr(field.type_, "max_digits", None),
|
||||
scale=getattr(field.type_, "decimal_places", None),
|
||||
)
|
||||
if issubclass(field.type_, ipaddress.IPv4Address):
|
||||
return AutoString
|
||||
if issubclass(field.type_, ipaddress.IPv4Network):
|
||||
@@ -408,7 +417,7 @@ def get_sqlachemy_type(field: ModelField) -> Any:
|
||||
return GUID
|
||||
|
||||
|
||||
def get_column_from_field(field: ModelField) -> Column:
|
||||
def get_column_from_field(field: ModelField) -> Column: # type: ignore
|
||||
sa_column = getattr(field.field_info, "sa_column", Undefined)
|
||||
if isinstance(sa_column, Column):
|
||||
return sa_column
|
||||
@@ -440,10 +449,10 @@ def get_column_from_field(field: ModelField) -> Column:
|
||||
kwargs["default"] = sa_default
|
||||
sa_column_args = getattr(field.field_info, "sa_column_args", Undefined)
|
||||
if sa_column_args is not Undefined:
|
||||
args.extend(list(cast(Sequence, sa_column_args)))
|
||||
args.extend(list(cast(Sequence[Any], sa_column_args)))
|
||||
sa_column_kwargs = getattr(field.field_info, "sa_column_kwargs", Undefined)
|
||||
if sa_column_kwargs is not Undefined:
|
||||
kwargs.update(cast(dict, sa_column_kwargs))
|
||||
kwargs.update(cast(Dict[Any, Any], sa_column_kwargs))
|
||||
return Column(sa_type, *args, **kwargs)
|
||||
|
||||
|
||||
@@ -452,24 +461,27 @@ class_registry = weakref.WeakValueDictionary() # type: ignore
|
||||
default_registry = registry()
|
||||
|
||||
|
||||
def _value_items_is_true(v) -> bool:
|
||||
def _value_items_is_true(v: Any) -> bool:
|
||||
# Re-implement Pydantic's ValueItems.is_true() as it hasn't been released as of
|
||||
# the current latest, Pydantic 1.8.2
|
||||
return v is True or v is ...
|
||||
|
||||
|
||||
_TSQLModel = TypeVar("_TSQLModel", bound="SQLModel")
|
||||
|
||||
|
||||
class SQLModel(BaseModel, metaclass=SQLModelMetaclass, registry=default_registry):
|
||||
# SQLAlchemy needs to set weakref(s), Pydantic will set the other slots values
|
||||
__slots__ = ("__weakref__",)
|
||||
__tablename__: ClassVar[Union[str, Callable[..., str]]]
|
||||
__sqlmodel_relationships__: ClassVar[Dict[str, RelationshipProperty]]
|
||||
__sqlmodel_relationships__: ClassVar[Dict[str, RelationshipProperty]] # type: ignore
|
||||
__name__: ClassVar[str]
|
||||
metadata: ClassVar[MetaData]
|
||||
|
||||
class Config:
|
||||
orm_mode = True
|
||||
|
||||
def __new__(cls, *args, **kwargs) -> Any:
|
||||
def __new__(cls, *args: Any, **kwargs: Any) -> Any:
|
||||
new_object = super().__new__(cls)
|
||||
# SQLAlchemy doesn't call __init__ on the base class
|
||||
# Ref: https://docs.sqlalchemy.org/en/14/orm/constructors.html
|
||||
@@ -520,7 +532,9 @@ class SQLModel(BaseModel, metaclass=SQLModelMetaclass, registry=default_registry
|
||||
super().__setattr__(name, value)
|
||||
|
||||
@classmethod
|
||||
def from_orm(cls: Type["SQLModel"], obj: Any, update: Dict[str, Any] = None):
|
||||
def from_orm(
|
||||
cls: Type[_TSQLModel], obj: Any, update: Optional[Dict[str, Any]] = None
|
||||
) -> _TSQLModel:
|
||||
# Duplicated from Pydantic
|
||||
if not cls.__config__.orm_mode:
|
||||
raise ConfigError(
|
||||
@@ -533,7 +547,7 @@ class SQLModel(BaseModel, metaclass=SQLModelMetaclass, registry=default_registry
|
||||
# End SQLModel support dict
|
||||
if not getattr(cls.__config__, "table", False):
|
||||
# If not table, normal Pydantic code
|
||||
m = cls.__new__(cls)
|
||||
m: _TSQLModel = cls.__new__(cls)
|
||||
else:
|
||||
# If table, create the new instance normally to make SQLAlchemy create
|
||||
# the _sa_instance_state attribute
|
||||
@@ -554,7 +568,7 @@ class SQLModel(BaseModel, metaclass=SQLModelMetaclass, registry=default_registry
|
||||
|
||||
@classmethod
|
||||
def parse_obj(
|
||||
cls: Type["SQLModel"], obj: Any, update: Dict[str, Any] = None
|
||||
cls: Type["SQLModel"], obj: Any, update: Optional[Dict[str, Any]] = None
|
||||
) -> "SQLModel":
|
||||
obj = cls._enforce_dict_if_root(obj)
|
||||
# SQLModel, support update dict
|
||||
|
||||
@@ -10,14 +10,14 @@ from typing_extensions import Literal
|
||||
from ..engine.result import Result, ScalarResult
|
||||
from ..sql.base import Executable
|
||||
|
||||
_T = TypeVar("_T")
|
||||
_TSelectParam = TypeVar("_TSelectParam")
|
||||
|
||||
|
||||
class Session(_Session):
|
||||
@overload
|
||||
def exec(
|
||||
self,
|
||||
statement: Select[_T],
|
||||
statement: Select[_TSelectParam],
|
||||
*,
|
||||
params: Optional[Union[Mapping[str, Any], Sequence[Mapping[str, Any]]]] = None,
|
||||
execution_options: Mapping[str, Any] = util.EMPTY_DICT,
|
||||
@@ -25,13 +25,13 @@ class Session(_Session):
|
||||
_parent_execute_state: Optional[Any] = None,
|
||||
_add_event: Optional[Any] = None,
|
||||
**kw: Any,
|
||||
) -> Union[Result[_T]]:
|
||||
) -> Result[_TSelectParam]:
|
||||
...
|
||||
|
||||
@overload
|
||||
def exec(
|
||||
self,
|
||||
statement: SelectOfScalar[_T],
|
||||
statement: SelectOfScalar[_TSelectParam],
|
||||
*,
|
||||
params: Optional[Union[Mapping[str, Any], Sequence[Mapping[str, Any]]]] = None,
|
||||
execution_options: Mapping[str, Any] = util.EMPTY_DICT,
|
||||
@@ -39,12 +39,16 @@ class Session(_Session):
|
||||
_parent_execute_state: Optional[Any] = None,
|
||||
_add_event: Optional[Any] = None,
|
||||
**kw: Any,
|
||||
) -> Union[ScalarResult[_T]]:
|
||||
) -> ScalarResult[_TSelectParam]:
|
||||
...
|
||||
|
||||
def exec(
|
||||
self,
|
||||
statement: Union[Select[_T], SelectOfScalar[_T], Executable[_T]],
|
||||
statement: Union[
|
||||
Select[_TSelectParam],
|
||||
SelectOfScalar[_TSelectParam],
|
||||
Executable[_TSelectParam],
|
||||
],
|
||||
*,
|
||||
params: Optional[Union[Mapping[str, Any], Sequence[Mapping[str, Any]]]] = None,
|
||||
execution_options: Mapping[str, Any] = util.EMPTY_DICT,
|
||||
@@ -52,11 +56,11 @@ class Session(_Session):
|
||||
_parent_execute_state: Optional[Any] = None,
|
||||
_add_event: Optional[Any] = None,
|
||||
**kw: Any,
|
||||
) -> Union[Result[_T], ScalarResult[_T]]:
|
||||
) -> Union[Result[_TSelectParam], ScalarResult[_TSelectParam]]:
|
||||
results = super().execute(
|
||||
statement,
|
||||
params=params,
|
||||
execution_options=execution_options, # type: ignore
|
||||
execution_options=execution_options,
|
||||
bind_arguments=bind_arguments,
|
||||
_parent_execute_state=_parent_execute_state,
|
||||
_add_event=_add_event,
|
||||
@@ -70,7 +74,7 @@ class Session(_Session):
|
||||
self,
|
||||
statement: _Executable,
|
||||
params: Optional[Union[Mapping[str, Any], Sequence[Mapping[str, Any]]]] = None,
|
||||
execution_options: Mapping[str, Any] = util.EMPTY_DICT,
|
||||
execution_options: Optional[Mapping[str, Any]] = util.EMPTY_DICT,
|
||||
bind_arguments: Optional[Mapping[str, Any]] = None,
|
||||
_parent_execute_state: Optional[Any] = None,
|
||||
_add_event: Optional[Any] = None,
|
||||
@@ -97,7 +101,7 @@ class Session(_Session):
|
||||
return super().execute( # type: ignore
|
||||
statement,
|
||||
params=params,
|
||||
execution_options=execution_options, # type: ignore
|
||||
execution_options=execution_options,
|
||||
bind_arguments=bind_arguments,
|
||||
_parent_execute_state=_parent_execute_state,
|
||||
_add_event=_add_event,
|
||||
@@ -118,13 +122,13 @@ class Session(_Session):
|
||||
|
||||
def get(
|
||||
self,
|
||||
entity: Type[_T],
|
||||
entity: Type[_TSelectParam],
|
||||
ident: Any,
|
||||
options: Optional[Sequence[Any]] = None,
|
||||
populate_existing: bool = False,
|
||||
with_for_update: Optional[Union[Literal[True], Mapping[str, Any]]] = None,
|
||||
identity_token: Optional[Any] = None,
|
||||
) -> Optional[_T]:
|
||||
) -> Optional[_TSelectParam]:
|
||||
return super().get(
|
||||
entity,
|
||||
ident,
|
||||
|
||||
@@ -6,6 +6,4 @@ _T = TypeVar("_T")
|
||||
|
||||
|
||||
class Executable(_Executable, Generic[_T]):
|
||||
def __init__(self, *args, **kwargs):
|
||||
self.__dict__["_exec_options"] = kwargs.pop("_exec_options", None)
|
||||
super(_Executable, self).__init__(*args, **kwargs)
|
||||
pass
|
||||
|
||||
@@ -38,17 +38,16 @@ if sys.version_info.minor >= 7:
|
||||
class SelectOfScalar(_Select, Generic[_TSelect]):
|
||||
pass
|
||||
|
||||
|
||||
else:
|
||||
from typing import GenericMeta # type: ignore
|
||||
|
||||
class GenericSelectMeta(GenericMeta, _Select.__class__): # type: ignore
|
||||
pass
|
||||
|
||||
class _Py36Select(_Select, Generic[_TSelect], metaclass=GenericSelectMeta): # type: ignore
|
||||
class _Py36Select(_Select, Generic[_TSelect], metaclass=GenericSelectMeta):
|
||||
pass
|
||||
|
||||
class _Py36SelectOfScalar(_Select, Generic[_TSelect], metaclass=GenericSelectMeta): # type: ignore
|
||||
class _Py36SelectOfScalar(_Select, Generic[_TSelect], metaclass=GenericSelectMeta):
|
||||
pass
|
||||
|
||||
# Cast them for editors to work correctly, from several tricks tried, this works
|
||||
@@ -65,9 +64,9 @@ if TYPE_CHECKING: # pragma: no cover
|
||||
|
||||
_TScalar_0 = TypeVar(
|
||||
"_TScalar_0",
|
||||
Column,
|
||||
Sequence,
|
||||
Mapping,
|
||||
Column, # type: ignore
|
||||
Sequence, # type: ignore
|
||||
Mapping, # type: ignore
|
||||
UUID,
|
||||
datetime,
|
||||
float,
|
||||
@@ -83,9 +82,9 @@ _TModel_0 = TypeVar("_TModel_0", bound="SQLModel")
|
||||
|
||||
_TScalar_1 = TypeVar(
|
||||
"_TScalar_1",
|
||||
Column,
|
||||
Sequence,
|
||||
Mapping,
|
||||
Column, # type: ignore
|
||||
Sequence, # type: ignore
|
||||
Mapping, # type: ignore
|
||||
UUID,
|
||||
datetime,
|
||||
float,
|
||||
@@ -101,9 +100,9 @@ _TModel_1 = TypeVar("_TModel_1", bound="SQLModel")
|
||||
|
||||
_TScalar_2 = TypeVar(
|
||||
"_TScalar_2",
|
||||
Column,
|
||||
Sequence,
|
||||
Mapping,
|
||||
Column, # type: ignore
|
||||
Sequence, # type: ignore
|
||||
Mapping, # type: ignore
|
||||
UUID,
|
||||
datetime,
|
||||
float,
|
||||
@@ -119,9 +118,9 @@ _TModel_2 = TypeVar("_TModel_2", bound="SQLModel")
|
||||
|
||||
_TScalar_3 = TypeVar(
|
||||
"_TScalar_3",
|
||||
Column,
|
||||
Sequence,
|
||||
Mapping,
|
||||
Column, # type: ignore
|
||||
Sequence, # type: ignore
|
||||
Mapping, # type: ignore
|
||||
UUID,
|
||||
datetime,
|
||||
float,
|
||||
@@ -446,14 +445,14 @@ def select( # type: ignore
|
||||
# Generated overloads end
|
||||
|
||||
|
||||
def select(*entities: Any, **kw: Any) -> Union[Select, SelectOfScalar]:
|
||||
def select(*entities: Any, **kw: Any) -> Union[Select, SelectOfScalar]: # type: ignore
|
||||
if len(entities) == 1:
|
||||
return SelectOfScalar._create(*entities, **kw) # type: ignore
|
||||
return Select._create(*entities, **kw) # type: ignore
|
||||
|
||||
|
||||
# TODO: add several @overload from Python types to SQLAlchemy equivalents
|
||||
def col(column_expression: Any) -> ColumnClause:
|
||||
def col(column_expression: Any) -> ColumnClause: # type: ignore
|
||||
if not isinstance(column_expression, (ColumnClause, Column, InstrumentedAttribute)):
|
||||
raise RuntimeError(f"Not a SQLAlchemy column: {column_expression}")
|
||||
return column_expression
|
||||
|
||||
@@ -36,7 +36,6 @@ if sys.version_info.minor >= 7:
|
||||
class SelectOfScalar(_Select, Generic[_TSelect]):
|
||||
pass
|
||||
|
||||
|
||||
else:
|
||||
from typing import GenericMeta # type: ignore
|
||||
|
||||
@@ -63,9 +62,9 @@ if TYPE_CHECKING: # pragma: no cover
|
||||
{% for i in range(number_of_types) %}
|
||||
_TScalar_{{ i }} = TypeVar(
|
||||
"_TScalar_{{ i }}",
|
||||
Column,
|
||||
Sequence,
|
||||
Mapping,
|
||||
Column, # type: ignore
|
||||
Sequence, # type: ignore
|
||||
Mapping, # type: ignore
|
||||
UUID,
|
||||
datetime,
|
||||
float,
|
||||
@@ -106,14 +105,14 @@ def select( # type: ignore
|
||||
|
||||
# Generated overloads end
|
||||
|
||||
def select(*entities: Any, **kw: Any) -> Union[Select, SelectOfScalar]:
|
||||
def select(*entities: Any, **kw: Any) -> Union[Select, SelectOfScalar]: # type: ignore
|
||||
if len(entities) == 1:
|
||||
return SelectOfScalar._create(*entities, **kw) # type: ignore
|
||||
return Select._create(*entities, **kw)
|
||||
return Select._create(*entities, **kw) # type: ignore
|
||||
|
||||
|
||||
# TODO: add several @overload from Python types to SQLAlchemy equivalents
|
||||
def col(column_expression: Any) -> ColumnClause:
|
||||
def col(column_expression: Any) -> ColumnClause: # type: ignore
|
||||
if not isinstance(column_expression, (ColumnClause, Column, InstrumentedAttribute)):
|
||||
raise RuntimeError(f"Not a SQLAlchemy column: {column_expression}")
|
||||
return column_expression
|
||||
|
||||
@@ -1,13 +1,14 @@
|
||||
import uuid
|
||||
from typing import Any, cast
|
||||
from typing import Any, Optional, cast
|
||||
|
||||
from sqlalchemy import types
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
from sqlalchemy.engine.interfaces import Dialect
|
||||
from sqlalchemy.sql.type_api import TypeEngine
|
||||
from sqlalchemy.types import CHAR, TypeDecorator
|
||||
|
||||
|
||||
class AutoString(types.TypeDecorator):
|
||||
class AutoString(types.TypeDecorator): # type: ignore
|
||||
|
||||
impl = types.String
|
||||
cache_ok = True
|
||||
@@ -22,7 +23,7 @@ class AutoString(types.TypeDecorator):
|
||||
|
||||
# Reference form SQLAlchemy docs: https://docs.sqlalchemy.org/en/14/core/custom_types.html#backend-agnostic-guid-type
|
||||
# with small modifications
|
||||
class GUID(TypeDecorator):
|
||||
class GUID(TypeDecorator): # type: ignore
|
||||
"""Platform-independent GUID type.
|
||||
|
||||
Uses PostgreSQL's UUID type, otherwise uses
|
||||
@@ -33,13 +34,13 @@ class GUID(TypeDecorator):
|
||||
impl = CHAR
|
||||
cache_ok = True
|
||||
|
||||
def load_dialect_impl(self, dialect):
|
||||
def load_dialect_impl(self, dialect: Dialect) -> TypeEngine: # type: ignore
|
||||
if dialect.name == "postgresql":
|
||||
return dialect.type_descriptor(UUID())
|
||||
return dialect.type_descriptor(UUID()) # type: ignore
|
||||
else:
|
||||
return dialect.type_descriptor(CHAR(32))
|
||||
return dialect.type_descriptor(CHAR(32)) # type: ignore
|
||||
|
||||
def process_bind_param(self, value, dialect):
|
||||
def process_bind_param(self, value: Any, dialect: Dialect) -> Optional[str]:
|
||||
if value is None:
|
||||
return value
|
||||
elif dialect.name == "postgresql":
|
||||
@@ -51,10 +52,10 @@ class GUID(TypeDecorator):
|
||||
# hexstring
|
||||
return f"{value.int:x}"
|
||||
|
||||
def process_result_value(self, value, dialect):
|
||||
def process_result_value(self, value: Any, dialect: Dialect) -> Optional[uuid.UUID]:
|
||||
if value is None:
|
||||
return value
|
||||
else:
|
||||
if not isinstance(value, uuid.UUID):
|
||||
value = uuid.UUID(value)
|
||||
return value
|
||||
return cast(uuid.UUID, value)
|
||||
|
||||
0
tests/test_advanced/__init__.py
Normal file
0
tests/test_advanced/__init__.py
Normal file
0
tests/test_advanced/test_decimal/__init__.py
Normal file
0
tests/test_advanced/test_decimal/__init__.py
Normal file
44
tests/test_advanced/test_decimal/test_tutorial001.py
Normal file
44
tests/test_advanced/test_decimal/test_tutorial001.py
Normal file
@@ -0,0 +1,44 @@
|
||||
from decimal import Decimal
|
||||
from unittest.mock import patch
|
||||
|
||||
from sqlmodel import create_engine
|
||||
|
||||
from ...conftest import get_testing_print_function
|
||||
|
||||
expected_calls = [
|
||||
[
|
||||
"Hero 1:",
|
||||
{
|
||||
"name": "Deadpond",
|
||||
"age": None,
|
||||
"id": 1,
|
||||
"secret_name": "Dive Wilson",
|
||||
"money": Decimal("1.100"),
|
||||
},
|
||||
],
|
||||
[
|
||||
"Hero 2:",
|
||||
{
|
||||
"name": "Rusty-Man",
|
||||
"age": 48,
|
||||
"id": 3,
|
||||
"secret_name": "Tommy Sharp",
|
||||
"money": Decimal("2.200"),
|
||||
},
|
||||
],
|
||||
["Total money: 3.300"],
|
||||
]
|
||||
|
||||
|
||||
def test_tutorial(clear_sqlmodel):
|
||||
from docs_src.advanced.decimal import tutorial001 as mod
|
||||
|
||||
mod.sqlite_url = "sqlite://"
|
||||
mod.engine = create_engine(mod.sqlite_url)
|
||||
calls = []
|
||||
|
||||
new_print = get_testing_print_function(calls)
|
||||
|
||||
with patch("builtins.print", new=new_print):
|
||||
mod.main()
|
||||
assert calls == expected_calls
|
||||
Reference in New Issue
Block a user