diff --git a/Eigen/src/Core/functors/StlFunctors.h b/Eigen/src/Core/functors/StlFunctors.h index d2e7b5b03..4570c9b63 100644 --- a/Eigen/src/Core/functors/StlFunctors.h +++ b/Eigen/src/Core/functors/StlFunctors.h @@ -12,6 +12,28 @@ namespace Eigen { +// Portable replacements for certain functors. +namespace numext { + +template +struct equal_to { + typedef bool result_type; + EIGEN_DEVICE_FUNC bool operator()(const T& lhs, const T& rhs) const { + return lhs == rhs; + } +}; + +template +struct not_equal_to { + typedef bool result_type; + EIGEN_DEVICE_FUNC bool operator()(const T& lhs, const T& rhs) const { + return lhs != rhs; + } +}; + +} + + namespace internal { // default functor traits for STL functors: @@ -68,10 +90,18 @@ template struct functor_traits > { enum { Cost = 1, PacketAccess = false }; }; +template +struct functor_traits > + : functor_traits > {}; + template struct functor_traits > { enum { Cost = 1, PacketAccess = false }; }; +template +struct functor_traits > + : functor_traits > {}; + #if (EIGEN_COMP_CXXVER < 11) // std::binder* are deprecated since c++11 and will be removed in c++17 template diff --git a/Eigen/src/plugins/MatrixCwiseBinaryOps.h b/Eigen/src/plugins/MatrixCwiseBinaryOps.h index f1084abef..a0feef871 100644 --- a/Eigen/src/plugins/MatrixCwiseBinaryOps.h +++ b/Eigen/src/plugins/MatrixCwiseBinaryOps.h @@ -39,10 +39,10 @@ cwiseProduct(const EIGEN_CURRENT_STORAGE_BASE_CLASS &other) const */ template EIGEN_DEVICE_FUNC -inline const CwiseBinaryOp, const Derived, const OtherDerived> +inline const CwiseBinaryOp, const Derived, const OtherDerived> cwiseEqual(const EIGEN_CURRENT_STORAGE_BASE_CLASS &other) const { - return CwiseBinaryOp, const Derived, const OtherDerived>(derived(), other.derived()); + return CwiseBinaryOp, const Derived, const OtherDerived>(derived(), other.derived()); } /** \returns an expression of the coefficient-wise != operator of *this and \a other @@ -59,10 +59,10 @@ cwiseEqual(const EIGEN_CURRENT_STORAGE_BASE_CLASS &other) const */ template EIGEN_DEVICE_FUNC -inline const CwiseBinaryOp, const Derived, const OtherDerived> +inline const CwiseBinaryOp, const Derived, const OtherDerived> cwiseNotEqual(const EIGEN_CURRENT_STORAGE_BASE_CLASS &other) const { - return CwiseBinaryOp, const Derived, const OtherDerived>(derived(), other.derived()); + return CwiseBinaryOp, const Derived, const OtherDerived>(derived(), other.derived()); } /** \returns an expression of the coefficient-wise min of *this and \a other diff --git a/test/gpu_basic.cu b/test/gpu_basic.cu index bf8dcacde..4298da3bb 100644 --- a/test/gpu_basic.cu +++ b/test/gpu_basic.cu @@ -197,18 +197,17 @@ struct complex_operators { res.segment(block_idx, size) = x1; res.segment(block_idx, size).array() /= x2.array(); block_idx += size; - // Equality comparisons currently not functional on device - // (std::equal_to is host-only). - // const T true_vector = T::Constant(true_value); - // const T false_vector = T::Constant(false_value); - // res.segment(block_idx, size) = (x1 == x2 ? true_vector : false_vector); - // block_idx += size; + const T true_vector = T::Constant(true_value); + const T false_vector = T::Constant(false_value); + res.segment(block_idx, size) = (x1 == x2 ? true_vector : false_vector); + block_idx += size; + // Mixing types in equality comparison does not work. // res.segment(block_idx, size) = (x1 == x2.real() ? true_vector : false_vector); // block_idx += size; // res.segment(block_idx, size) = (x1.real() == x2 ? true_vector : false_vector); // block_idx += size; - // res.segment(block_idx, size) = (x1 != x2 ? true_vector : false_vector); - // block_idx += size; + res.segment(block_idx, size) = (x1 != x2 ? true_vector : false_vector); + block_idx += size; // res.segment(block_idx, size) = (x1 != x2.real() ? true_vector : false_vector); // block_idx += size; // res.segment(block_idx, size) = (x1.real() != x2 ? true_vector : false_vector);