mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-04-20 08:39:37 +08:00
Enable equality comparisons on GPU.
Since `std::equal_to::operator()` is not a device function, it fails on GPU. On my device, I seem to get a silent crash in the kernel (no reported error, but the kernel does not complete). Replacing this with a portable version enables comparisons on device. Addresses #2292 - would need to be cherry-picked. The 3.3 branch also requires adding `EIGEN_DEVICE_FUNC` in `BooleanRedux.h` to get fully working.
This commit is contained in:
parent
c86ac71b4f
commit
7880f10526
@ -12,6 +12,28 @@
|
|||||||
|
|
||||||
namespace Eigen {
|
namespace Eigen {
|
||||||
|
|
||||||
|
// Portable replacements for certain functors.
|
||||||
|
namespace numext {
|
||||||
|
|
||||||
|
template<typename T = void>
|
||||||
|
struct equal_to {
|
||||||
|
typedef bool result_type;
|
||||||
|
EIGEN_DEVICE_FUNC bool operator()(const T& lhs, const T& rhs) const {
|
||||||
|
return lhs == rhs;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template<typename T = void>
|
||||||
|
struct not_equal_to {
|
||||||
|
typedef bool result_type;
|
||||||
|
EIGEN_DEVICE_FUNC bool operator()(const T& lhs, const T& rhs) const {
|
||||||
|
return lhs != rhs;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
namespace internal {
|
namespace internal {
|
||||||
|
|
||||||
// default functor traits for STL functors:
|
// default functor traits for STL functors:
|
||||||
@ -68,10 +90,18 @@ template<typename T>
|
|||||||
struct functor_traits<std::equal_to<T> >
|
struct functor_traits<std::equal_to<T> >
|
||||||
{ enum { Cost = 1, PacketAccess = false }; };
|
{ enum { Cost = 1, PacketAccess = false }; };
|
||||||
|
|
||||||
|
template<typename T>
|
||||||
|
struct functor_traits<numext::equal_to<T> >
|
||||||
|
: functor_traits<std::equal_to<T> > {};
|
||||||
|
|
||||||
template<typename T>
|
template<typename T>
|
||||||
struct functor_traits<std::not_equal_to<T> >
|
struct functor_traits<std::not_equal_to<T> >
|
||||||
{ enum { Cost = 1, PacketAccess = false }; };
|
{ enum { Cost = 1, PacketAccess = false }; };
|
||||||
|
|
||||||
|
template<typename T>
|
||||||
|
struct functor_traits<numext::not_equal_to<T> >
|
||||||
|
: functor_traits<std::not_equal_to<T> > {};
|
||||||
|
|
||||||
#if (EIGEN_COMP_CXXVER < 11)
|
#if (EIGEN_COMP_CXXVER < 11)
|
||||||
// std::binder* are deprecated since c++11 and will be removed in c++17
|
// std::binder* are deprecated since c++11 and will be removed in c++17
|
||||||
template<typename T>
|
template<typename T>
|
||||||
|
@ -39,10 +39,10 @@ cwiseProduct(const EIGEN_CURRENT_STORAGE_BASE_CLASS<OtherDerived> &other) const
|
|||||||
*/
|
*/
|
||||||
template<typename OtherDerived>
|
template<typename OtherDerived>
|
||||||
EIGEN_DEVICE_FUNC
|
EIGEN_DEVICE_FUNC
|
||||||
inline const CwiseBinaryOp<std::equal_to<Scalar>, const Derived, const OtherDerived>
|
inline const CwiseBinaryOp<numext::equal_to<Scalar>, const Derived, const OtherDerived>
|
||||||
cwiseEqual(const EIGEN_CURRENT_STORAGE_BASE_CLASS<OtherDerived> &other) const
|
cwiseEqual(const EIGEN_CURRENT_STORAGE_BASE_CLASS<OtherDerived> &other) const
|
||||||
{
|
{
|
||||||
return CwiseBinaryOp<std::equal_to<Scalar>, const Derived, const OtherDerived>(derived(), other.derived());
|
return CwiseBinaryOp<numext::equal_to<Scalar>, const Derived, const OtherDerived>(derived(), other.derived());
|
||||||
}
|
}
|
||||||
|
|
||||||
/** \returns an expression of the coefficient-wise != operator of *this and \a other
|
/** \returns an expression of the coefficient-wise != operator of *this and \a other
|
||||||
@ -59,10 +59,10 @@ cwiseEqual(const EIGEN_CURRENT_STORAGE_BASE_CLASS<OtherDerived> &other) const
|
|||||||
*/
|
*/
|
||||||
template<typename OtherDerived>
|
template<typename OtherDerived>
|
||||||
EIGEN_DEVICE_FUNC
|
EIGEN_DEVICE_FUNC
|
||||||
inline const CwiseBinaryOp<std::not_equal_to<Scalar>, const Derived, const OtherDerived>
|
inline const CwiseBinaryOp<numext::not_equal_to<Scalar>, const Derived, const OtherDerived>
|
||||||
cwiseNotEqual(const EIGEN_CURRENT_STORAGE_BASE_CLASS<OtherDerived> &other) const
|
cwiseNotEqual(const EIGEN_CURRENT_STORAGE_BASE_CLASS<OtherDerived> &other) const
|
||||||
{
|
{
|
||||||
return CwiseBinaryOp<std::not_equal_to<Scalar>, const Derived, const OtherDerived>(derived(), other.derived());
|
return CwiseBinaryOp<numext::not_equal_to<Scalar>, const Derived, const OtherDerived>(derived(), other.derived());
|
||||||
}
|
}
|
||||||
|
|
||||||
/** \returns an expression of the coefficient-wise min of *this and \a other
|
/** \returns an expression of the coefficient-wise min of *this and \a other
|
||||||
|
@ -197,18 +197,17 @@ struct complex_operators {
|
|||||||
res.segment(block_idx, size) = x1; res.segment(block_idx, size).array() /= x2.array();
|
res.segment(block_idx, size) = x1; res.segment(block_idx, size).array() /= x2.array();
|
||||||
block_idx += size;
|
block_idx += size;
|
||||||
|
|
||||||
// Equality comparisons currently not functional on device
|
const T true_vector = T::Constant(true_value);
|
||||||
// (std::equal_to<T> is host-only).
|
const T false_vector = T::Constant(false_value);
|
||||||
// const T true_vector = T::Constant(true_value);
|
res.segment(block_idx, size) = (x1 == x2 ? true_vector : false_vector);
|
||||||
// const T false_vector = T::Constant(false_value);
|
block_idx += size;
|
||||||
// res.segment(block_idx, size) = (x1 == x2 ? true_vector : false_vector);
|
// Mixing types in equality comparison does not work.
|
||||||
// block_idx += size;
|
|
||||||
// res.segment(block_idx, size) = (x1 == x2.real() ? true_vector : false_vector);
|
// res.segment(block_idx, size) = (x1 == x2.real() ? true_vector : false_vector);
|
||||||
// block_idx += size;
|
// block_idx += size;
|
||||||
// res.segment(block_idx, size) = (x1.real() == x2 ? true_vector : false_vector);
|
// res.segment(block_idx, size) = (x1.real() == x2 ? true_vector : false_vector);
|
||||||
// block_idx += size;
|
// block_idx += size;
|
||||||
// res.segment(block_idx, size) = (x1 != x2 ? true_vector : false_vector);
|
res.segment(block_idx, size) = (x1 != x2 ? true_vector : false_vector);
|
||||||
// block_idx += size;
|
block_idx += size;
|
||||||
// res.segment(block_idx, size) = (x1 != x2.real() ? true_vector : false_vector);
|
// res.segment(block_idx, size) = (x1 != x2.real() ? true_vector : false_vector);
|
||||||
// block_idx += size;
|
// block_idx += size;
|
||||||
// res.segment(block_idx, size) = (x1.real() != x2 ? true_vector : false_vector);
|
// res.segment(block_idx, size) = (x1.real() != x2 ? true_vector : false_vector);
|
||||||
|
Loading…
x
Reference in New Issue
Block a user