Enable equality comparisons on GPU.

Since `std::equal_to::operator()` is not a device function, it
fails on GPU.  On my device, I seem to get a silent crash in the
kernel (no reported error, but the kernel does not complete).

Replacing this with a portable version enables comparisons on device.

Addresses #2292 - would need to be cherry-picked.  The 3.3 branch
also requires adding `EIGEN_DEVICE_FUNC` in `BooleanRedux.h` to get
fully working.
This commit is contained in:
Antonio Sanchez 2021-07-20 13:53:41 -07:00 committed by Rasmus Munk Larsen
parent c86ac71b4f
commit 7880f10526
3 changed files with 41 additions and 12 deletions

View File

@ -12,6 +12,28 @@
namespace Eigen {
// Portable replacements for certain functors.
namespace numext {
template<typename T = void>
struct equal_to {
typedef bool result_type;
EIGEN_DEVICE_FUNC bool operator()(const T& lhs, const T& rhs) const {
return lhs == rhs;
}
};
template<typename T = void>
struct not_equal_to {
typedef bool result_type;
EIGEN_DEVICE_FUNC bool operator()(const T& lhs, const T& rhs) const {
return lhs != rhs;
}
};
}
namespace internal {
// default functor traits for STL functors:
@ -68,10 +90,18 @@ template<typename T>
struct functor_traits<std::equal_to<T> >
{ enum { Cost = 1, PacketAccess = false }; };
template<typename T>
struct functor_traits<numext::equal_to<T> >
: functor_traits<std::equal_to<T> > {};
template<typename T>
struct functor_traits<std::not_equal_to<T> >
{ enum { Cost = 1, PacketAccess = false }; };
template<typename T>
struct functor_traits<numext::not_equal_to<T> >
: functor_traits<std::not_equal_to<T> > {};
#if (EIGEN_COMP_CXXVER < 11)
// std::binder* are deprecated since c++11 and will be removed in c++17
template<typename T>

View File

@ -39,10 +39,10 @@ cwiseProduct(const EIGEN_CURRENT_STORAGE_BASE_CLASS<OtherDerived> &other) const
*/
template<typename OtherDerived>
EIGEN_DEVICE_FUNC
inline const CwiseBinaryOp<std::equal_to<Scalar>, const Derived, const OtherDerived>
inline const CwiseBinaryOp<numext::equal_to<Scalar>, const Derived, const OtherDerived>
cwiseEqual(const EIGEN_CURRENT_STORAGE_BASE_CLASS<OtherDerived> &other) const
{
return CwiseBinaryOp<std::equal_to<Scalar>, const Derived, const OtherDerived>(derived(), other.derived());
return CwiseBinaryOp<numext::equal_to<Scalar>, const Derived, const OtherDerived>(derived(), other.derived());
}
/** \returns an expression of the coefficient-wise != operator of *this and \a other
@ -59,10 +59,10 @@ cwiseEqual(const EIGEN_CURRENT_STORAGE_BASE_CLASS<OtherDerived> &other) const
*/
template<typename OtherDerived>
EIGEN_DEVICE_FUNC
inline const CwiseBinaryOp<std::not_equal_to<Scalar>, const Derived, const OtherDerived>
inline const CwiseBinaryOp<numext::not_equal_to<Scalar>, const Derived, const OtherDerived>
cwiseNotEqual(const EIGEN_CURRENT_STORAGE_BASE_CLASS<OtherDerived> &other) const
{
return CwiseBinaryOp<std::not_equal_to<Scalar>, const Derived, const OtherDerived>(derived(), other.derived());
return CwiseBinaryOp<numext::not_equal_to<Scalar>, const Derived, const OtherDerived>(derived(), other.derived());
}
/** \returns an expression of the coefficient-wise min of *this and \a other

View File

@ -197,18 +197,17 @@ struct complex_operators {
res.segment(block_idx, size) = x1; res.segment(block_idx, size).array() /= x2.array();
block_idx += size;
// Equality comparisons currently not functional on device
// (std::equal_to<T> is host-only).
// const T true_vector = T::Constant(true_value);
// const T false_vector = T::Constant(false_value);
// res.segment(block_idx, size) = (x1 == x2 ? true_vector : false_vector);
// block_idx += size;
const T true_vector = T::Constant(true_value);
const T false_vector = T::Constant(false_value);
res.segment(block_idx, size) = (x1 == x2 ? true_vector : false_vector);
block_idx += size;
// Mixing types in equality comparison does not work.
// res.segment(block_idx, size) = (x1 == x2.real() ? true_vector : false_vector);
// block_idx += size;
// res.segment(block_idx, size) = (x1.real() == x2 ? true_vector : false_vector);
// block_idx += size;
// res.segment(block_idx, size) = (x1 != x2 ? true_vector : false_vector);
// block_idx += size;
res.segment(block_idx, size) = (x1 != x2 ? true_vector : false_vector);
block_idx += size;
// res.segment(block_idx, size) = (x1 != x2.real() ? true_vector : false_vector);
// block_idx += size;
// res.segment(block_idx, size) = (x1.real() != x2 ? true_vector : false_vector);