Vectorize tensor.isnan() by using typed predicates.

This commit is contained in:
Rasmus Munk Larsen 2023-03-16 04:04:22 +00:00
parent f02856c640
commit 0488b708b4
8 changed files with 110 additions and 25 deletions

View File

@ -443,6 +443,12 @@ pnot(const Packet& a) {
template<typename Packet> EIGEN_DEVICE_FUNC inline Packet template<typename Packet> EIGEN_DEVICE_FUNC inline Packet
pandnot(const Packet& a, const Packet& b) { return pand(a, pnot(b)); } pandnot(const Packet& a, const Packet& b) { return pand(a, pnot(b)); }
/** \internal \returns isnan(a) */
template<typename Packet> EIGEN_DEVICE_FUNC inline Packet
pisnan(const Packet& a) {
return pandnot(ptrue(a), pcmp_eq(a, a));
}
// In the general case, use bitwise select. // In the general case, use bitwise select.
template<typename Packet, typename EnableIf = void> template<typename Packet, typename EnableIf = void>
struct pselect_impl { struct pselect_impl {

View File

@ -634,6 +634,7 @@ template<> EIGEN_STRONG_INLINE Packet8f pcmp_le(const Packet8f& a, const Packet8
template<> EIGEN_STRONG_INLINE Packet8f pcmp_lt(const Packet8f& a, const Packet8f& b) { return _mm256_cmp_ps(a,b,_CMP_LT_OQ); } template<> EIGEN_STRONG_INLINE Packet8f pcmp_lt(const Packet8f& a, const Packet8f& b) { return _mm256_cmp_ps(a,b,_CMP_LT_OQ); }
template<> EIGEN_STRONG_INLINE Packet8f pcmp_lt_or_nan(const Packet8f& a, const Packet8f& b) { return _mm256_cmp_ps(a, b, _CMP_NGE_UQ); } template<> EIGEN_STRONG_INLINE Packet8f pcmp_lt_or_nan(const Packet8f& a, const Packet8f& b) { return _mm256_cmp_ps(a, b, _CMP_NGE_UQ); }
template<> EIGEN_STRONG_INLINE Packet8f pcmp_eq(const Packet8f& a, const Packet8f& b) { return _mm256_cmp_ps(a,b,_CMP_EQ_OQ); } template<> EIGEN_STRONG_INLINE Packet8f pcmp_eq(const Packet8f& a, const Packet8f& b) { return _mm256_cmp_ps(a,b,_CMP_EQ_OQ); }
template<> EIGEN_STRONG_INLINE Packet8f pisnan(const Packet8f& a) { return _mm256_cmp_ps(a,a,_CMP_UNORD_Q); }
template<> EIGEN_STRONG_INLINE Packet4d pcmp_le(const Packet4d& a, const Packet4d& b) { return _mm256_cmp_pd(a,b,_CMP_LE_OQ); } template<> EIGEN_STRONG_INLINE Packet4d pcmp_le(const Packet4d& a, const Packet4d& b) { return _mm256_cmp_pd(a,b,_CMP_LE_OQ); }
template<> EIGEN_STRONG_INLINE Packet4d pcmp_lt(const Packet4d& a, const Packet4d& b) { return _mm256_cmp_pd(a,b,_CMP_LT_OQ); } template<> EIGEN_STRONG_INLINE Packet4d pcmp_lt(const Packet4d& a, const Packet4d& b) { return _mm256_cmp_pd(a,b,_CMP_LT_OQ); }

View File

@ -353,7 +353,7 @@ EIGEN_STRONG_INLINE Packet8d pnegate(const Packet8d& a) {
} }
template <> template <>
EIGEN_STRONG_INLINE Packet16i pnegate(const Packet16i& a) { EIGEN_STRONG_INLINE Packet16i pnegate(const Packet16i& a) {
return _mm512_sub_epi32(_mm512_set1_epi32(0), a); return _mm512_sub_epi32(_mm512_setzero_si512(), a);
} }
template <> template <>
@ -580,66 +580,72 @@ EIGEN_STRONG_INLINE __m256i Pack32To16(Packet16f rf) {
return _mm256_insertf128_si256(_mm256_castsi128_si256(result_lo), result_hi, 1); return _mm256_insertf128_si256(_mm256_castsi128_si256(result_lo), result_hi, 1);
} }
template <>
EIGEN_STRONG_INLINE Packet16f pisnan(const Packet16f& a) {
__mmask16 mask = _mm512_cmp_ps_mask(a, a, _CMP_UNORD_Q);
return _mm512_castsi512_ps(_mm512_maskz_set1_epi32(mask, 0xffffffffu));
}
template <> template <>
EIGEN_STRONG_INLINE Packet16f pcmp_eq(const Packet16f& a, const Packet16f& b) { EIGEN_STRONG_INLINE Packet16f pcmp_eq(const Packet16f& a, const Packet16f& b) {
__mmask16 mask = _mm512_cmp_ps_mask(a, b, _CMP_EQ_OQ); __mmask16 mask = _mm512_cmp_ps_mask(a, b, _CMP_EQ_OQ);
return _mm512_castsi512_ps( return _mm512_castsi512_ps(
_mm512_mask_set1_epi32(_mm512_set1_epi32(0), mask, 0xffffffffu)); _mm512_mask_set1_epi32(_mm512_setzero_epi32(), mask, 0xffffffffu));
} }
template<> EIGEN_STRONG_INLINE Packet16f pcmp_le(const Packet16f& a, const Packet16f& b) { template<> EIGEN_STRONG_INLINE Packet16f pcmp_le(const Packet16f& a, const Packet16f& b) {
__mmask16 mask = _mm512_cmp_ps_mask(a, b, _CMP_LE_OQ); __mmask16 mask = _mm512_cmp_ps_mask(a, b, _CMP_LE_OQ);
return _mm512_castsi512_ps( return _mm512_castsi512_ps(
_mm512_mask_set1_epi32(_mm512_set1_epi32(0), mask, 0xffffffffu)); _mm512_mask_set1_epi32(_mm512_setzero_epi32(), mask, 0xffffffffu));
} }
template<> EIGEN_STRONG_INLINE Packet16f pcmp_lt(const Packet16f& a, const Packet16f& b) { template<> EIGEN_STRONG_INLINE Packet16f pcmp_lt(const Packet16f& a, const Packet16f& b) {
__mmask16 mask = _mm512_cmp_ps_mask(a, b, _CMP_LT_OQ); __mmask16 mask = _mm512_cmp_ps_mask(a, b, _CMP_LT_OQ);
return _mm512_castsi512_ps( return _mm512_castsi512_ps(
_mm512_mask_set1_epi32(_mm512_set1_epi32(0), mask, 0xffffffffu)); _mm512_mask_set1_epi32(_mm512_setzero_epi32(), mask, 0xffffffffu));
} }
template<> EIGEN_STRONG_INLINE Packet16f pcmp_lt_or_nan(const Packet16f& a, const Packet16f& b) { template<> EIGEN_STRONG_INLINE Packet16f pcmp_lt_or_nan(const Packet16f& a, const Packet16f& b) {
__mmask16 mask = _mm512_cmp_ps_mask(a, b, _CMP_NGE_UQ); __mmask16 mask = _mm512_cmp_ps_mask(a, b, _CMP_NGE_UQ);
return _mm512_castsi512_ps( return _mm512_castsi512_ps(
_mm512_mask_set1_epi32(_mm512_set1_epi32(0), mask, 0xffffffffu)); _mm512_mask_set1_epi32(_mm512_setzero_epi32(), mask, 0xffffffffu));
} }
template<> EIGEN_STRONG_INLINE Packet16i pcmp_eq(const Packet16i& a, const Packet16i& b) { template<> EIGEN_STRONG_INLINE Packet16i pcmp_eq(const Packet16i& a, const Packet16i& b) {
__mmask16 mask = _mm512_cmp_epi32_mask(a, b, _MM_CMPINT_EQ); __mmask16 mask = _mm512_cmp_epi32_mask(a, b, _MM_CMPINT_EQ);
return _mm512_mask_set1_epi32(_mm512_set1_epi32(0), mask, 0xffffffffu); return _mm512_mask_set1_epi32(_mm512_setzero_epi32(), mask, 0xffffffffu);
} }
template<> EIGEN_STRONG_INLINE Packet16i pcmp_le(const Packet16i& a, const Packet16i& b) { template<> EIGEN_STRONG_INLINE Packet16i pcmp_le(const Packet16i& a, const Packet16i& b) {
__mmask16 mask = _mm512_cmp_epi32_mask(a, b, _MM_CMPINT_LE); __mmask16 mask = _mm512_cmp_epi32_mask(a, b, _MM_CMPINT_LE);
return _mm512_mask_set1_epi32(_mm512_set1_epi32(0), mask, 0xffffffffu); return _mm512_mask_set1_epi32(_mm512_setzero_epi32(), mask, 0xffffffffu);
} }
template<> EIGEN_STRONG_INLINE Packet16i pcmp_lt(const Packet16i& a, const Packet16i& b) { template<> EIGEN_STRONG_INLINE Packet16i pcmp_lt(const Packet16i& a, const Packet16i& b) {
__mmask16 mask = _mm512_cmp_epi32_mask(a, b, _MM_CMPINT_LT); __mmask16 mask = _mm512_cmp_epi32_mask(a, b, _MM_CMPINT_LT);
return _mm512_mask_set1_epi32(_mm512_set1_epi32(0), mask, 0xffffffffu); return _mm512_mask_set1_epi32(_mm512_setzero_epi32(), mask, 0xffffffffu);
} }
template <> template <>
EIGEN_STRONG_INLINE Packet8d pcmp_eq(const Packet8d& a, const Packet8d& b) { EIGEN_STRONG_INLINE Packet8d pcmp_eq(const Packet8d& a, const Packet8d& b) {
__mmask8 mask = _mm512_cmp_pd_mask(a, b, _CMP_EQ_OQ); __mmask8 mask = _mm512_cmp_pd_mask(a, b, _CMP_EQ_OQ);
return _mm512_castsi512_pd( return _mm512_castsi512_pd(
_mm512_mask_set1_epi64(_mm512_set1_epi64(0), mask, 0xffffffffffffffffu)); _mm512_mask_set1_epi64(_mm512_setzero_epi32(), mask, 0xffffffffffffffffu));
} }
template <> template <>
EIGEN_STRONG_INLINE Packet8d pcmp_le(const Packet8d& a, const Packet8d& b) { EIGEN_STRONG_INLINE Packet8d pcmp_le(const Packet8d& a, const Packet8d& b) {
__mmask8 mask = _mm512_cmp_pd_mask(a, b, _CMP_LE_OQ); __mmask8 mask = _mm512_cmp_pd_mask(a, b, _CMP_LE_OQ);
return _mm512_castsi512_pd( return _mm512_castsi512_pd(
_mm512_mask_set1_epi64(_mm512_set1_epi64(0), mask, 0xffffffffffffffffu)); _mm512_mask_set1_epi64(_mm512_setzero_epi32(), mask, 0xffffffffffffffffu));
} }
template <> template <>
EIGEN_STRONG_INLINE Packet8d pcmp_lt(const Packet8d& a, const Packet8d& b) { EIGEN_STRONG_INLINE Packet8d pcmp_lt(const Packet8d& a, const Packet8d& b) {
__mmask8 mask = _mm512_cmp_pd_mask(a, b, _CMP_LT_OQ); __mmask8 mask = _mm512_cmp_pd_mask(a, b, _CMP_LT_OQ);
return _mm512_castsi512_pd( return _mm512_castsi512_pd(
_mm512_mask_set1_epi64(_mm512_set1_epi64(0), mask, 0xffffffffffffffffu)); _mm512_mask_set1_epi64(_mm512_setzero_epi32(), mask, 0xffffffffffffffffu));
} }
template <> template <>
EIGEN_STRONG_INLINE Packet8d pcmp_lt_or_nan(const Packet8d& a, const Packet8d& b) { EIGEN_STRONG_INLINE Packet8d pcmp_lt_or_nan(const Packet8d& a, const Packet8d& b) {
__mmask8 mask = _mm512_cmp_pd_mask(a, b, _CMP_NGE_UQ); __mmask8 mask = _mm512_cmp_pd_mask(a, b, _CMP_NGE_UQ);
return _mm512_castsi512_pd( return _mm512_castsi512_pd(
_mm512_mask_set1_epi64(_mm512_set1_epi64(0), mask, 0xffffffffffffffffu)); _mm512_mask_set1_epi64(_mm512_setzero_epi32(), mask, 0xffffffffffffffffu));
} }
template<> EIGEN_STRONG_INLINE Packet16f print<Packet16f>(const Packet16f& a) { return _mm512_roundscale_ps(a, _MM_FROUND_CUR_DIRECTION); } template<> EIGEN_STRONG_INLINE Packet16f print<Packet16f>(const Packet16f& a) { return _mm512_roundscale_ps(a, _MM_FROUND_CUR_DIRECTION); }

View File

@ -16,6 +16,33 @@ namespace Eigen {
namespace internal { namespace internal {
template <>
struct type_casting_traits<float, bool> {
enum {
VectorizedCast = 1,
SrcCoeffRatio = 1,
TgtCoeffRatio = 1
};
};
template <>
struct type_casting_traits<bool,float> {
enum {
VectorizedCast = 1,
SrcCoeffRatio = 1,
TgtCoeffRatio = 1
};
};
template<> EIGEN_STRONG_INLINE Packet16b pcast<Packet16f, Packet16b>(const Packet16f& a) {
__mmask16 mask = _mm512_cmpneq_ps_mask(a, pzero(a));
return _mm512_maskz_cvtepi32_epi8(mask, _mm512_set1_epi32(1));
}
template<> EIGEN_STRONG_INLINE Packet16f pcast<Packet16b, Packet16f>(const Packet16b& a) {
return _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(a));
}
template<> EIGEN_STRONG_INLINE Packet16i pcast<Packet16f, Packet16i>(const Packet16f& a) { template<> EIGEN_STRONG_INLINE Packet16i pcast<Packet16f, Packet16i>(const Packet16f& a) {
return _mm512_cvttps_epi32(a); return _mm512_cvttps_epi32(a);
} }

View File

@ -1124,7 +1124,7 @@ Packet psqrt_complex(const Packet& a) {
Packet imag_inf_result; Packet imag_inf_result;
imag_inf_result.v = por(pand(cst_pos_inf, real_mask), pandnot(a.v, real_mask)); imag_inf_result.v = por(pand(cst_pos_inf, real_mask), pandnot(a.v, real_mask));
// unless otherwise specified, if either the real or imaginary component is nan, the entire result is nan // unless otherwise specified, if either the real or imaginary component is nan, the entire result is nan
Packet result_is_nan = pandnot(ptrue(result), pcmp_eq(result, result)); Packet result_is_nan = pisnan(result);
result = por(result_is_nan, result); result = por(result_is_nan, result);
return pselect(is_imag_inf, imag_inf_result, pselect(is_real_inf, real_inf_result, result)); return pselect(is_imag_inf, imag_inf_result, pselect(is_real_inf, real_inf_result, result));
@ -1796,7 +1796,7 @@ EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet generic_pow(const Pac
const Packet abs_x_is_lt_one = pcmp_lt(abs_x, cst_one); const Packet abs_x_is_lt_one = pcmp_lt(abs_x, cst_one);
const Packet x_is_one = pandnot(abs_x_is_one, x_is_neg); const Packet x_is_one = pandnot(abs_x_is_one, x_is_neg);
const Packet x_is_neg_one = pand(abs_x_is_one, x_is_neg); const Packet x_is_neg_one = pand(abs_x_is_one, x_is_neg);
const Packet x_is_nan = pandnot(ptrue(x), pcmp_eq(x, x)); const Packet x_is_nan = pisnan(x);
// Predicates for sign and magnitude of y. // Predicates for sign and magnitude of y.
const Packet abs_y = pabs(y); const Packet abs_y = pabs(y);
@ -1804,7 +1804,7 @@ EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet generic_pow(const Pac
const Packet abs_y_is_zero = pcmp_eq(abs_y, cst_zero); const Packet abs_y_is_zero = pcmp_eq(abs_y, cst_zero);
const Packet y_is_neg = pcmp_lt(y, cst_zero); const Packet y_is_neg = pcmp_lt(y, cst_zero);
const Packet y_is_pos = pandnot(ptrue(y), por(abs_y_is_zero, y_is_neg)); const Packet y_is_pos = pandnot(ptrue(y), por(abs_y_is_zero, y_is_neg));
const Packet y_is_nan = pandnot(ptrue(y), pcmp_eq(y, y)); const Packet y_is_nan = pisnan(y);
const Packet abs_y_is_inf = pcmp_eq(abs_y, cst_pos_inf); const Packet abs_y_is_inf = pcmp_eq(abs_y, cst_pos_inf);
EIGEN_CONSTEXPR Scalar huge_exponent = EIGEN_CONSTEXPR Scalar huge_exponent =
(NumTraits<Scalar>::max_exponent() * Scalar(EIGEN_LN2)) / NumTraits<Scalar>::epsilon(); (NumTraits<Scalar>::max_exponent() * Scalar(EIGEN_LN2)) / NumTraits<Scalar>::epsilon();

View File

@ -859,22 +859,39 @@ struct functor_traits<scalar_ceil_op<Scalar> >
* \brief Template functor to compute whether a scalar is NaN * \brief Template functor to compute whether a scalar is NaN
* \sa class CwiseUnaryOp, ArrayBase::isnan() * \sa class CwiseUnaryOp, ArrayBase::isnan()
*/ */
template<typename Scalar> struct scalar_isnan_op { template<typename Scalar, bool UseTypedPredicate=false>
typedef bool result_type; struct scalar_isnan_op {
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE result_type operator() (const Scalar& a) const { EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool operator() (const Scalar& a) const {
#if defined(SYCL_DEVICE_ONLY) #if defined(SYCL_DEVICE_ONLY)
return numext::isnan(a); return numext::isnan(a);
#else #else
return (numext::isnan)(a); return numext::isnan EIGEN_NOT_A_MACRO (a);
#endif #endif
} }
}; };
template<typename Scalar> template<typename Scalar>
struct functor_traits<scalar_isnan_op<Scalar> > struct scalar_isnan_op<Scalar, true> {
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar operator() (const Scalar& a) const {
#if defined(SYCL_DEVICE_ONLY)
return (numext::isnan(a) ? ptrue(a) : pzero(a));
#else
return (numext::isnan EIGEN_NOT_A_MACRO (a) ? ptrue(a) : pzero(a));
#endif
}
template <typename Packet>
EIGEN_DEVICE_FUNC inline Packet packetOp(const Packet& a) const {
return pisnan(a);
}
};
template<typename Scalar, bool UseTypedPredicate>
struct functor_traits<scalar_isnan_op<Scalar, UseTypedPredicate> >
{ {
enum { enum {
Cost = NumTraits<Scalar>::MulCost, Cost = NumTraits<Scalar>::MulCost,
PacketAccess = false PacketAccess = packet_traits<Scalar>::HasCmp && UseTypedPredicate
}; };
}; };

View File

@ -611,12 +611,13 @@ class TensorBase<Derived, ReadOnlyAccessors>
return operator!=(constant(threshold)); return operator!=(constant(threshold));
} }
// Checks // Predicates.
EIGEN_DEVICE_FUNC EIGEN_DEVICE_FUNC
EIGEN_STRONG_INLINE const TensorCwiseUnaryOp<internal::scalar_isnan_op<Scalar>, const Derived> EIGEN_STRONG_INLINE const TensorConversionOp<bool, const TensorCwiseUnaryOp<internal::scalar_isnan_op<Scalar, true>, const Derived>>
(isnan)() const { (isnan)() const {
return unaryExpr(internal::scalar_isnan_op<Scalar>()); return unaryExpr(internal::scalar_isnan_op<Scalar, true>()).template cast<bool>();
} }
EIGEN_DEVICE_FUNC EIGEN_DEVICE_FUNC
EIGEN_STRONG_INLINE const TensorCwiseUnaryOp<internal::scalar_isinf_op<Scalar>, const Derived> EIGEN_STRONG_INLINE const TensorCwiseUnaryOp<internal::scalar_isinf_op<Scalar>, const Derived>
(isinf)() const { (isinf)() const {
@ -1219,4 +1220,3 @@ class TensorBase : public TensorBase<Derived, ReadOnlyAccessors> {
} // end namespace Eigen } // end namespace Eigen
#endif // EIGEN_CXX11_TENSOR_TENSOR_BASE_H #endif // EIGEN_CXX11_TENSOR_TENSOR_BASE_H

View File

@ -79,8 +79,36 @@ static void test_equality()
} }
static void test_isnan()
{
Tensor<Scalar, 3> mat(2,3,7);
mat.setRandom();
for (int i = 0; i < 2; ++i) {
for (int j = 0; j < 3; ++j) {
for (int k = 0; k < 7; ++k) {
if (internal::random<bool>()) {
mat(i,j,k) = std::numeric_limits<Scalar>::quiet_NaN();
}
}
}
}
Tensor<bool, 3> nan(2,3,7);
nan = (mat.isnan)();
for (int i = 0; i < 2; ++i) {
for (int j = 0; j < 3; ++j) {
for (int k = 0; k < 7; ++k) {
VERIFY_IS_EQUAL(nan(i,j,k), (std::isnan)(mat(i,j,k)));
}
}
}
}
EIGEN_DECLARE_TEST(cxx11_tensor_comparisons) EIGEN_DECLARE_TEST(cxx11_tensor_comparisons)
{ {
CALL_SUBTEST(test_orderings()); CALL_SUBTEST(test_orderings());
CALL_SUBTEST(test_equality()); CALL_SUBTEST(test_equality());
CALL_SUBTEST(test_isnan());
} }