🐛 Fix support for types with Optional[Annoated[x, f()]]
, e.g. id: Optional[pydantic.UUID4]
(#1093)
This commit is contained in:
parent
4eaf8b9efb
commit
a14ab0bd3c
@ -21,7 +21,7 @@ from typing import (
|
|||||||
from pydantic import VERSION as P_VERSION
|
from pydantic import VERSION as P_VERSION
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from pydantic.fields import FieldInfo
|
from pydantic.fields import FieldInfo
|
||||||
from typing_extensions import get_args, get_origin
|
from typing_extensions import Annotated, get_args, get_origin
|
||||||
|
|
||||||
# Reassign variable to make it reexported for mypy
|
# Reassign variable to make it reexported for mypy
|
||||||
PYDANTIC_VERSION = P_VERSION
|
PYDANTIC_VERSION = P_VERSION
|
||||||
@ -177,16 +177,17 @@ if IS_PYDANTIC_V2:
|
|||||||
return False
|
return False
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def get_type_from_field(field: Any) -> Any:
|
def get_sa_type_from_type_annotation(annotation: Any) -> Any:
|
||||||
type_: Any = field.annotation
|
|
||||||
# Resolve Optional fields
|
# Resolve Optional fields
|
||||||
if type_ is None:
|
if annotation is None:
|
||||||
raise ValueError("Missing field type")
|
raise ValueError("Missing field type")
|
||||||
origin = get_origin(type_)
|
origin = get_origin(annotation)
|
||||||
if origin is None:
|
if origin is None:
|
||||||
return type_
|
return annotation
|
||||||
|
elif origin is Annotated:
|
||||||
|
return get_sa_type_from_type_annotation(get_args(annotation)[0])
|
||||||
if _is_union_type(origin):
|
if _is_union_type(origin):
|
||||||
bases = get_args(type_)
|
bases = get_args(annotation)
|
||||||
if len(bases) > 2:
|
if len(bases) > 2:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Cannot have a (non-optional) union as a SQLAlchemy field"
|
"Cannot have a (non-optional) union as a SQLAlchemy field"
|
||||||
@ -197,9 +198,14 @@ if IS_PYDANTIC_V2:
|
|||||||
"Cannot have a (non-optional) union as a SQLAlchemy field"
|
"Cannot have a (non-optional) union as a SQLAlchemy field"
|
||||||
)
|
)
|
||||||
# Optional unions are allowed
|
# Optional unions are allowed
|
||||||
return bases[0] if bases[0] is not NoneType else bases[1]
|
use_type = bases[0] if bases[0] is not NoneType else bases[1]
|
||||||
|
return get_sa_type_from_type_annotation(use_type)
|
||||||
return origin
|
return origin
|
||||||
|
|
||||||
|
def get_sa_type_from_field(field: Any) -> Any:
|
||||||
|
type_: Any = field.annotation
|
||||||
|
return get_sa_type_from_type_annotation(type_)
|
||||||
|
|
||||||
def get_field_metadata(field: Any) -> Any:
|
def get_field_metadata(field: Any) -> Any:
|
||||||
for meta in field.metadata:
|
for meta in field.metadata:
|
||||||
if isinstance(meta, (PydanticMetadata, MaxLen)):
|
if isinstance(meta, (PydanticMetadata, MaxLen)):
|
||||||
@ -444,7 +450,7 @@ else:
|
|||||||
)
|
)
|
||||||
return field.allow_none # type: ignore[no-any-return, attr-defined]
|
return field.allow_none # type: ignore[no-any-return, attr-defined]
|
||||||
|
|
||||||
def get_type_from_field(field: Any) -> Any:
|
def get_sa_type_from_field(field: Any) -> Any:
|
||||||
if isinstance(field.type_, type) and field.shape == SHAPE_SINGLETON:
|
if isinstance(field.type_, type) and field.shape == SHAPE_SINGLETON:
|
||||||
return field.type_
|
return field.type_
|
||||||
raise ValueError(f"The field {field.name} has no matching SQLAlchemy type")
|
raise ValueError(f"The field {field.name} has no matching SQLAlchemy type")
|
||||||
|
@ -71,7 +71,7 @@ from ._compat import ( # type: ignore[attr-defined]
|
|||||||
get_field_metadata,
|
get_field_metadata,
|
||||||
get_model_fields,
|
get_model_fields,
|
||||||
get_relationship_to,
|
get_relationship_to,
|
||||||
get_type_from_field,
|
get_sa_type_from_field,
|
||||||
init_pydantic_private_attrs,
|
init_pydantic_private_attrs,
|
||||||
is_field_noneable,
|
is_field_noneable,
|
||||||
is_table_model_class,
|
is_table_model_class,
|
||||||
@ -649,7 +649,7 @@ def get_sqlalchemy_type(field: Any) -> Any:
|
|||||||
if sa_type is not Undefined:
|
if sa_type is not Undefined:
|
||||||
return sa_type
|
return sa_type
|
||||||
|
|
||||||
type_ = get_type_from_field(field)
|
type_ = get_sa_type_from_field(field)
|
||||||
metadata = get_field_metadata(field)
|
metadata = get_field_metadata(field)
|
||||||
|
|
||||||
# Check enums first as an enum can also be a str, needed by Pydantic/FastAPI
|
# Check enums first as an enum can also be a str, needed by Pydantic/FastAPI
|
||||||
|
26
tests/test_annotated_uuid.py
Normal file
26
tests/test_annotated_uuid.py
Normal file
@ -0,0 +1,26 @@
|
|||||||
|
import uuid
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from sqlmodel import Field, Session, SQLModel, create_engine, select
|
||||||
|
|
||||||
|
from tests.conftest import needs_pydanticv2
|
||||||
|
|
||||||
|
|
||||||
|
@needs_pydanticv2
|
||||||
|
def test_annotated_optional_types(clear_sqlmodel) -> None:
|
||||||
|
from pydantic import UUID4
|
||||||
|
|
||||||
|
class Hero(SQLModel, table=True):
|
||||||
|
# Pydantic UUID4 is: Annotated[UUID, UuidVersion(4)]
|
||||||
|
id: Optional[UUID4] = Field(default_factory=uuid.uuid4, primary_key=True)
|
||||||
|
|
||||||
|
engine = create_engine("sqlite:///:memory:")
|
||||||
|
SQLModel.metadata.create_all(engine)
|
||||||
|
with Session(engine) as db:
|
||||||
|
hero = Hero()
|
||||||
|
db.add(hero)
|
||||||
|
db.commit()
|
||||||
|
statement = select(Hero)
|
||||||
|
result = db.exec(statement).all()
|
||||||
|
assert len(result) == 1
|
||||||
|
assert isinstance(hero.id, uuid.UUID)
|
Loading…
x
Reference in New Issue
Block a user