diff --git a/Eigen/src/Core/CoreEvaluators.h b/Eigen/src/Core/CoreEvaluators.h index c9b2d2d28..60857e2cc 100644 --- a/Eigen/src/Core/CoreEvaluators.h +++ b/Eigen/src/Core/CoreEvaluators.h @@ -1048,22 +1048,33 @@ struct ternary_evaluator, IndexBased Data m_d; }; -// specialization for expressions like (a < b).select(c, d) to enable full vectorization template -struct evaluator, Arg1, Arg2, - CwiseBinaryOp, CmpLhsType, CmpRhsType>>> - : public ternary_evaluator< - CwiseTernaryOp, Arg1, Arg2, - CwiseBinaryOp, CmpLhsType, CmpRhsType>>> { +struct scalar_boolean_select_spec { using DummyTernaryOp = scalar_boolean_select_op; using DummyArg3 = CwiseBinaryOp, CmpLhsType, CmpRhsType>; using DummyXprType = CwiseTernaryOp; - using TernaryOp = scalar_boolean_select_op; - using Arg3 = CwiseBinaryOp, CmpLhsType, CmpRhsType>; + // only use the typed comparison if it is vectorized + static constexpr bool UseTyped = functor_traits>::PacketAccess; + using CondScalar = std::conditional_t; + + using TernaryOp = scalar_boolean_select_op; + using Arg3 = CwiseBinaryOp, CmpLhsType, CmpRhsType>; using XprType = CwiseTernaryOp; using Base = ternary_evaluator; +}; + +// specialization for expressions like (a < b).select(c, d) to enable full vectorization +template +struct evaluator, Arg1, Arg2, + CwiseBinaryOp, CmpLhsType, CmpRhsType>>> + : public scalar_boolean_select_spec::Base { + using Helper = scalar_boolean_select_spec; + using Base = typename Helper::Base; + using DummyXprType = typename Helper::DummyXprType; + using Arg3 = typename Helper::Arg3; + using XprType = typename Helper::XprType; EIGEN_DEVICE_FUNC explicit evaluator(const DummyXprType& xpr) : Base(XprType(xpr.arg1(), xpr.arg2(), Arg3(xpr.arg3().lhs(), xpr.arg3().rhs()))) {} diff --git a/Eigen/src/Core/functors/BinaryFunctors.h b/Eigen/src/Core/functors/BinaryFunctors.h index b6ecfb5d5..85e1584ea 100644 --- a/Eigen/src/Core/functors/BinaryFunctors.h +++ b/Eigen/src/Core/functors/BinaryFunctors.h @@ -207,20 +207,9 @@ struct functor_traits -struct typed_cmp_helper { - static constexpr bool SameType = is_same::value; - static constexpr bool IsNumeric = is_arithmetic::Real>::value; - static constexpr bool UseTyped = UseTypedComparators && SameType && IsNumeric; - using type = typename conditional::type; -}; - -template -using cmp_return_t = typename typed_cmp_helper::type; - template struct scalar_cmp_op : binary_op_base { - using result_type = cmp_return_t; + using result_type = std::conditional_t; EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE result_type operator()(const LhsScalar& a, const RhsScalar& b) const { return a == b ? result_type(1) : result_type(0); } @@ -233,7 +222,7 @@ struct scalar_cmp_op : binary template struct scalar_cmp_op : binary_op_base { - using result_type = cmp_return_t; + using result_type = std::conditional_t; EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE result_type operator()(const LhsScalar& a, const RhsScalar& b) const { return a < b ? result_type(1) : result_type(0); } @@ -246,7 +235,7 @@ struct scalar_cmp_op : binary template struct scalar_cmp_op : binary_op_base { - using result_type = cmp_return_t; + using result_type = std::conditional_t; EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE result_type operator()(const LhsScalar& a, const RhsScalar& b) const { return a <= b ? result_type(1) : result_type(0); } @@ -259,7 +248,7 @@ struct scalar_cmp_op : binary template struct scalar_cmp_op : binary_op_base { - using result_type = cmp_return_t; + using result_type = std::conditional_t; EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE result_type operator()(const LhsScalar& a, const RhsScalar& b) const { return a > b ? result_type(1) : result_type(0); } @@ -272,7 +261,7 @@ struct scalar_cmp_op : binary template struct scalar_cmp_op : binary_op_base { - using result_type = cmp_return_t; + using result_type = std::conditional_t; EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE result_type operator()(const LhsScalar& a, const RhsScalar& b) const { return a >= b ? result_type(1) : result_type(0); } @@ -285,7 +274,7 @@ struct scalar_cmp_op : binary template struct scalar_cmp_op : binary_op_base { - using result_type = cmp_return_t; + using result_type = std::conditional_t; EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE result_type operator()(const LhsScalar& a, const RhsScalar& b) const { return !(a <= b || b <= a) ? result_type(1) : result_type(0); } @@ -298,7 +287,7 @@ struct scalar_cmp_op : bin template struct scalar_cmp_op : binary_op_base { - using result_type = cmp_return_t; + using result_type = std::conditional_t; EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE result_type operator()(const LhsScalar& a, const RhsScalar& b) const { return a != b ? result_type(1) : result_type(0); } diff --git a/test/array_cwise.cpp b/test/array_cwise.cpp index d37e47a64..f361cce77 100644 --- a/test/array_cwise.cpp +++ b/test/array_cwise.cpp @@ -762,6 +762,18 @@ void comparisons(const ArrayType& m) { VERIFY_IS_APPROX(((m1.abs() + 1) > RealScalar(0.1)).colwise().count(), ArrayOfIndices::Constant(cols, rows).transpose()); VERIFY_IS_APPROX(((m1.abs() + 1) > RealScalar(0.1)).rowwise().count(), ArrayOfIndices::Constant(rows, cols)); + + // simple data type that does not permit implicit conversions + struct scalar_wrapper { + Scalar m_data; + scalar_wrapper() : m_data(0) {} + explicit scalar_wrapper(Scalar data) : m_data(data) {} + bool operator==(scalar_wrapper other) const { return m_data == other.m_data; } + }; + + // test bug2966: select did not support some scalar types that forbade implicit conversions from bool + ArrayX m5(10); + m5 = (m5 == scalar_wrapper(0)).select(m5, m5); } template