diff --git a/Eigen/src/Core/functors/BinaryFunctors.h b/Eigen/src/Core/functors/BinaryFunctors.h index faf41f52f..cd8ae9ee5 100644 --- a/Eigen/src/Core/functors/BinaryFunctors.h +++ b/Eigen/src/Core/functors/BinaryFunctors.h @@ -195,7 +195,7 @@ struct functor_traits > { * \todo Implement packet-comparisons */ template + bool UseTypedComparators = false> struct scalar_cmp_op; template diff --git a/Eigen/src/plugins/MatrixCwiseBinaryOps.h b/Eigen/src/plugins/MatrixCwiseBinaryOps.h index 53dac96de..ad74c3568 100644 --- a/Eigen/src/plugins/MatrixCwiseBinaryOps.h +++ b/Eigen/src/plugins/MatrixCwiseBinaryOps.h @@ -25,6 +25,13 @@ cwiseProduct(const EIGEN_CURRENT_STORAGE_BASE_CLASS &other) const return EIGEN_CWISE_BINARY_RETURN_TYPE(Derived,OtherDerived,product)(derived(), other.derived()); } +template using CwiseBinaryEqualReturnType = CwiseBinaryOp, const Derived, const OtherDerived>; +template using CwiseBinaryNotEqualReturnType = CwiseBinaryOp, const Derived, const OtherDerived>; +template using CwiseBinaryLesserReturnType = CwiseBinaryOp, const Derived, const OtherDerived>; +template using CwiseBinaryGreaterReturnType = CwiseBinaryOp, const Derived, const OtherDerived>; +template using CwiseBinaryLesserOrEqualReturnType = CwiseBinaryOp, const Derived, const OtherDerived>; +template using CwiseBinaryGreaterOrEqualReturnType = CwiseBinaryOp, const Derived, const OtherDerived>; + /** \returns an expression of the coefficient-wise == operator of *this and \a other * * \warning this performs an exact comparison, which is generally a bad idea with floating-point types. @@ -39,10 +46,10 @@ cwiseProduct(const EIGEN_CURRENT_STORAGE_BASE_CLASS &other) const */ template EIGEN_DEVICE_FUNC -inline const CwiseBinaryOp, const Derived, const OtherDerived> +inline const CwiseBinaryEqualReturnType cwiseEqual(const EIGEN_CURRENT_STORAGE_BASE_CLASS &other) const { - return CwiseBinaryOp, const Derived, const OtherDerived>(derived(), other.derived()); + return CwiseBinaryEqualReturnType(derived(), other.derived()); } /** \returns an expression of the coefficient-wise != operator of *this and \a other @@ -59,46 +66,46 @@ cwiseEqual(const EIGEN_CURRENT_STORAGE_BASE_CLASS &other) const */ template EIGEN_DEVICE_FUNC -inline const CwiseBinaryOp, const Derived, const OtherDerived> +inline const CwiseBinaryNotEqualReturnType cwiseNotEqual(const EIGEN_CURRENT_STORAGE_BASE_CLASS &other) const { - return CwiseBinaryOp, const Derived, const OtherDerived>(derived(), other.derived()); + return CwiseBinaryNotEqualReturnType(derived(), other.derived()); } /** \returns an expression of the coefficient-wise < operator of *this and \a other */ template EIGEN_DEVICE_FUNC -inline const CwiseBinaryOp, const Derived, const OtherDerived> +inline const CwiseBinaryLesserReturnType cwiseLesser(const EIGEN_CURRENT_STORAGE_BASE_CLASS& other) const { - return CwiseBinaryOp, const Derived, const OtherDerived>(derived(), other.derived()); + return CwiseBinaryLesserReturnType(derived(), other.derived()); } /** \returns an expression of the coefficient-wise > operator of *this and \a other */ template EIGEN_DEVICE_FUNC -inline const CwiseBinaryOp, const Derived, const OtherDerived> +inline const CwiseBinaryGreaterReturnType cwiseGreater(const EIGEN_CURRENT_STORAGE_BASE_CLASS& other) const { - return CwiseBinaryOp, const Derived, const OtherDerived>(derived(), other.derived()); + return CwiseBinaryGreaterReturnType(derived(), other.derived()); } /** \returns an expression of the coefficient-wise <= operator of *this and \a other */ template EIGEN_DEVICE_FUNC -inline const CwiseBinaryOp, const Derived, const OtherDerived> +inline const CwiseBinaryLesserOrEqualReturnType cwiseLesserOrEqual(const EIGEN_CURRENT_STORAGE_BASE_CLASS& other) const { - return CwiseBinaryOp, const Derived, const OtherDerived>(derived(), other.derived()); + return CwiseBinaryLesserOrEqualReturnType(derived(), other.derived()); } /** \returns an expression of the coefficient-wise >= operator of *this and \a other */ template EIGEN_DEVICE_FUNC -inline const CwiseBinaryOp, const Derived, const OtherDerived> +inline const CwiseBinaryGreaterOrEqualReturnType cwiseGreaterOrEqual(const EIGEN_CURRENT_STORAGE_BASE_CLASS& other) const { - return CwiseBinaryOp, const Derived, const OtherDerived>(derived(), other.derived()); + return CwiseBinaryGreaterOrEqualReturnType(derived(), other.derived()); } /** \returns an expression of the coefficient-wise min of *this and \a other @@ -191,7 +198,7 @@ EIGEN_DEVICE_FUNC inline const CwiseScalarEqualReturnType cwiseEqual(const Scalar& s) const { - return CwiseScalarEqualReturnType(derived(), Derived::Constant(rows(), cols(), s), internal::scalar_cmp_op()); + return CwiseScalarEqualReturnType(derived(), Derived::Constant(rows(), cols(), s)); } @@ -208,7 +215,7 @@ EIGEN_DEVICE_FUNC inline const CwiseScalarNotEqualReturnType cwiseNotEqual(const Scalar& s) const { - return CwiseScalarNotEqualReturnType(derived(), Derived::Constant(rows(), cols(), s), internal::scalar_cmp_op()); + return CwiseScalarNotEqualReturnType(derived(), Derived::Constant(rows(), cols(), s)); } /** \returns an expression of the coefficient-wise < operator of \c *this and a scalar \a s */ @@ -216,7 +223,7 @@ EIGEN_DEVICE_FUNC inline const CwiseScalarLesserReturnType cwiseLesser(const Scalar& s) const { - return CwiseScalarLesserReturnType(derived(), Derived::Constant(rows(), cols(), s), internal::scalar_cmp_op()); + return CwiseScalarLesserReturnType(derived(), Derived::Constant(rows(), cols(), s)); } /** \returns an expression of the coefficient-wise > operator of \c *this and a scalar \a s */ @@ -224,7 +231,7 @@ EIGEN_DEVICE_FUNC inline const CwiseScalarGreaterReturnType cwiseGreater(const Scalar& s) const { - return CwiseScalarGreaterReturnType(derived(), Derived::Constant(rows(), cols(), s), internal::scalar_cmp_op()); + return CwiseScalarGreaterReturnType(derived(), Derived::Constant(rows(), cols(), s)); } /** \returns an expression of the coefficient-wise <= operator of \c *this and a scalar \a s */ @@ -232,7 +239,7 @@ EIGEN_DEVICE_FUNC inline const CwiseScalarLesserOrEqualReturnType cwiseLesserOrEqual(const Scalar& s) const { - return CwiseScalarLesserOrEqualReturnType(derived(), Derived::Constant(rows(), cols(), s), internal::scalar_cmp_op()); + return CwiseScalarLesserOrEqualReturnType(derived(), Derived::Constant(rows(), cols(), s)); } /** \returns an expression of the coefficient-wise >= operator of \c *this and a scalar \a s */ @@ -240,7 +247,62 @@ EIGEN_DEVICE_FUNC inline const CwiseScalarGreaterOrEqualReturnType cwiseGreaterOrEqual(const Scalar& s) const { - return CwiseScalarGreaterOrEqualReturnType(derived(), Derived::Constant(rows(), cols(), s), internal::scalar_cmp_op()); + return CwiseScalarGreaterOrEqualReturnType(derived(), Derived::Constant(rows(), cols(), s)); } +template using CwiseBinaryTypedEqualReturnType = CwiseBinaryOp, const Derived, const OtherDerived>; +template using CwiseBinaryTypedNotEqualReturnType = CwiseBinaryOp, const Derived, const OtherDerived>; +template using CwiseBinaryTypedLesserReturnType = CwiseBinaryOp, const Derived, const OtherDerived>; +template using CwiseBinaryTypedGreaterReturnType = CwiseBinaryOp, const Derived, const OtherDerived>; +template using CwiseBinaryTypedLesserOrEqualReturnType = CwiseBinaryOp, const Derived, const OtherDerived>; +template using CwiseBinaryTypedGreaterOrEqualReturnType = CwiseBinaryOp, const Derived, const OtherDerived>; + +template +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const CwiseBinaryTypedEqualReturnType +cwiseTypedEqual(const EIGEN_CURRENT_STORAGE_BASE_CLASS& other) const { return CwiseBinaryTypedEqualReturnType(derived(), other.derived()); } + +template +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const CwiseBinaryTypedNotEqualReturnType +cwiseTypedNotEqual(const EIGEN_CURRENT_STORAGE_BASE_CLASS& other) const { return CwiseBinaryTypedNotEqualReturnType(derived(), other.derived()); } + +template +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const CwiseBinaryTypedLesserReturnType +cwiseTypedLesser(const EIGEN_CURRENT_STORAGE_BASE_CLASS& other) const { return CwiseBinaryTypedLesserReturnType(derived(), other.derived()); } + +template +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const CwiseBinaryTypedGreaterReturnType +cwiseTypedGreater(const EIGEN_CURRENT_STORAGE_BASE_CLASS& other) const { return CwiseBinaryTypedGreaterReturnType(derived(), other.derived()); } + +template +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const CwiseBinaryTypedLesserOrEqualReturnType +cwiseTypedLesserOrEqual(const EIGEN_CURRENT_STORAGE_BASE_CLASS& other) const { return CwiseBinaryTypedLesserOrEqualReturnType(derived(), other.derived()); } + +template +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const CwiseBinaryTypedGreaterOrEqualReturnType +cwiseTypedGreaterOrEqual(const EIGEN_CURRENT_STORAGE_BASE_CLASS& other) const { return CwiseBinaryTypedGreaterOrEqualReturnType(derived(), other.derived()); } + +using CwiseScalarTypedEqualReturnType = CwiseBinaryOp, const Derived, const ConstantReturnType>; +using CwiseScalarTypedNotEqualReturnType = CwiseBinaryOp, const Derived, const ConstantReturnType>; +using CwiseScalarTypedLesserReturnType = CwiseBinaryOp, const Derived, const ConstantReturnType>; +using CwiseScalarTypedGreaterReturnType = CwiseBinaryOp, const Derived, const ConstantReturnType>; +using CwiseScalarTypedLesserOrEqualReturnType = CwiseBinaryOp, const Derived, const ConstantReturnType>; +using CwiseScalarTypedGreaterOrEqualReturnType = CwiseBinaryOp, const Derived, const ConstantReturnType>; + +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const CwiseScalarTypedEqualReturnType +cwiseTypedEqual(const Scalar& s) const { return CwiseScalarTypedEqualReturnType(derived(), ConstantReturnType(rows(), cols(), s)); } + +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const CwiseScalarTypedNotEqualReturnType +cwiseTypedNotEqual(const Scalar& s) const { return CwiseScalarTypedNotEqualReturnType(derived(), ConstantReturnType(rows(), cols(), s)); } + +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const CwiseScalarTypedLesserReturnType +cwiseTypedLesser(const Scalar& s) const { return CwiseScalarTypedLesserReturnType(derived(), ConstantReturnType(rows(), cols(), s)); } + +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const CwiseScalarTypedGreaterReturnType +cwiseTypedGreater(const Scalar& s) const { return CwiseScalarTypedGreaterReturnType(derived(), ConstantReturnType(rows(), cols(), s)); } + +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const CwiseScalarTypedLesserOrEqualReturnType +cwiseTypedLesserOrEqual(const Scalar& s) const { return CwiseScalarTypedLesserOrEqualReturnType(derived(), ConstantReturnType(rows(), cols(), s)); } + +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const CwiseScalarTypedGreaterOrEqualReturnType +cwiseTypedGreaterOrEqual(const Scalar& s) const { return CwiseScalarTypedGreaterOrEqualReturnType(derived(), ConstantReturnType(rows(), cols(), s)); } diff --git a/test/array_cwise.cpp b/test/array_cwise.cpp index 5d23d0395..9b8891dbe 100644 --- a/test/array_cwise.cpp +++ b/test/array_cwise.cpp @@ -590,21 +590,6 @@ template void comparisons(const ArrayType& m) typedef typename ArrayType::Scalar Scalar; typedef typename NumTraits::Real RealScalar; - // explicitly test both typed and boolean comparison ops - using typed_eq = internal::scalar_cmp_op; - using typed_ne = internal::scalar_cmp_op; - using typed_lt = internal::scalar_cmp_op; - using typed_le = internal::scalar_cmp_op; - using typed_gt = internal::scalar_cmp_op; - using typed_ge = internal::scalar_cmp_op; - - using bool_eq = internal::scalar_cmp_op; - using bool_ne = internal::scalar_cmp_op; - using bool_lt = internal::scalar_cmp_op; - using bool_le = internal::scalar_cmp_op; - using bool_gt = internal::scalar_cmp_op; - using bool_ge = internal::scalar_cmp_op; - Index rows = m.rows(); Index cols = m.cols(); @@ -649,28 +634,44 @@ template void comparisons(const ArrayType& m) // use typed comparisons, regardless of operator overload behavior typename ArrayType::ConstantReturnType typed_true = ArrayType::Constant(rows, cols, Scalar(1)); // (m1 + Scalar(1)) > m1).all() - VERIFY_IS_CWISE_EQUAL((m1 + Scalar(1)).binaryExpr(m1, typed_gt()), typed_true); + VERIFY_IS_CWISE_EQUAL((m1 + Scalar(1)).cwiseTypedGreater(m1), typed_true); // (m1 - Scalar(1)) < m1).all() - VERIFY_IS_CWISE_EQUAL((m1 - Scalar(1)).binaryExpr(m1, typed_lt()), typed_true); + VERIFY_IS_CWISE_EQUAL((m1 - Scalar(1)).cwiseTypedLesser(m1), typed_true); // (m1 + Scalar(1)) == (m1 + Scalar(1))).all() - VERIFY_IS_CWISE_EQUAL((m1 + Scalar(1)).binaryExpr(m1 + Scalar(1), typed_eq()), typed_true); + VERIFY_IS_CWISE_EQUAL((m1 + Scalar(1)).cwiseTypedEqual(m1 + Scalar(1)), typed_true); // (m1 - Scalar(1)) != m1).all() - VERIFY_IS_CWISE_EQUAL((m1 - Scalar(1)).binaryExpr(m1, typed_ne()), typed_true); + VERIFY_IS_CWISE_EQUAL((m1 - Scalar(1)).cwiseTypedNotEqual(m1), typed_true); // (m1 <= m2 || m1 >= m2).all() - VERIFY_IS_CWISE_EQUAL(m1.binaryExpr(m2, typed_le()) || m1.binaryExpr(m2, typed_ge()), typed_true); + VERIFY_IS_CWISE_EQUAL(m1.cwiseTypedGreaterOrEqual(m2) || m1.cwiseTypedLesserOrEqual(m2), typed_true); // use boolean comparisons, regardless of operator overload behavior ArrayXX::ConstantReturnType bool_true = ArrayXX::Constant(rows, cols, true); // (m1 + Scalar(1)) > m1).all() - VERIFY_IS_CWISE_EQUAL((m1 + Scalar(1)).binaryExpr(m1, bool_gt()), bool_true); + VERIFY_IS_CWISE_EQUAL((m1 + Scalar(1)).cwiseGreater(m1), bool_true); // (m1 - Scalar(1)) < m1).all() - VERIFY_IS_CWISE_EQUAL((m1 - Scalar(1)).binaryExpr(m1, bool_lt()), bool_true); + VERIFY_IS_CWISE_EQUAL((m1 - Scalar(1)).cwiseLesser(m1), bool_true); // (m1 + Scalar(1)) == (m1 + Scalar(1))).all() - VERIFY_IS_CWISE_EQUAL((m1 + Scalar(1)).binaryExpr(m1 + Scalar(1), bool_eq()), bool_true); + VERIFY_IS_CWISE_EQUAL((m1 + Scalar(1)).cwiseEqual(m1 + Scalar(1)), bool_true); // (m1 - Scalar(1)) != m1).all() - VERIFY_IS_CWISE_EQUAL((m1 - Scalar(1)).binaryExpr(m1, bool_ne()), bool_true); + VERIFY_IS_CWISE_EQUAL((m1 - Scalar(1)).cwiseNotEqual(m1), bool_true); // (m1 <= m2 || m1 >= m2).all() - VERIFY_IS_CWISE_EQUAL(m1.binaryExpr(m2, bool_le()) || m1.binaryExpr(m2, bool_ge()), bool_true); + VERIFY_IS_CWISE_EQUAL(m1.cwiseLesserOrEqual(m2) || m1.cwiseGreaterOrEqual(m2), bool_true); + + // test typed comparisons with scalar argument + VERIFY_IS_CWISE_EQUAL((m1 - m1).cwiseTypedEqual(Scalar(0)), typed_true); + VERIFY_IS_CWISE_EQUAL((m1.abs() + Scalar(1)).cwiseTypedNotEqual(Scalar(0)), typed_true); + VERIFY_IS_CWISE_EQUAL((m1 + Scalar(1)).cwiseTypedGreater(m1.minCoeff()), typed_true); + VERIFY_IS_CWISE_EQUAL((m1 - Scalar(1)).cwiseTypedLesser(m1.maxCoeff()), typed_true); + VERIFY_IS_CWISE_EQUAL(m1.abs().cwiseTypedLesserOrEqual(NumTraits::highest()), typed_true); + VERIFY_IS_CWISE_EQUAL((m1 * m1).cwiseTypedGreaterOrEqual(Scalar(0)), typed_true); + + // test boolean comparisons with scalar argument + VERIFY_IS_CWISE_EQUAL((m1 - m1).cwiseEqual(Scalar(0)), bool_true); + VERIFY_IS_CWISE_EQUAL((m1.abs() + Scalar(1)).cwiseNotEqual(Scalar(0)), bool_true); + VERIFY_IS_CWISE_EQUAL((m1 + Scalar(1)).cwiseGreater(m1.minCoeff()), bool_true); + VERIFY_IS_CWISE_EQUAL((m1 - Scalar(1)).cwiseLesser(m1.maxCoeff()), bool_true); + VERIFY_IS_CWISE_EQUAL(m1.abs().cwiseLesserOrEqual(NumTraits::highest()), bool_true); + VERIFY_IS_CWISE_EQUAL((m1 * m1).cwiseGreaterOrEqual(Scalar(0)), bool_true); // test Select VERIFY_IS_APPROX( (m1