✨ Do not allow invalid combinations of field parameters for columns and relationships, sa_column
excludes sa_column_args
, primary_key
, nullable
, etc. (#681)
* ♻️ Make sa_column exclusive, do not allow incompatible arguments, sa_column_args, primary_key, etc * ✅ Add tests for new errors when incorrectly using sa_column * ✅ Add tests for sa_column_args and sa_column_kwargs * ♻️ Do not allow sa_relationship with sa_relationship_args or sa_relationship_kwargs * ✅ Add tests for relationship errors * ✅ Fix test for sa_column_args
This commit is contained in:
parent
e4e1385eed
commit
717594ef13
151
sqlmodel/main.py
151
sqlmodel/main.py
@ -22,6 +22,7 @@ from typing import (
|
|||||||
TypeVar,
|
TypeVar,
|
||||||
Union,
|
Union,
|
||||||
cast,
|
cast,
|
||||||
|
overload,
|
||||||
)
|
)
|
||||||
|
|
||||||
from pydantic import BaseConfig, BaseModel
|
from pydantic import BaseConfig, BaseModel
|
||||||
@ -87,6 +88,28 @@ class FieldInfo(PydanticFieldInfo):
|
|||||||
"Passing sa_column_kwargs is not supported when "
|
"Passing sa_column_kwargs is not supported when "
|
||||||
"also passing a sa_column"
|
"also passing a sa_column"
|
||||||
)
|
)
|
||||||
|
if primary_key is not Undefined:
|
||||||
|
raise RuntimeError(
|
||||||
|
"Passing primary_key is not supported when "
|
||||||
|
"also passing a sa_column"
|
||||||
|
)
|
||||||
|
if nullable is not Undefined:
|
||||||
|
raise RuntimeError(
|
||||||
|
"Passing nullable is not supported when " "also passing a sa_column"
|
||||||
|
)
|
||||||
|
if foreign_key is not Undefined:
|
||||||
|
raise RuntimeError(
|
||||||
|
"Passing foreign_key is not supported when "
|
||||||
|
"also passing a sa_column"
|
||||||
|
)
|
||||||
|
if unique is not Undefined:
|
||||||
|
raise RuntimeError(
|
||||||
|
"Passing unique is not supported when " "also passing a sa_column"
|
||||||
|
)
|
||||||
|
if index is not Undefined:
|
||||||
|
raise RuntimeError(
|
||||||
|
"Passing index is not supported when " "also passing a sa_column"
|
||||||
|
)
|
||||||
super().__init__(default=default, **kwargs)
|
super().__init__(default=default, **kwargs)
|
||||||
self.primary_key = primary_key
|
self.primary_key = primary_key
|
||||||
self.nullable = nullable
|
self.nullable = nullable
|
||||||
@ -126,6 +149,7 @@ class RelationshipInfo(Representation):
|
|||||||
self.sa_relationship_kwargs = sa_relationship_kwargs
|
self.sa_relationship_kwargs = sa_relationship_kwargs
|
||||||
|
|
||||||
|
|
||||||
|
@overload
|
||||||
def Field(
|
def Field(
|
||||||
default: Any = Undefined,
|
default: Any = Undefined,
|
||||||
*,
|
*,
|
||||||
@ -156,9 +180,88 @@ def Field(
|
|||||||
regex: Optional[str] = None,
|
regex: Optional[str] = None,
|
||||||
discriminator: Optional[str] = None,
|
discriminator: Optional[str] = None,
|
||||||
repr: bool = True,
|
repr: bool = True,
|
||||||
primary_key: bool = False,
|
primary_key: Union[bool, UndefinedType] = Undefined,
|
||||||
foreign_key: Optional[Any] = None,
|
foreign_key: Any = Undefined,
|
||||||
unique: bool = False,
|
unique: Union[bool, UndefinedType] = Undefined,
|
||||||
|
nullable: Union[bool, UndefinedType] = Undefined,
|
||||||
|
index: Union[bool, UndefinedType] = Undefined,
|
||||||
|
sa_column_args: Union[Sequence[Any], UndefinedType] = Undefined,
|
||||||
|
sa_column_kwargs: Union[Mapping[str, Any], UndefinedType] = Undefined,
|
||||||
|
schema_extra: Optional[Dict[str, Any]] = None,
|
||||||
|
) -> Any:
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def Field(
|
||||||
|
default: Any = Undefined,
|
||||||
|
*,
|
||||||
|
default_factory: Optional[NoArgAnyCallable] = None,
|
||||||
|
alias: Optional[str] = None,
|
||||||
|
title: Optional[str] = None,
|
||||||
|
description: Optional[str] = None,
|
||||||
|
exclude: Union[
|
||||||
|
AbstractSet[Union[int, str]], Mapping[Union[int, str], Any], Any
|
||||||
|
] = None,
|
||||||
|
include: Union[
|
||||||
|
AbstractSet[Union[int, str]], Mapping[Union[int, str], Any], Any
|
||||||
|
] = None,
|
||||||
|
const: Optional[bool] = None,
|
||||||
|
gt: Optional[float] = None,
|
||||||
|
ge: Optional[float] = None,
|
||||||
|
lt: Optional[float] = None,
|
||||||
|
le: Optional[float] = None,
|
||||||
|
multiple_of: Optional[float] = None,
|
||||||
|
max_digits: Optional[int] = None,
|
||||||
|
decimal_places: Optional[int] = None,
|
||||||
|
min_items: Optional[int] = None,
|
||||||
|
max_items: Optional[int] = None,
|
||||||
|
unique_items: Optional[bool] = None,
|
||||||
|
min_length: Optional[int] = None,
|
||||||
|
max_length: Optional[int] = None,
|
||||||
|
allow_mutation: bool = True,
|
||||||
|
regex: Optional[str] = None,
|
||||||
|
discriminator: Optional[str] = None,
|
||||||
|
repr: bool = True,
|
||||||
|
sa_column: Union[Column, UndefinedType] = Undefined, # type: ignore
|
||||||
|
schema_extra: Optional[Dict[str, Any]] = None,
|
||||||
|
) -> Any:
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
def Field(
|
||||||
|
default: Any = Undefined,
|
||||||
|
*,
|
||||||
|
default_factory: Optional[NoArgAnyCallable] = None,
|
||||||
|
alias: Optional[str] = None,
|
||||||
|
title: Optional[str] = None,
|
||||||
|
description: Optional[str] = None,
|
||||||
|
exclude: Union[
|
||||||
|
AbstractSet[Union[int, str]], Mapping[Union[int, str], Any], Any
|
||||||
|
] = None,
|
||||||
|
include: Union[
|
||||||
|
AbstractSet[Union[int, str]], Mapping[Union[int, str], Any], Any
|
||||||
|
] = None,
|
||||||
|
const: Optional[bool] = None,
|
||||||
|
gt: Optional[float] = None,
|
||||||
|
ge: Optional[float] = None,
|
||||||
|
lt: Optional[float] = None,
|
||||||
|
le: Optional[float] = None,
|
||||||
|
multiple_of: Optional[float] = None,
|
||||||
|
max_digits: Optional[int] = None,
|
||||||
|
decimal_places: Optional[int] = None,
|
||||||
|
min_items: Optional[int] = None,
|
||||||
|
max_items: Optional[int] = None,
|
||||||
|
unique_items: Optional[bool] = None,
|
||||||
|
min_length: Optional[int] = None,
|
||||||
|
max_length: Optional[int] = None,
|
||||||
|
allow_mutation: bool = True,
|
||||||
|
regex: Optional[str] = None,
|
||||||
|
discriminator: Optional[str] = None,
|
||||||
|
repr: bool = True,
|
||||||
|
primary_key: Union[bool, UndefinedType] = Undefined,
|
||||||
|
foreign_key: Any = Undefined,
|
||||||
|
unique: Union[bool, UndefinedType] = Undefined,
|
||||||
nullable: Union[bool, UndefinedType] = Undefined,
|
nullable: Union[bool, UndefinedType] = Undefined,
|
||||||
index: Union[bool, UndefinedType] = Undefined,
|
index: Union[bool, UndefinedType] = Undefined,
|
||||||
sa_column: Union[Column, UndefinedType] = Undefined, # type: ignore
|
sa_column: Union[Column, UndefinedType] = Undefined, # type: ignore
|
||||||
@ -206,6 +309,27 @@ def Field(
|
|||||||
return field_info
|
return field_info
|
||||||
|
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def Relationship(
|
||||||
|
*,
|
||||||
|
back_populates: Optional[str] = None,
|
||||||
|
link_model: Optional[Any] = None,
|
||||||
|
sa_relationship_args: Optional[Sequence[Any]] = None,
|
||||||
|
sa_relationship_kwargs: Optional[Mapping[str, Any]] = None,
|
||||||
|
) -> Any:
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def Relationship(
|
||||||
|
*,
|
||||||
|
back_populates: Optional[str] = None,
|
||||||
|
link_model: Optional[Any] = None,
|
||||||
|
sa_relationship: Optional[RelationshipProperty] = None, # type: ignore
|
||||||
|
) -> Any:
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
def Relationship(
|
def Relationship(
|
||||||
*,
|
*,
|
||||||
back_populates: Optional[str] = None,
|
back_populates: Optional[str] = None,
|
||||||
@ -440,21 +564,28 @@ def get_column_from_field(field: ModelField) -> Column: # type: ignore
|
|||||||
if isinstance(sa_column, Column):
|
if isinstance(sa_column, Column):
|
||||||
return sa_column
|
return sa_column
|
||||||
sa_type = get_sqlalchemy_type(field)
|
sa_type = get_sqlalchemy_type(field)
|
||||||
primary_key = getattr(field.field_info, "primary_key", False)
|
primary_key = getattr(field.field_info, "primary_key", Undefined)
|
||||||
|
if primary_key is Undefined:
|
||||||
|
primary_key = False
|
||||||
index = getattr(field.field_info, "index", Undefined)
|
index = getattr(field.field_info, "index", Undefined)
|
||||||
if index is Undefined:
|
if index is Undefined:
|
||||||
index = False
|
index = False
|
||||||
nullable = not primary_key and _is_field_noneable(field)
|
nullable = not primary_key and _is_field_noneable(field)
|
||||||
# Override derived nullability if the nullable property is set explicitly
|
# Override derived nullability if the nullable property is set explicitly
|
||||||
# on the field
|
# on the field
|
||||||
if hasattr(field.field_info, "nullable"):
|
field_nullable = getattr(field.field_info, "nullable", Undefined) # noqa: B009
|
||||||
field_nullable = getattr(field.field_info, "nullable") # noqa: B009
|
if field_nullable != Undefined:
|
||||||
if field_nullable != Undefined:
|
assert not isinstance(field_nullable, UndefinedType)
|
||||||
nullable = field_nullable
|
nullable = field_nullable
|
||||||
args = []
|
args = []
|
||||||
foreign_key = getattr(field.field_info, "foreign_key", None)
|
foreign_key = getattr(field.field_info, "foreign_key", Undefined)
|
||||||
unique = getattr(field.field_info, "unique", False)
|
if foreign_key is Undefined:
|
||||||
|
foreign_key = None
|
||||||
|
unique = getattr(field.field_info, "unique", Undefined)
|
||||||
|
if unique is Undefined:
|
||||||
|
unique = False
|
||||||
if foreign_key:
|
if foreign_key:
|
||||||
|
assert isinstance(foreign_key, str)
|
||||||
args.append(ForeignKey(foreign_key))
|
args.append(ForeignKey(foreign_key))
|
||||||
kwargs = {
|
kwargs = {
|
||||||
"primary_key": primary_key,
|
"primary_key": primary_key,
|
||||||
|
39
tests/test_field_sa_args_kwargs.py
Normal file
39
tests/test_field_sa_args_kwargs.py
Normal file
@ -0,0 +1,39 @@
|
|||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from sqlalchemy import ForeignKey
|
||||||
|
from sqlmodel import Field, SQLModel, create_engine
|
||||||
|
|
||||||
|
|
||||||
|
def test_sa_column_args(clear_sqlmodel, caplog) -> None:
|
||||||
|
class Team(SQLModel, table=True):
|
||||||
|
id: Optional[int] = Field(default=None, primary_key=True)
|
||||||
|
name: str
|
||||||
|
|
||||||
|
class Hero(SQLModel, table=True):
|
||||||
|
id: Optional[int] = Field(default=None, primary_key=True)
|
||||||
|
team_id: Optional[int] = Field(
|
||||||
|
default=None,
|
||||||
|
sa_column_args=[ForeignKey("team.id")],
|
||||||
|
)
|
||||||
|
|
||||||
|
engine = create_engine("sqlite://", echo=True)
|
||||||
|
SQLModel.metadata.create_all(engine)
|
||||||
|
create_table_log = [
|
||||||
|
message for message in caplog.messages if "CREATE TABLE hero" in message
|
||||||
|
][0]
|
||||||
|
assert "FOREIGN KEY(team_id) REFERENCES team (id)" in create_table_log
|
||||||
|
|
||||||
|
|
||||||
|
def test_sa_column_kargs(clear_sqlmodel, caplog) -> None:
|
||||||
|
class Item(SQLModel, table=True):
|
||||||
|
id: Optional[int] = Field(
|
||||||
|
default=None,
|
||||||
|
sa_column_kwargs={"primary_key": True},
|
||||||
|
)
|
||||||
|
|
||||||
|
engine = create_engine("sqlite://", echo=True)
|
||||||
|
SQLModel.metadata.create_all(engine)
|
||||||
|
create_table_log = [
|
||||||
|
message for message in caplog.messages if "CREATE TABLE item" in message
|
||||||
|
][0]
|
||||||
|
assert "PRIMARY KEY (id)" in create_table_log
|
99
tests/test_field_sa_column.py
Normal file
99
tests/test_field_sa_column.py
Normal file
@ -0,0 +1,99 @@
|
|||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from sqlalchemy import Column, Integer, String
|
||||||
|
from sqlmodel import Field, SQLModel
|
||||||
|
|
||||||
|
|
||||||
|
def test_sa_column_takes_precedence() -> None:
|
||||||
|
class Item(SQLModel, table=True):
|
||||||
|
id: Optional[int] = Field(
|
||||||
|
default=None,
|
||||||
|
sa_column=Column(String, primary_key=True, nullable=False),
|
||||||
|
)
|
||||||
|
|
||||||
|
# It would have been nullable with no sa_column
|
||||||
|
assert Item.id.nullable is False # type: ignore
|
||||||
|
assert isinstance(Item.id.type, String) # type: ignore
|
||||||
|
|
||||||
|
|
||||||
|
def test_sa_column_no_sa_args() -> None:
|
||||||
|
with pytest.raises(RuntimeError):
|
||||||
|
|
||||||
|
class Item(SQLModel, table=True):
|
||||||
|
id: Optional[int] = Field(
|
||||||
|
default=None,
|
||||||
|
sa_column_args=[Integer],
|
||||||
|
sa_column=Column(Integer, primary_key=True),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_sa_column_no_sa_kargs() -> None:
|
||||||
|
with pytest.raises(RuntimeError):
|
||||||
|
|
||||||
|
class Item(SQLModel, table=True):
|
||||||
|
id: Optional[int] = Field(
|
||||||
|
default=None,
|
||||||
|
sa_column_kwargs={"primary_key": True},
|
||||||
|
sa_column=Column(Integer, primary_key=True),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_sa_column_no_primary_key() -> None:
|
||||||
|
with pytest.raises(RuntimeError):
|
||||||
|
|
||||||
|
class Item(SQLModel, table=True):
|
||||||
|
id: Optional[int] = Field(
|
||||||
|
default=None,
|
||||||
|
primary_key=True,
|
||||||
|
sa_column=Column(Integer, primary_key=True),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_sa_column_no_nullable() -> None:
|
||||||
|
with pytest.raises(RuntimeError):
|
||||||
|
|
||||||
|
class Item(SQLModel, table=True):
|
||||||
|
id: Optional[int] = Field(
|
||||||
|
default=None,
|
||||||
|
nullable=True,
|
||||||
|
sa_column=Column(Integer, primary_key=True),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_sa_column_no_foreign_key() -> None:
|
||||||
|
with pytest.raises(RuntimeError):
|
||||||
|
|
||||||
|
class Team(SQLModel, table=True):
|
||||||
|
id: Optional[int] = Field(default=None, primary_key=True)
|
||||||
|
name: str
|
||||||
|
|
||||||
|
class Hero(SQLModel, table=True):
|
||||||
|
id: Optional[int] = Field(default=None, primary_key=True)
|
||||||
|
team_id: Optional[int] = Field(
|
||||||
|
default=None,
|
||||||
|
foreign_key="team.id",
|
||||||
|
sa_column=Column(Integer, primary_key=True),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_sa_column_no_unique() -> None:
|
||||||
|
with pytest.raises(RuntimeError):
|
||||||
|
|
||||||
|
class Item(SQLModel, table=True):
|
||||||
|
id: Optional[int] = Field(
|
||||||
|
default=None,
|
||||||
|
unique=True,
|
||||||
|
sa_column=Column(Integer, primary_key=True),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_sa_column_no_index() -> None:
|
||||||
|
with pytest.raises(RuntimeError):
|
||||||
|
|
||||||
|
class Item(SQLModel, table=True):
|
||||||
|
id: Optional[int] = Field(
|
||||||
|
default=None,
|
||||||
|
index=True,
|
||||||
|
sa_column=Column(Integer, primary_key=True),
|
||||||
|
)
|
53
tests/test_field_sa_relationship.py
Normal file
53
tests/test_field_sa_relationship.py
Normal file
@ -0,0 +1,53 @@
|
|||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from sqlalchemy.orm import relationship
|
||||||
|
from sqlmodel import Field, Relationship, SQLModel
|
||||||
|
|
||||||
|
|
||||||
|
def test_sa_relationship_no_args() -> None:
|
||||||
|
with pytest.raises(RuntimeError):
|
||||||
|
|
||||||
|
class Team(SQLModel, table=True):
|
||||||
|
id: Optional[int] = Field(default=None, primary_key=True)
|
||||||
|
name: str = Field(index=True)
|
||||||
|
headquarters: str
|
||||||
|
|
||||||
|
heroes: List["Hero"] = Relationship(
|
||||||
|
back_populates="team",
|
||||||
|
sa_relationship_args=["Hero"],
|
||||||
|
sa_relationship=relationship("Hero", back_populates="team"),
|
||||||
|
)
|
||||||
|
|
||||||
|
class Hero(SQLModel, table=True):
|
||||||
|
id: Optional[int] = Field(default=None, primary_key=True)
|
||||||
|
name: str = Field(index=True)
|
||||||
|
secret_name: str
|
||||||
|
age: Optional[int] = Field(default=None, index=True)
|
||||||
|
|
||||||
|
team_id: Optional[int] = Field(default=None, foreign_key="team.id")
|
||||||
|
team: Optional[Team] = Relationship(back_populates="heroes")
|
||||||
|
|
||||||
|
|
||||||
|
def test_sa_relationship_no_kwargs() -> None:
|
||||||
|
with pytest.raises(RuntimeError):
|
||||||
|
|
||||||
|
class Team(SQLModel, table=True):
|
||||||
|
id: Optional[int] = Field(default=None, primary_key=True)
|
||||||
|
name: str = Field(index=True)
|
||||||
|
headquarters: str
|
||||||
|
|
||||||
|
heroes: List["Hero"] = Relationship(
|
||||||
|
back_populates="team",
|
||||||
|
sa_relationship_kwargs={"lazy": "selectin"},
|
||||||
|
sa_relationship=relationship("Hero", back_populates="team"),
|
||||||
|
)
|
||||||
|
|
||||||
|
class Hero(SQLModel, table=True):
|
||||||
|
id: Optional[int] = Field(default=None, primary_key=True)
|
||||||
|
name: str = Field(index=True)
|
||||||
|
secret_name: str
|
||||||
|
age: Optional[int] = Field(default=None, index=True)
|
||||||
|
|
||||||
|
team_id: Optional[int] = Field(default=None, foreign_key="team.id")
|
||||||
|
team: Optional[Team] = Relationship(back_populates="heroes")
|
Loading…
x
Reference in New Issue
Block a user