diff --git a/Eigen/src/Core/GenericPacketMath.h b/Eigen/src/Core/GenericPacketMath.h index 8b5f030bf..8322b38f8 100644 --- a/Eigen/src/Core/GenericPacketMath.h +++ b/Eigen/src/Core/GenericPacketMath.h @@ -443,6 +443,12 @@ pnot(const Packet& a) { template EIGEN_DEVICE_FUNC inline Packet pandnot(const Packet& a, const Packet& b) { return pand(a, pnot(b)); } +/** \internal \returns isnan(a) */ +template EIGEN_DEVICE_FUNC inline Packet +pisnan(const Packet& a) { + return pandnot(ptrue(a), pcmp_eq(a, a)); +} + // In the general case, use bitwise select. template struct pselect_impl { diff --git a/Eigen/src/Core/arch/AVX/PacketMath.h b/Eigen/src/Core/arch/AVX/PacketMath.h index 3eb439434..c5e1cc08b 100644 --- a/Eigen/src/Core/arch/AVX/PacketMath.h +++ b/Eigen/src/Core/arch/AVX/PacketMath.h @@ -634,6 +634,7 @@ template<> EIGEN_STRONG_INLINE Packet8f pcmp_le(const Packet8f& a, const Packet8 template<> EIGEN_STRONG_INLINE Packet8f pcmp_lt(const Packet8f& a, const Packet8f& b) { return _mm256_cmp_ps(a,b,_CMP_LT_OQ); } template<> EIGEN_STRONG_INLINE Packet8f pcmp_lt_or_nan(const Packet8f& a, const Packet8f& b) { return _mm256_cmp_ps(a, b, _CMP_NGE_UQ); } template<> EIGEN_STRONG_INLINE Packet8f pcmp_eq(const Packet8f& a, const Packet8f& b) { return _mm256_cmp_ps(a,b,_CMP_EQ_OQ); } +template<> EIGEN_STRONG_INLINE Packet8f pisnan(const Packet8f& a) { return _mm256_cmp_ps(a,a,_CMP_UNORD_Q); } template<> EIGEN_STRONG_INLINE Packet4d pcmp_le(const Packet4d& a, const Packet4d& b) { return _mm256_cmp_pd(a,b,_CMP_LE_OQ); } template<> EIGEN_STRONG_INLINE Packet4d pcmp_lt(const Packet4d& a, const Packet4d& b) { return _mm256_cmp_pd(a,b,_CMP_LT_OQ); } diff --git a/Eigen/src/Core/arch/AVX512/PacketMath.h b/Eigen/src/Core/arch/AVX512/PacketMath.h index 98b55ea06..543f424b4 100644 --- a/Eigen/src/Core/arch/AVX512/PacketMath.h +++ b/Eigen/src/Core/arch/AVX512/PacketMath.h @@ -353,7 +353,7 @@ EIGEN_STRONG_INLINE Packet8d pnegate(const Packet8d& a) { } template <> EIGEN_STRONG_INLINE Packet16i pnegate(const Packet16i& a) { - return _mm512_sub_epi32(_mm512_set1_epi32(0), a); + return _mm512_sub_epi32(_mm512_setzero_si512(), a); } template <> @@ -580,66 +580,72 @@ EIGEN_STRONG_INLINE __m256i Pack32To16(Packet16f rf) { return _mm256_insertf128_si256(_mm256_castsi128_si256(result_lo), result_hi, 1); } +template <> +EIGEN_STRONG_INLINE Packet16f pisnan(const Packet16f& a) { + __mmask16 mask = _mm512_cmp_ps_mask(a, a, _CMP_UNORD_Q); + return _mm512_castsi512_ps(_mm512_maskz_set1_epi32(mask, 0xffffffffu)); +} + template <> EIGEN_STRONG_INLINE Packet16f pcmp_eq(const Packet16f& a, const Packet16f& b) { __mmask16 mask = _mm512_cmp_ps_mask(a, b, _CMP_EQ_OQ); return _mm512_castsi512_ps( - _mm512_mask_set1_epi32(_mm512_set1_epi32(0), mask, 0xffffffffu)); + _mm512_mask_set1_epi32(_mm512_setzero_epi32(), mask, 0xffffffffu)); } template<> EIGEN_STRONG_INLINE Packet16f pcmp_le(const Packet16f& a, const Packet16f& b) { __mmask16 mask = _mm512_cmp_ps_mask(a, b, _CMP_LE_OQ); return _mm512_castsi512_ps( - _mm512_mask_set1_epi32(_mm512_set1_epi32(0), mask, 0xffffffffu)); + _mm512_mask_set1_epi32(_mm512_setzero_epi32(), mask, 0xffffffffu)); } template<> EIGEN_STRONG_INLINE Packet16f pcmp_lt(const Packet16f& a, const Packet16f& b) { __mmask16 mask = _mm512_cmp_ps_mask(a, b, _CMP_LT_OQ); return _mm512_castsi512_ps( - _mm512_mask_set1_epi32(_mm512_set1_epi32(0), mask, 0xffffffffu)); + _mm512_mask_set1_epi32(_mm512_setzero_epi32(), mask, 0xffffffffu)); } template<> EIGEN_STRONG_INLINE Packet16f pcmp_lt_or_nan(const Packet16f& a, const Packet16f& b) { __mmask16 mask = _mm512_cmp_ps_mask(a, b, _CMP_NGE_UQ); return _mm512_castsi512_ps( - _mm512_mask_set1_epi32(_mm512_set1_epi32(0), mask, 0xffffffffu)); + _mm512_mask_set1_epi32(_mm512_setzero_epi32(), mask, 0xffffffffu)); } template<> EIGEN_STRONG_INLINE Packet16i pcmp_eq(const Packet16i& a, const Packet16i& b) { __mmask16 mask = _mm512_cmp_epi32_mask(a, b, _MM_CMPINT_EQ); - return _mm512_mask_set1_epi32(_mm512_set1_epi32(0), mask, 0xffffffffu); + return _mm512_mask_set1_epi32(_mm512_setzero_epi32(), mask, 0xffffffffu); } template<> EIGEN_STRONG_INLINE Packet16i pcmp_le(const Packet16i& a, const Packet16i& b) { __mmask16 mask = _mm512_cmp_epi32_mask(a, b, _MM_CMPINT_LE); - return _mm512_mask_set1_epi32(_mm512_set1_epi32(0), mask, 0xffffffffu); + return _mm512_mask_set1_epi32(_mm512_setzero_epi32(), mask, 0xffffffffu); } template<> EIGEN_STRONG_INLINE Packet16i pcmp_lt(const Packet16i& a, const Packet16i& b) { __mmask16 mask = _mm512_cmp_epi32_mask(a, b, _MM_CMPINT_LT); - return _mm512_mask_set1_epi32(_mm512_set1_epi32(0), mask, 0xffffffffu); + return _mm512_mask_set1_epi32(_mm512_setzero_epi32(), mask, 0xffffffffu); } template <> EIGEN_STRONG_INLINE Packet8d pcmp_eq(const Packet8d& a, const Packet8d& b) { __mmask8 mask = _mm512_cmp_pd_mask(a, b, _CMP_EQ_OQ); return _mm512_castsi512_pd( - _mm512_mask_set1_epi64(_mm512_set1_epi64(0), mask, 0xffffffffffffffffu)); + _mm512_mask_set1_epi64(_mm512_setzero_epi32(), mask, 0xffffffffffffffffu)); } template <> EIGEN_STRONG_INLINE Packet8d pcmp_le(const Packet8d& a, const Packet8d& b) { __mmask8 mask = _mm512_cmp_pd_mask(a, b, _CMP_LE_OQ); return _mm512_castsi512_pd( - _mm512_mask_set1_epi64(_mm512_set1_epi64(0), mask, 0xffffffffffffffffu)); + _mm512_mask_set1_epi64(_mm512_setzero_epi32(), mask, 0xffffffffffffffffu)); } template <> EIGEN_STRONG_INLINE Packet8d pcmp_lt(const Packet8d& a, const Packet8d& b) { __mmask8 mask = _mm512_cmp_pd_mask(a, b, _CMP_LT_OQ); return _mm512_castsi512_pd( - _mm512_mask_set1_epi64(_mm512_set1_epi64(0), mask, 0xffffffffffffffffu)); + _mm512_mask_set1_epi64(_mm512_setzero_epi32(), mask, 0xffffffffffffffffu)); } template <> EIGEN_STRONG_INLINE Packet8d pcmp_lt_or_nan(const Packet8d& a, const Packet8d& b) { __mmask8 mask = _mm512_cmp_pd_mask(a, b, _CMP_NGE_UQ); return _mm512_castsi512_pd( - _mm512_mask_set1_epi64(_mm512_set1_epi64(0), mask, 0xffffffffffffffffu)); + _mm512_mask_set1_epi64(_mm512_setzero_epi32(), mask, 0xffffffffffffffffu)); } template<> EIGEN_STRONG_INLINE Packet16f print(const Packet16f& a) { return _mm512_roundscale_ps(a, _MM_FROUND_CUR_DIRECTION); } diff --git a/Eigen/src/Core/arch/AVX512/TypeCasting.h b/Eigen/src/Core/arch/AVX512/TypeCasting.h index 62a74297f..755cef58c 100644 --- a/Eigen/src/Core/arch/AVX512/TypeCasting.h +++ b/Eigen/src/Core/arch/AVX512/TypeCasting.h @@ -16,6 +16,33 @@ namespace Eigen { namespace internal { +template <> +struct type_casting_traits { + enum { + VectorizedCast = 1, + SrcCoeffRatio = 1, + TgtCoeffRatio = 1 + }; +}; + +template <> +struct type_casting_traits { + enum { + VectorizedCast = 1, + SrcCoeffRatio = 1, + TgtCoeffRatio = 1 + }; +}; + +template<> EIGEN_STRONG_INLINE Packet16b pcast(const Packet16f& a) { + __mmask16 mask = _mm512_cmpneq_ps_mask(a, pzero(a)); + return _mm512_maskz_cvtepi32_epi8(mask, _mm512_set1_epi32(1)); +} + +template<> EIGEN_STRONG_INLINE Packet16f pcast(const Packet16b& a) { + return _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(a)); +} + template<> EIGEN_STRONG_INLINE Packet16i pcast(const Packet16f& a) { return _mm512_cvttps_epi32(a); } diff --git a/Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h b/Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h index 21de0ac22..be215328a 100644 --- a/Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h +++ b/Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h @@ -1124,7 +1124,7 @@ Packet psqrt_complex(const Packet& a) { Packet imag_inf_result; imag_inf_result.v = por(pand(cst_pos_inf, real_mask), pandnot(a.v, real_mask)); // unless otherwise specified, if either the real or imaginary component is nan, the entire result is nan - Packet result_is_nan = pandnot(ptrue(result), pcmp_eq(result, result)); + Packet result_is_nan = pisnan(result); result = por(result_is_nan, result); return pselect(is_imag_inf, imag_inf_result, pselect(is_real_inf, real_inf_result, result)); @@ -1796,7 +1796,7 @@ EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet generic_pow(const Pac const Packet abs_x_is_lt_one = pcmp_lt(abs_x, cst_one); const Packet x_is_one = pandnot(abs_x_is_one, x_is_neg); const Packet x_is_neg_one = pand(abs_x_is_one, x_is_neg); - const Packet x_is_nan = pandnot(ptrue(x), pcmp_eq(x, x)); + const Packet x_is_nan = pisnan(x); // Predicates for sign and magnitude of y. const Packet abs_y = pabs(y); @@ -1804,7 +1804,7 @@ EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet generic_pow(const Pac const Packet abs_y_is_zero = pcmp_eq(abs_y, cst_zero); const Packet y_is_neg = pcmp_lt(y, cst_zero); const Packet y_is_pos = pandnot(ptrue(y), por(abs_y_is_zero, y_is_neg)); - const Packet y_is_nan = pandnot(ptrue(y), pcmp_eq(y, y)); + const Packet y_is_nan = pisnan(y); const Packet abs_y_is_inf = pcmp_eq(abs_y, cst_pos_inf); EIGEN_CONSTEXPR Scalar huge_exponent = (NumTraits::max_exponent() * Scalar(EIGEN_LN2)) / NumTraits::epsilon(); diff --git a/Eigen/src/Core/functors/UnaryFunctors.h b/Eigen/src/Core/functors/UnaryFunctors.h index e909f13f8..8354c0a76 100644 --- a/Eigen/src/Core/functors/UnaryFunctors.h +++ b/Eigen/src/Core/functors/UnaryFunctors.h @@ -859,22 +859,39 @@ struct functor_traits > * \brief Template functor to compute whether a scalar is NaN * \sa class CwiseUnaryOp, ArrayBase::isnan() */ -template struct scalar_isnan_op { - typedef bool result_type; - EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE result_type operator() (const Scalar& a) const { +template +struct scalar_isnan_op { + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool operator() (const Scalar& a) const { #if defined(SYCL_DEVICE_ONLY) return numext::isnan(a); #else - return (numext::isnan)(a); + return numext::isnan EIGEN_NOT_A_MACRO (a); #endif } }; + + template -struct functor_traits > +struct scalar_isnan_op { + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar operator() (const Scalar& a) const { +#if defined(SYCL_DEVICE_ONLY) + return (numext::isnan(a) ? ptrue(a) : pzero(a)); +#else + return (numext::isnan EIGEN_NOT_A_MACRO (a) ? ptrue(a) : pzero(a)); +#endif + } + template + EIGEN_DEVICE_FUNC inline Packet packetOp(const Packet& a) const { + return pisnan(a); + } +}; + +template +struct functor_traits > { enum { Cost = NumTraits::MulCost, - PacketAccess = false + PacketAccess = packet_traits::HasCmp && UseTypedPredicate }; }; diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorBase.h b/unsupported/Eigen/CXX11/src/Tensor/TensorBase.h index 5e6cdb1ca..9ca9ca183 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorBase.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorBase.h @@ -611,12 +611,13 @@ class TensorBase return operator!=(constant(threshold)); } - // Checks + // Predicates. EIGEN_DEVICE_FUNC - EIGEN_STRONG_INLINE const TensorCwiseUnaryOp, const Derived> + EIGEN_STRONG_INLINE const TensorConversionOp, const Derived>> (isnan)() const { - return unaryExpr(internal::scalar_isnan_op()); + return unaryExpr(internal::scalar_isnan_op()).template cast(); } + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorCwiseUnaryOp, const Derived> (isinf)() const { @@ -1219,4 +1220,3 @@ class TensorBase : public TensorBase { } // end namespace Eigen #endif // EIGEN_CXX11_TENSOR_TENSOR_BASE_H - diff --git a/unsupported/test/cxx11_tensor_comparisons.cpp b/unsupported/test/cxx11_tensor_comparisons.cpp index d90ef7a9a..86c73355b 100644 --- a/unsupported/test/cxx11_tensor_comparisons.cpp +++ b/unsupported/test/cxx11_tensor_comparisons.cpp @@ -79,8 +79,36 @@ static void test_equality() } + +static void test_isnan() +{ + Tensor mat(2,3,7); + + mat.setRandom(); + for (int i = 0; i < 2; ++i) { + for (int j = 0; j < 3; ++j) { + for (int k = 0; k < 7; ++k) { + if (internal::random()) { + mat(i,j,k) = std::numeric_limits::quiet_NaN(); + } + } + } + } + Tensor nan(2,3,7); + nan = (mat.isnan)(); + for (int i = 0; i < 2; ++i) { + for (int j = 0; j < 3; ++j) { + for (int k = 0; k < 7; ++k) { + VERIFY_IS_EQUAL(nan(i,j,k), (std::isnan)(mat(i,j,k))); + } + } + } + +} + EIGEN_DECLARE_TEST(cxx11_tensor_comparisons) { CALL_SUBTEST(test_orderings()); CALL_SUBTEST(test_equality()); + CALL_SUBTEST(test_isnan()); }