From ed1962b464119395d23fa7deeb9b7e77b0c05d18 Mon Sep 17 00:00:00 2001 From: Benoit Steiner Date: Fri, 6 Nov 2015 09:18:43 -0800 Subject: [PATCH] Reimplement the tensor comparison operators by using the scalar_cmp_op functors. This makes them more cuda friendly. --- Eigen/src/Core/functors/BinaryFunctors.h | 8 ++++++++ Eigen/src/Core/util/Constants.h | 4 +++- unsupported/Eigen/CXX11/src/Tensor/TensorBase.h | 8 ++++---- 3 files changed, 15 insertions(+), 5 deletions(-) diff --git a/Eigen/src/Core/functors/BinaryFunctors.h b/Eigen/src/Core/functors/BinaryFunctors.h index cc0e80a33..f77066910 100644 --- a/Eigen/src/Core/functors/BinaryFunctors.h +++ b/Eigen/src/Core/functors/BinaryFunctors.h @@ -186,6 +186,14 @@ template struct scalar_cmp_op { EIGEN_EMPTY_STRUCT_CTOR(scalar_cmp_op) EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool operator()(const Scalar& a, const Scalar& b) const {return a<=b;} }; +template struct scalar_cmp_op { + EIGEN_EMPTY_STRUCT_CTOR(scalar_cmp_op) + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool operator()(const Scalar& a, const Scalar& b) const {return a>b;} +}; +template struct scalar_cmp_op { + EIGEN_EMPTY_STRUCT_CTOR(scalar_cmp_op) + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool operator()(const Scalar& a, const Scalar& b) const {return a>=b;} +}; template struct scalar_cmp_op { EIGEN_EMPTY_STRUCT_CTOR(scalar_cmp_op) EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool operator()(const Scalar& a, const Scalar& b) const {return !(a<=b || b<=a);} diff --git a/Eigen/src/Core/util/Constants.h b/Eigen/src/Core/util/Constants.h index c35077af6..28852c8c3 100644 --- a/Eigen/src/Core/util/Constants.h +++ b/Eigen/src/Core/util/Constants.h @@ -531,7 +531,9 @@ enum ComparisonName { cmp_LT = 1, cmp_LE = 2, cmp_UNORD = 3, - cmp_NEQ = 4 + cmp_NEQ = 4, + cmp_GT = 5, + cmp_GE = 6 }; } // end namespace internal diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorBase.h b/unsupported/Eigen/CXX11/src/Tensor/TensorBase.h index ceced984b..906687436 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorBase.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorBase.h @@ -257,12 +257,12 @@ class TensorBase template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorCwiseBinaryOp, const Derived, const OtherDerived> operator<(const OtherDerived& other) const { - return binaryExpr(other.derived(), std::less()); + return binaryExpr(other.derived(), internal::scalar_cmp_op()); } template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorCwiseBinaryOp, const Derived, const OtherDerived> operator<=(const OtherDerived& other) const { - return binaryExpr(other.derived(), std::less_equal()); + return binaryExpr(other.derived(), internal::scalar_cmp_op()); } template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorCwiseBinaryOp, const Derived, const OtherDerived> @@ -278,12 +278,12 @@ class TensorBase template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorCwiseBinaryOp, const Derived, const OtherDerived> operator==(const OtherDerived& other) const { - return binaryExpr(other.derived(), std::equal_to()); + return binaryExpr(other.derived(), internal::scalar_cmp_op()); } template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorCwiseBinaryOp, const Derived, const OtherDerived> operator!=(const OtherDerived& other) const { - return binaryExpr(other.derived(), std::not_equal_to()); + return binaryExpr(other.derived(), internal::scalar_cmp_op()); } // comparisons and tests for Scalars