mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-10-11 23:51:50 +08:00
Fix select: return typed comparisons if vectorized
This commit is contained in:
parent
027dc5bc8d
commit
dbd25f632b
@ -1048,22 +1048,33 @@ struct ternary_evaluator<CwiseTernaryOp<TernaryOp, Arg1, Arg2, Arg3>, IndexBased
|
|||||||
Data m_d;
|
Data m_d;
|
||||||
};
|
};
|
||||||
|
|
||||||
// specialization for expressions like (a < b).select(c, d) to enable full vectorization
|
|
||||||
template <typename Arg1, typename Arg2, typename Scalar, typename CmpLhsType, typename CmpRhsType, ComparisonName cmp>
|
template <typename Arg1, typename Arg2, typename Scalar, typename CmpLhsType, typename CmpRhsType, ComparisonName cmp>
|
||||||
struct evaluator<CwiseTernaryOp<scalar_boolean_select_op<Scalar, Scalar, bool>, Arg1, Arg2,
|
struct scalar_boolean_select_spec {
|
||||||
CwiseBinaryOp<scalar_cmp_op<Scalar, Scalar, cmp, false>, CmpLhsType, CmpRhsType>>>
|
|
||||||
: public ternary_evaluator<
|
|
||||||
CwiseTernaryOp<scalar_boolean_select_op<Scalar, Scalar, Scalar>, Arg1, Arg2,
|
|
||||||
CwiseBinaryOp<scalar_cmp_op<Scalar, Scalar, cmp, true>, CmpLhsType, CmpRhsType>>> {
|
|
||||||
using DummyTernaryOp = scalar_boolean_select_op<Scalar, Scalar, bool>;
|
using DummyTernaryOp = scalar_boolean_select_op<Scalar, Scalar, bool>;
|
||||||
using DummyArg3 = CwiseBinaryOp<scalar_cmp_op<Scalar, Scalar, cmp, false>, CmpLhsType, CmpRhsType>;
|
using DummyArg3 = CwiseBinaryOp<scalar_cmp_op<Scalar, Scalar, cmp, false>, CmpLhsType, CmpRhsType>;
|
||||||
using DummyXprType = CwiseTernaryOp<DummyTernaryOp, Arg1, Arg2, DummyArg3>;
|
using DummyXprType = CwiseTernaryOp<DummyTernaryOp, Arg1, Arg2, DummyArg3>;
|
||||||
|
|
||||||
using TernaryOp = scalar_boolean_select_op<Scalar, Scalar, Scalar>;
|
// only use the typed comparison if it is vectorized
|
||||||
using Arg3 = CwiseBinaryOp<scalar_cmp_op<Scalar, Scalar, cmp, true>, CmpLhsType, CmpRhsType>;
|
static constexpr bool UseTyped = functor_traits<scalar_cmp_op<Scalar, Scalar, cmp, true>>::PacketAccess;
|
||||||
|
using CondScalar = std::conditional_t<UseTyped, Scalar, bool>;
|
||||||
|
|
||||||
|
using TernaryOp = scalar_boolean_select_op<Scalar, Scalar, CondScalar>;
|
||||||
|
using Arg3 = CwiseBinaryOp<scalar_cmp_op<Scalar, Scalar, cmp, UseTyped>, CmpLhsType, CmpRhsType>;
|
||||||
using XprType = CwiseTernaryOp<TernaryOp, Arg1, Arg2, Arg3>;
|
using XprType = CwiseTernaryOp<TernaryOp, Arg1, Arg2, Arg3>;
|
||||||
|
|
||||||
using Base = ternary_evaluator<XprType>;
|
using Base = ternary_evaluator<XprType>;
|
||||||
|
};
|
||||||
|
|
||||||
|
// specialization for expressions like (a < b).select(c, d) to enable full vectorization
|
||||||
|
template <typename Arg1, typename Arg2, typename Scalar, typename CmpLhsType, typename CmpRhsType, ComparisonName cmp>
|
||||||
|
struct evaluator<CwiseTernaryOp<scalar_boolean_select_op<Scalar, Scalar, bool>, Arg1, Arg2,
|
||||||
|
CwiseBinaryOp<scalar_cmp_op<Scalar, Scalar, cmp, false>, CmpLhsType, CmpRhsType>>>
|
||||||
|
: public scalar_boolean_select_spec<Arg1, Arg2, Scalar, CmpLhsType, CmpRhsType, cmp>::Base {
|
||||||
|
using Helper = scalar_boolean_select_spec<Arg1, Arg2, Scalar, CmpLhsType, CmpRhsType, cmp>;
|
||||||
|
using Base = typename Helper::Base;
|
||||||
|
using DummyXprType = typename Helper::DummyXprType;
|
||||||
|
using Arg3 = typename Helper::Arg3;
|
||||||
|
using XprType = typename Helper::XprType;
|
||||||
|
|
||||||
EIGEN_DEVICE_FUNC explicit evaluator(const DummyXprType& xpr)
|
EIGEN_DEVICE_FUNC explicit evaluator(const DummyXprType& xpr)
|
||||||
: Base(XprType(xpr.arg1(), xpr.arg2(), Arg3(xpr.arg3().lhs(), xpr.arg3().rhs()))) {}
|
: Base(XprType(xpr.arg1(), xpr.arg2(), Arg3(xpr.arg3().lhs(), xpr.arg3().rhs()))) {}
|
||||||
|
@ -207,20 +207,9 @@ struct functor_traits<scalar_cmp_op<LhsScalar, RhsScalar, cmp, UseTypedComparato
|
|||||||
};
|
};
|
||||||
};
|
};
|
||||||
|
|
||||||
template <typename LhsScalar, typename RhsScalar, bool UseTypedComparators>
|
|
||||||
struct typed_cmp_helper {
|
|
||||||
static constexpr bool SameType = is_same<LhsScalar, RhsScalar>::value;
|
|
||||||
static constexpr bool IsNumeric = is_arithmetic<typename NumTraits<LhsScalar>::Real>::value;
|
|
||||||
static constexpr bool UseTyped = UseTypedComparators && SameType && IsNumeric;
|
|
||||||
using type = typename conditional<UseTyped, LhsScalar, bool>::type;
|
|
||||||
};
|
|
||||||
|
|
||||||
template <typename LhsScalar, typename RhsScalar, bool UseTypedComparators>
|
|
||||||
using cmp_return_t = typename typed_cmp_helper<LhsScalar, RhsScalar, UseTypedComparators>::type;
|
|
||||||
|
|
||||||
template <typename LhsScalar, typename RhsScalar, bool UseTypedComparators>
|
template <typename LhsScalar, typename RhsScalar, bool UseTypedComparators>
|
||||||
struct scalar_cmp_op<LhsScalar, RhsScalar, cmp_EQ, UseTypedComparators> : binary_op_base<LhsScalar, RhsScalar> {
|
struct scalar_cmp_op<LhsScalar, RhsScalar, cmp_EQ, UseTypedComparators> : binary_op_base<LhsScalar, RhsScalar> {
|
||||||
using result_type = cmp_return_t<LhsScalar, RhsScalar, UseTypedComparators>;
|
using result_type = std::conditional_t<UseTypedComparators, LhsScalar, bool>;
|
||||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE result_type operator()(const LhsScalar& a, const RhsScalar& b) const {
|
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE result_type operator()(const LhsScalar& a, const RhsScalar& b) const {
|
||||||
return a == b ? result_type(1) : result_type(0);
|
return a == b ? result_type(1) : result_type(0);
|
||||||
}
|
}
|
||||||
@ -233,7 +222,7 @@ struct scalar_cmp_op<LhsScalar, RhsScalar, cmp_EQ, UseTypedComparators> : binary
|
|||||||
|
|
||||||
template <typename LhsScalar, typename RhsScalar, bool UseTypedComparators>
|
template <typename LhsScalar, typename RhsScalar, bool UseTypedComparators>
|
||||||
struct scalar_cmp_op<LhsScalar, RhsScalar, cmp_LT, UseTypedComparators> : binary_op_base<LhsScalar, RhsScalar> {
|
struct scalar_cmp_op<LhsScalar, RhsScalar, cmp_LT, UseTypedComparators> : binary_op_base<LhsScalar, RhsScalar> {
|
||||||
using result_type = cmp_return_t<LhsScalar, RhsScalar, UseTypedComparators>;
|
using result_type = std::conditional_t<UseTypedComparators, LhsScalar, bool>;
|
||||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE result_type operator()(const LhsScalar& a, const RhsScalar& b) const {
|
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE result_type operator()(const LhsScalar& a, const RhsScalar& b) const {
|
||||||
return a < b ? result_type(1) : result_type(0);
|
return a < b ? result_type(1) : result_type(0);
|
||||||
}
|
}
|
||||||
@ -246,7 +235,7 @@ struct scalar_cmp_op<LhsScalar, RhsScalar, cmp_LT, UseTypedComparators> : binary
|
|||||||
|
|
||||||
template <typename LhsScalar, typename RhsScalar, bool UseTypedComparators>
|
template <typename LhsScalar, typename RhsScalar, bool UseTypedComparators>
|
||||||
struct scalar_cmp_op<LhsScalar, RhsScalar, cmp_LE, UseTypedComparators> : binary_op_base<LhsScalar, RhsScalar> {
|
struct scalar_cmp_op<LhsScalar, RhsScalar, cmp_LE, UseTypedComparators> : binary_op_base<LhsScalar, RhsScalar> {
|
||||||
using result_type = cmp_return_t<LhsScalar, RhsScalar, UseTypedComparators>;
|
using result_type = std::conditional_t<UseTypedComparators, LhsScalar, bool>;
|
||||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE result_type operator()(const LhsScalar& a, const RhsScalar& b) const {
|
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE result_type operator()(const LhsScalar& a, const RhsScalar& b) const {
|
||||||
return a <= b ? result_type(1) : result_type(0);
|
return a <= b ? result_type(1) : result_type(0);
|
||||||
}
|
}
|
||||||
@ -259,7 +248,7 @@ struct scalar_cmp_op<LhsScalar, RhsScalar, cmp_LE, UseTypedComparators> : binary
|
|||||||
|
|
||||||
template <typename LhsScalar, typename RhsScalar, bool UseTypedComparators>
|
template <typename LhsScalar, typename RhsScalar, bool UseTypedComparators>
|
||||||
struct scalar_cmp_op<LhsScalar, RhsScalar, cmp_GT, UseTypedComparators> : binary_op_base<LhsScalar, RhsScalar> {
|
struct scalar_cmp_op<LhsScalar, RhsScalar, cmp_GT, UseTypedComparators> : binary_op_base<LhsScalar, RhsScalar> {
|
||||||
using result_type = cmp_return_t<LhsScalar, RhsScalar, UseTypedComparators>;
|
using result_type = std::conditional_t<UseTypedComparators, LhsScalar, bool>;
|
||||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE result_type operator()(const LhsScalar& a, const RhsScalar& b) const {
|
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE result_type operator()(const LhsScalar& a, const RhsScalar& b) const {
|
||||||
return a > b ? result_type(1) : result_type(0);
|
return a > b ? result_type(1) : result_type(0);
|
||||||
}
|
}
|
||||||
@ -272,7 +261,7 @@ struct scalar_cmp_op<LhsScalar, RhsScalar, cmp_GT, UseTypedComparators> : binary
|
|||||||
|
|
||||||
template <typename LhsScalar, typename RhsScalar, bool UseTypedComparators>
|
template <typename LhsScalar, typename RhsScalar, bool UseTypedComparators>
|
||||||
struct scalar_cmp_op<LhsScalar, RhsScalar, cmp_GE, UseTypedComparators> : binary_op_base<LhsScalar, RhsScalar> {
|
struct scalar_cmp_op<LhsScalar, RhsScalar, cmp_GE, UseTypedComparators> : binary_op_base<LhsScalar, RhsScalar> {
|
||||||
using result_type = cmp_return_t<LhsScalar, RhsScalar, UseTypedComparators>;
|
using result_type = std::conditional_t<UseTypedComparators, LhsScalar, bool>;
|
||||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE result_type operator()(const LhsScalar& a, const RhsScalar& b) const {
|
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE result_type operator()(const LhsScalar& a, const RhsScalar& b) const {
|
||||||
return a >= b ? result_type(1) : result_type(0);
|
return a >= b ? result_type(1) : result_type(0);
|
||||||
}
|
}
|
||||||
@ -285,7 +274,7 @@ struct scalar_cmp_op<LhsScalar, RhsScalar, cmp_GE, UseTypedComparators> : binary
|
|||||||
|
|
||||||
template <typename LhsScalar, typename RhsScalar, bool UseTypedComparators>
|
template <typename LhsScalar, typename RhsScalar, bool UseTypedComparators>
|
||||||
struct scalar_cmp_op<LhsScalar, RhsScalar, cmp_UNORD, UseTypedComparators> : binary_op_base<LhsScalar, RhsScalar> {
|
struct scalar_cmp_op<LhsScalar, RhsScalar, cmp_UNORD, UseTypedComparators> : binary_op_base<LhsScalar, RhsScalar> {
|
||||||
using result_type = cmp_return_t<LhsScalar, RhsScalar, UseTypedComparators>;
|
using result_type = std::conditional_t<UseTypedComparators, LhsScalar, bool>;
|
||||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE result_type operator()(const LhsScalar& a, const RhsScalar& b) const {
|
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE result_type operator()(const LhsScalar& a, const RhsScalar& b) const {
|
||||||
return !(a <= b || b <= a) ? result_type(1) : result_type(0);
|
return !(a <= b || b <= a) ? result_type(1) : result_type(0);
|
||||||
}
|
}
|
||||||
@ -298,7 +287,7 @@ struct scalar_cmp_op<LhsScalar, RhsScalar, cmp_UNORD, UseTypedComparators> : bin
|
|||||||
|
|
||||||
template <typename LhsScalar, typename RhsScalar, bool UseTypedComparators>
|
template <typename LhsScalar, typename RhsScalar, bool UseTypedComparators>
|
||||||
struct scalar_cmp_op<LhsScalar, RhsScalar, cmp_NEQ, UseTypedComparators> : binary_op_base<LhsScalar, RhsScalar> {
|
struct scalar_cmp_op<LhsScalar, RhsScalar, cmp_NEQ, UseTypedComparators> : binary_op_base<LhsScalar, RhsScalar> {
|
||||||
using result_type = cmp_return_t<LhsScalar, RhsScalar, UseTypedComparators>;
|
using result_type = std::conditional_t<UseTypedComparators, LhsScalar, bool>;
|
||||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE result_type operator()(const LhsScalar& a, const RhsScalar& b) const {
|
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE result_type operator()(const LhsScalar& a, const RhsScalar& b) const {
|
||||||
return a != b ? result_type(1) : result_type(0);
|
return a != b ? result_type(1) : result_type(0);
|
||||||
}
|
}
|
||||||
|
@ -762,6 +762,18 @@ void comparisons(const ArrayType& m) {
|
|||||||
VERIFY_IS_APPROX(((m1.abs() + 1) > RealScalar(0.1)).colwise().count(),
|
VERIFY_IS_APPROX(((m1.abs() + 1) > RealScalar(0.1)).colwise().count(),
|
||||||
ArrayOfIndices::Constant(cols, rows).transpose());
|
ArrayOfIndices::Constant(cols, rows).transpose());
|
||||||
VERIFY_IS_APPROX(((m1.abs() + 1) > RealScalar(0.1)).rowwise().count(), ArrayOfIndices::Constant(rows, cols));
|
VERIFY_IS_APPROX(((m1.abs() + 1) > RealScalar(0.1)).rowwise().count(), ArrayOfIndices::Constant(rows, cols));
|
||||||
|
|
||||||
|
// simple data type that does not permit implicit conversions
|
||||||
|
struct scalar_wrapper {
|
||||||
|
Scalar m_data;
|
||||||
|
scalar_wrapper() : m_data(0) {}
|
||||||
|
explicit scalar_wrapper(Scalar data) : m_data(data) {}
|
||||||
|
bool operator==(scalar_wrapper other) const { return m_data == other.m_data; }
|
||||||
|
};
|
||||||
|
|
||||||
|
// test bug2966: select did not support some scalar types that forbade implicit conversions from bool
|
||||||
|
ArrayX<scalar_wrapper> m5(10);
|
||||||
|
m5 = (m5 == scalar_wrapper(0)).select(m5, m5);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename ArrayType>
|
template <typename ArrayType>
|
||||||
|
Loading…
x
Reference in New Issue
Block a user