sqlmodel-fix/sqlmodel/sql/expression.py
Sebastián Ramírez d165e4b5ad
♻️ Refactor generate select template to isolate templated code to the minimum (#967)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
2024-06-03 21:34:54 -05:00

216 lines
6.2 KiB
Python

from typing import (
Any,
Iterable,
Mapping,
Optional,
Sequence,
Tuple,
Type,
TypeVar,
Union,
)
import sqlalchemy
from sqlalchemy import (
Column,
ColumnElement,
Extract,
FunctionElement,
FunctionFilter,
Label,
Over,
TypeCoerce,
WithinGroup,
)
from sqlalchemy.orm import InstrumentedAttribute, Mapped
from sqlalchemy.sql._typing import (
_ColumnExpressionArgument,
_ColumnExpressionOrLiteralArgument,
_ColumnExpressionOrStrLabelArgument,
)
from sqlalchemy.sql.elements import (
BinaryExpression,
Case,
Cast,
CollectionAggregate,
ColumnClause,
TryCast,
UnaryExpression,
)
from sqlalchemy.sql.type_api import TypeEngine
from typing_extensions import Literal
from ._expression_select_cls import Select as Select
from ._expression_select_cls import SelectOfScalar as SelectOfScalar
from ._expression_select_gen import select as select
_T = TypeVar("_T")
_TypeEngineArgument = Union[Type[TypeEngine[_T]], TypeEngine[_T]]
# Redefine operatos that would only take a column expresion to also take the (virtual)
# types of Pydantic models, e.g. str instead of only Mapped[str].
def all_(expr: Union[_ColumnExpressionArgument[_T], _T]) -> CollectionAggregate[bool]:
return sqlalchemy.all_(expr) # type: ignore[arg-type]
def and_(
initial_clause: Union[Literal[True], _ColumnExpressionArgument[bool], bool],
*clauses: Union[_ColumnExpressionArgument[bool], bool],
) -> ColumnElement[bool]:
return sqlalchemy.and_(initial_clause, *clauses) # type: ignore[arg-type]
def any_(expr: Union[_ColumnExpressionArgument[_T], _T]) -> CollectionAggregate[bool]:
return sqlalchemy.any_(expr) # type: ignore[arg-type]
def asc(
column: Union[_ColumnExpressionOrStrLabelArgument[_T], _T],
) -> UnaryExpression[_T]:
return sqlalchemy.asc(column) # type: ignore[arg-type]
def collate(
expression: Union[_ColumnExpressionArgument[str], str], collation: str
) -> BinaryExpression[str]:
return sqlalchemy.collate(expression, collation) # type: ignore[arg-type]
def between(
expr: Union[_ColumnExpressionOrLiteralArgument[_T], _T],
lower_bound: Any,
upper_bound: Any,
symmetric: bool = False,
) -> BinaryExpression[bool]:
return sqlalchemy.between(expr, lower_bound, upper_bound, symmetric=symmetric)
def not_(clause: Union[_ColumnExpressionArgument[_T], _T]) -> ColumnElement[_T]:
return sqlalchemy.not_(clause) # type: ignore[arg-type]
def case(
*whens: Union[
Tuple[Union[_ColumnExpressionArgument[bool], bool], Any], Mapping[Any, Any]
],
value: Optional[Any] = None,
else_: Optional[Any] = None,
) -> Case[Any]:
return sqlalchemy.case(*whens, value=value, else_=else_) # type: ignore[arg-type]
def cast(
expression: Union[_ColumnExpressionOrLiteralArgument[Any], Any],
type_: "_TypeEngineArgument[_T]",
) -> Cast[_T]:
return sqlalchemy.cast(expression, type_)
def try_cast(
expression: Union[_ColumnExpressionOrLiteralArgument[Any], Any],
type_: "_TypeEngineArgument[_T]",
) -> TryCast[_T]:
return sqlalchemy.try_cast(expression, type_)
def desc(
column: Union[_ColumnExpressionOrStrLabelArgument[_T], _T],
) -> UnaryExpression[_T]:
return sqlalchemy.desc(column) # type: ignore[arg-type]
def distinct(expr: Union[_ColumnExpressionArgument[_T], _T]) -> UnaryExpression[_T]:
return sqlalchemy.distinct(expr) # type: ignore[arg-type]
def bitwise_not(expr: Union[_ColumnExpressionArgument[_T], _T]) -> UnaryExpression[_T]:
return sqlalchemy.bitwise_not(expr) # type: ignore[arg-type]
def extract(field: str, expr: Union[_ColumnExpressionArgument[Any], Any]) -> Extract:
return sqlalchemy.extract(field, expr)
def funcfilter(
func: FunctionElement[_T], *criterion: Union[_ColumnExpressionArgument[bool], bool]
) -> FunctionFilter[_T]:
return sqlalchemy.funcfilter(func, *criterion) # type: ignore[arg-type]
def label(
name: str,
element: Union[_ColumnExpressionArgument[_T], _T],
type_: Optional["_TypeEngineArgument[_T]"] = None,
) -> Label[_T]:
return sqlalchemy.label(name, element, type_=type_) # type: ignore[arg-type]
def nulls_first(
column: Union[_ColumnExpressionArgument[_T], _T],
) -> UnaryExpression[_T]:
return sqlalchemy.nulls_first(column) # type: ignore[arg-type]
def nulls_last(column: Union[_ColumnExpressionArgument[_T], _T]) -> UnaryExpression[_T]:
return sqlalchemy.nulls_last(column) # type: ignore[arg-type]
def or_(
initial_clause: Union[Literal[False], _ColumnExpressionArgument[bool], bool],
*clauses: Union[_ColumnExpressionArgument[bool], bool],
) -> ColumnElement[bool]:
return sqlalchemy.or_(initial_clause, *clauses) # type: ignore[arg-type]
def over(
element: FunctionElement[_T],
partition_by: Optional[
Union[
Iterable[Union[_ColumnExpressionArgument[Any], Any]],
_ColumnExpressionArgument[Any],
Any,
]
] = None,
order_by: Optional[
Union[
Iterable[Union[_ColumnExpressionArgument[Any], Any]],
_ColumnExpressionArgument[Any],
Any,
]
] = None,
range_: Optional[Tuple[Optional[int], Optional[int]]] = None,
rows: Optional[Tuple[Optional[int], Optional[int]]] = None,
) -> Over[_T]:
return sqlalchemy.over(
element, partition_by=partition_by, order_by=order_by, range_=range_, rows=rows
)
def tuple_(
*clauses: Union[_ColumnExpressionArgument[Any], Any],
types: Optional[Sequence["_TypeEngineArgument[Any]"]] = None,
) -> Tuple[Any, ...]:
return sqlalchemy.tuple_(*clauses, types=types) # type: ignore[return-value]
def type_coerce(
expression: Union[_ColumnExpressionOrLiteralArgument[Any], Any],
type_: "_TypeEngineArgument[_T]",
) -> TypeCoerce[_T]:
return sqlalchemy.type_coerce(expression, type_)
def within_group(
element: FunctionElement[_T], *order_by: Union[_ColumnExpressionArgument[Any], Any]
) -> WithinGroup[_T]:
return sqlalchemy.within_group(element, *order_by)
def col(column_expression: _T) -> Mapped[_T]:
if not isinstance(column_expression, (ColumnClause, Column, InstrumentedAttribute)):
raise RuntimeError(f"Not a SQLAlchemy column: {column_expression}")
return column_expression # type: ignore