Reimplement the tensor comparison operators by using the scalar_cmp_op functors. This makes them more cuda friendly.

This commit is contained in:
Benoit Steiner 2015-11-06 09:18:43 -08:00
parent bfd6ee64f3
commit ed1962b464
3 changed files with 15 additions and 5 deletions

View File

@ -186,6 +186,14 @@ template<typename Scalar> struct scalar_cmp_op<Scalar, cmp_LE> {
EIGEN_EMPTY_STRUCT_CTOR(scalar_cmp_op)
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool operator()(const Scalar& a, const Scalar& b) const {return a<=b;}
};
template<typename Scalar> struct scalar_cmp_op<Scalar, cmp_GT> {
EIGEN_EMPTY_STRUCT_CTOR(scalar_cmp_op)
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool operator()(const Scalar& a, const Scalar& b) const {return a>b;}
};
template<typename Scalar> struct scalar_cmp_op<Scalar, cmp_GE> {
EIGEN_EMPTY_STRUCT_CTOR(scalar_cmp_op)
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool operator()(const Scalar& a, const Scalar& b) const {return a>=b;}
};
template<typename Scalar> struct scalar_cmp_op<Scalar, cmp_UNORD> {
EIGEN_EMPTY_STRUCT_CTOR(scalar_cmp_op)
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool operator()(const Scalar& a, const Scalar& b) const {return !(a<=b || b<=a);}

View File

@ -531,7 +531,9 @@ enum ComparisonName {
cmp_LT = 1,
cmp_LE = 2,
cmp_UNORD = 3,
cmp_NEQ = 4
cmp_NEQ = 4,
cmp_GT = 5,
cmp_GE = 6
};
} // end namespace internal

View File

@ -257,12 +257,12 @@ class TensorBase<Derived, ReadOnlyAccessors>
template<typename OtherDerived> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
const TensorCwiseBinaryOp<std::less<Scalar>, const Derived, const OtherDerived>
operator<(const OtherDerived& other) const {
return binaryExpr(other.derived(), std::less<Scalar>());
return binaryExpr(other.derived(), internal::scalar_cmp_op<Scalar, internal::cmp_LT>());
}
template<typename OtherDerived> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
const TensorCwiseBinaryOp<std::less_equal<Scalar>, const Derived, const OtherDerived>
operator<=(const OtherDerived& other) const {
return binaryExpr(other.derived(), std::less_equal<Scalar>());
return binaryExpr(other.derived(), internal::scalar_cmp_op<Scalar, internal::cmp_LE>());
}
template<typename OtherDerived> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
const TensorCwiseBinaryOp<std::greater<Scalar>, const Derived, const OtherDerived>
@ -278,12 +278,12 @@ class TensorBase<Derived, ReadOnlyAccessors>
template<typename OtherDerived> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
const TensorCwiseBinaryOp<std::equal_to<Scalar>, const Derived, const OtherDerived>
operator==(const OtherDerived& other) const {
return binaryExpr(other.derived(), std::equal_to<Scalar>());
return binaryExpr(other.derived(), internal::scalar_cmp_op<Scalar, internal::cmp_EQ>());
}
template<typename OtherDerived> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
const TensorCwiseBinaryOp<std::not_equal_to<Scalar>, const Derived, const OtherDerived>
operator!=(const OtherDerived& other) const {
return binaryExpr(other.derived(), std::not_equal_to<Scalar>());
return binaryExpr(other.derived(), internal::scalar_cmp_op<Scalar, internal::cmp_NEQ>());
}
// comparisons and tests for Scalars