Vectorize isfinite and isinf.

This commit is contained in:
Rasmus Munk Larsen 2024-05-29 00:20:12 +00:00 committed by Charles Schlosser
parent 5a9f66fb35
commit 9148c47d67
4 changed files with 110 additions and 21 deletions

View File

@ -579,12 +579,6 @@ EIGEN_DEVICE_FUNC inline Packet pandnot(const Packet& a, const Packet& b) {
return pand(a, pnot(b));
}
/** \internal \returns isnan(a) */
template <typename Packet>
EIGEN_DEVICE_FUNC inline Packet pisnan(const Packet& a) {
return pandnot(ptrue(a), pcmp_eq(a, a));
}
// In the general case, use bitwise select.
template <typename Packet, typename EnableIf = void>
struct pselect_impl {
@ -1002,6 +996,20 @@ EIGEN_DEVICE_FUNC inline Packet pcplxflip(const Packet& a) {
* Special math functions
***************************/
/** \internal \returns isnan(a) */
template <typename Packet>
EIGEN_DEVICE_FUNC inline Packet pisnan(const Packet& a) {
return pandnot(ptrue(a), pcmp_eq(a, a));
}
/** \internal \returns isinf(a) */
template <typename Packet>
EIGEN_DEVICE_FUNC inline Packet pisinf(const Packet& a) {
using Scalar = typename unpacket_traits<Packet>::type;
constexpr Scalar inf = NumTraits<Scalar>::infinity();
return pcmp_eq(pabs(a), pset1<Packet>(inf));
}
/** \internal \returns the sine of \a a (coeff-wise) */
template <typename Packet>
EIGEN_DECLARE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet psin(const Packet& a) {

View File

@ -989,10 +989,9 @@ struct functor_traits<scalar_isnan_op<Scalar, UseTypedPredicate>> {
* \brief Template functor to check whether a scalar is +/-inf
* \sa class CwiseUnaryOp, ArrayBase::isinf()
*/
template <typename Scalar>
template <typename Scalar, bool UseTypedPredicate = false>
struct scalar_isinf_op {
typedef bool result_type;
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE result_type operator()(const Scalar& a) const {
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool operator()(const Scalar& a) const {
#if defined(SYCL_DEVICE_ONLY)
return numext::isinf(a);
#else
@ -1000,19 +999,33 @@ struct scalar_isinf_op {
#endif
}
};
template <typename Scalar>
struct functor_traits<scalar_isinf_op<Scalar>> {
enum { Cost = NumTraits<Scalar>::MulCost, PacketAccess = false };
struct scalar_isinf_op<Scalar, true> {
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar operator()(const Scalar& a) const {
#if defined(SYCL_DEVICE_ONLY)
return (numext::isinf(a) ? ptrue(a) : pzero(a));
#else
return (numext::isinf EIGEN_NOT_A_MACRO(a) ? ptrue(a) : pzero(a));
#endif
}
template <typename Packet>
EIGEN_DEVICE_FUNC inline Packet packetOp(const Packet& a) const {
return pisinf(a);
}
};
template <typename Scalar, bool UseTypedPredicate>
struct functor_traits<scalar_isinf_op<Scalar, UseTypedPredicate>> {
enum { Cost = NumTraits<Scalar>::MulCost, PacketAccess = packet_traits<Scalar>::HasCmp && UseTypedPredicate };
};
/** \internal
* \brief Template functor to check whether a scalar has a finite value
* \sa class CwiseUnaryOp, ArrayBase::isfinite()
*/
template <typename Scalar>
template <typename Scalar, bool UseTypedPredicate = false>
struct scalar_isfinite_op {
typedef bool result_type;
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE result_type operator()(const Scalar& a) const {
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool operator()(const Scalar& a) const {
#if defined(SYCL_DEVICE_ONLY)
return numext::isfinite(a);
#else
@ -1020,9 +1033,25 @@ struct scalar_isfinite_op {
#endif
}
};
template <typename Scalar>
struct functor_traits<scalar_isfinite_op<Scalar>> {
enum { Cost = NumTraits<Scalar>::MulCost, PacketAccess = false };
struct scalar_isfinite_op<Scalar, true> {
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar operator()(const Scalar& a) const {
#if defined(SYCL_DEVICE_ONLY)
return (numext::isfinite(a) ? ptrue(a) : pzero(a));
#else
return (numext::isfinite EIGEN_NOT_A_MACRO(a) ? ptrue(a) : pzero(a));
#endif
}
template <typename Packet>
EIGEN_DEVICE_FUNC inline Packet packetOp(const Packet& a) const {
constexpr Scalar inf = NumTraits<Scalar>::infinity();
return pcmp_lt(pabs(a), pset1<Packet>(inf));
}
};
template <typename Scalar, bool UseTypedPredicate>
struct functor_traits<scalar_isfinite_op<Scalar, UseTypedPredicate>> {
enum { Cost = NumTraits<Scalar>::MulCost, PacketAccess = packet_traits<Scalar>::HasCmp && UseTypedPredicate };
};
/** \internal

View File

@ -618,16 +618,15 @@ class TensorBase<Derived, ReadOnlyAccessors>
(isnan)() const {
return unaryExpr(internal::scalar_isnan_op<Scalar, true>()).template cast<bool>();
}
EIGEN_DEVICE_FUNC
EIGEN_STRONG_INLINE const TensorCwiseUnaryOp<internal::scalar_isinf_op<Scalar>, const Derived>
EIGEN_STRONG_INLINE const TensorConversionOp<bool, const TensorCwiseUnaryOp<internal::scalar_isinf_op<Scalar, true>, const Derived>>
(isinf)() const {
return unaryExpr(internal::scalar_isinf_op<Scalar>());
return unaryExpr(internal::scalar_isinf_op<Scalar, true>()).template cast<bool>();
}
EIGEN_DEVICE_FUNC
EIGEN_STRONG_INLINE const TensorCwiseUnaryOp<internal::scalar_isfinite_op<Scalar>, const Derived>
EIGEN_STRONG_INLINE const TensorConversionOp<bool, const TensorCwiseUnaryOp<internal::scalar_isfinite_op<Scalar, true>, const Derived>>
(isfinite)() const {
return unaryExpr(internal::scalar_isfinite_op<Scalar>());
return unaryExpr(internal::scalar_isfinite_op<Scalar, true>()).template cast<bool>();
}
// Coefficient-wise ternary operators.

View File

@ -132,8 +132,61 @@ static void test_isnan() {
}
}
static void test_isinf() {
Tensor<Scalar, 3> mat(2, 3, 7);
mat.setRandom();
for (int i = 0; i < 2; ++i) {
for (int j = 0; j < 3; ++j) {
for (int k = 0; k < 7; ++k) {
if (internal::random<bool>()) {
mat(i, j, k) = std::numeric_limits<Scalar>::infinity();
}
}
}
}
Tensor<bool, 3> inf(2, 3, 7);
inf = (mat.isinf)();
for (int i = 0; i < 2; ++i) {
for (int j = 0; j < 3; ++j) {
for (int k = 0; k < 7; ++k) {
VERIFY_IS_EQUAL(inf(i, j, k), (std::isinf)(mat(i, j, k)));
}
}
}
}
static void test_isfinite() {
Tensor<Scalar, 3> mat(2, 3, 7);
mat.setRandom();
for (int i = 0; i < 2; ++i) {
for (int j = 0; j < 3; ++j) {
for (int k = 0; k < 7; ++k) {
if (internal::random<bool>()) {
mat(i, j, k) = std::numeric_limits<Scalar>::infinity();
}
if (internal::random<bool>()) {
mat(i, j, k) = std::numeric_limits<Scalar>::quiet_NaN();
}
}
}
}
Tensor<bool, 3> inf(2, 3, 7);
inf = (mat.isfinite)();
for (int i = 0; i < 2; ++i) {
for (int j = 0; j < 3; ++j) {
for (int k = 0; k < 7; ++k) {
VERIFY_IS_EQUAL(inf(i, j, k), (std::isfinite)(mat(i, j, k)));
}
}
}
}
EIGEN_DECLARE_TEST(cxx11_tensor_comparisons) {
CALL_SUBTEST(test_orderings());
CALL_SUBTEST(test_equality());
CALL_SUBTEST(test_isnan());
CALL_SUBTEST(test_isinf());
CALL_SUBTEST(test_isfinite());
}