mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-08-14 04:35:57 +08:00
Reimplement the tensor comparison operators by using the scalar_cmp_op functors. This makes them more cuda friendly.
This commit is contained in:
parent
bfd6ee64f3
commit
ed1962b464
@ -186,6 +186,14 @@ template<typename Scalar> struct scalar_cmp_op<Scalar, cmp_LE> {
|
||||
EIGEN_EMPTY_STRUCT_CTOR(scalar_cmp_op)
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool operator()(const Scalar& a, const Scalar& b) const {return a<=b;}
|
||||
};
|
||||
template<typename Scalar> struct scalar_cmp_op<Scalar, cmp_GT> {
|
||||
EIGEN_EMPTY_STRUCT_CTOR(scalar_cmp_op)
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool operator()(const Scalar& a, const Scalar& b) const {return a>b;}
|
||||
};
|
||||
template<typename Scalar> struct scalar_cmp_op<Scalar, cmp_GE> {
|
||||
EIGEN_EMPTY_STRUCT_CTOR(scalar_cmp_op)
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool operator()(const Scalar& a, const Scalar& b) const {return a>=b;}
|
||||
};
|
||||
template<typename Scalar> struct scalar_cmp_op<Scalar, cmp_UNORD> {
|
||||
EIGEN_EMPTY_STRUCT_CTOR(scalar_cmp_op)
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool operator()(const Scalar& a, const Scalar& b) const {return !(a<=b || b<=a);}
|
||||
|
@ -531,7 +531,9 @@ enum ComparisonName {
|
||||
cmp_LT = 1,
|
||||
cmp_LE = 2,
|
||||
cmp_UNORD = 3,
|
||||
cmp_NEQ = 4
|
||||
cmp_NEQ = 4,
|
||||
cmp_GT = 5,
|
||||
cmp_GE = 6
|
||||
};
|
||||
} // end namespace internal
|
||||
|
||||
|
@ -257,12 +257,12 @@ class TensorBase<Derived, ReadOnlyAccessors>
|
||||
template<typename OtherDerived> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
|
||||
const TensorCwiseBinaryOp<std::less<Scalar>, const Derived, const OtherDerived>
|
||||
operator<(const OtherDerived& other) const {
|
||||
return binaryExpr(other.derived(), std::less<Scalar>());
|
||||
return binaryExpr(other.derived(), internal::scalar_cmp_op<Scalar, internal::cmp_LT>());
|
||||
}
|
||||
template<typename OtherDerived> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
|
||||
const TensorCwiseBinaryOp<std::less_equal<Scalar>, const Derived, const OtherDerived>
|
||||
operator<=(const OtherDerived& other) const {
|
||||
return binaryExpr(other.derived(), std::less_equal<Scalar>());
|
||||
return binaryExpr(other.derived(), internal::scalar_cmp_op<Scalar, internal::cmp_LE>());
|
||||
}
|
||||
template<typename OtherDerived> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
|
||||
const TensorCwiseBinaryOp<std::greater<Scalar>, const Derived, const OtherDerived>
|
||||
@ -278,12 +278,12 @@ class TensorBase<Derived, ReadOnlyAccessors>
|
||||
template<typename OtherDerived> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
|
||||
const TensorCwiseBinaryOp<std::equal_to<Scalar>, const Derived, const OtherDerived>
|
||||
operator==(const OtherDerived& other) const {
|
||||
return binaryExpr(other.derived(), std::equal_to<Scalar>());
|
||||
return binaryExpr(other.derived(), internal::scalar_cmp_op<Scalar, internal::cmp_EQ>());
|
||||
}
|
||||
template<typename OtherDerived> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
|
||||
const TensorCwiseBinaryOp<std::not_equal_to<Scalar>, const Derived, const OtherDerived>
|
||||
operator!=(const OtherDerived& other) const {
|
||||
return binaryExpr(other.derived(), std::not_equal_to<Scalar>());
|
||||
return binaryExpr(other.derived(), internal::scalar_cmp_op<Scalar, internal::cmp_NEQ>());
|
||||
}
|
||||
|
||||
// comparisons and tests for Scalars
|
||||
|
Loading…
x
Reference in New Issue
Block a user