mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-09-23 06:43:13 +08:00
Vectorize tensor.isnan() by using typed predicates.
This commit is contained in:
parent
f02856c640
commit
0488b708b4
@ -443,6 +443,12 @@ pnot(const Packet& a) {
|
||||
template<typename Packet> EIGEN_DEVICE_FUNC inline Packet
|
||||
pandnot(const Packet& a, const Packet& b) { return pand(a, pnot(b)); }
|
||||
|
||||
/** \internal \returns isnan(a) */
|
||||
template<typename Packet> EIGEN_DEVICE_FUNC inline Packet
|
||||
pisnan(const Packet& a) {
|
||||
return pandnot(ptrue(a), pcmp_eq(a, a));
|
||||
}
|
||||
|
||||
// In the general case, use bitwise select.
|
||||
template<typename Packet, typename EnableIf = void>
|
||||
struct pselect_impl {
|
||||
|
@ -634,6 +634,7 @@ template<> EIGEN_STRONG_INLINE Packet8f pcmp_le(const Packet8f& a, const Packet8
|
||||
template<> EIGEN_STRONG_INLINE Packet8f pcmp_lt(const Packet8f& a, const Packet8f& b) { return _mm256_cmp_ps(a,b,_CMP_LT_OQ); }
|
||||
template<> EIGEN_STRONG_INLINE Packet8f pcmp_lt_or_nan(const Packet8f& a, const Packet8f& b) { return _mm256_cmp_ps(a, b, _CMP_NGE_UQ); }
|
||||
template<> EIGEN_STRONG_INLINE Packet8f pcmp_eq(const Packet8f& a, const Packet8f& b) { return _mm256_cmp_ps(a,b,_CMP_EQ_OQ); }
|
||||
template<> EIGEN_STRONG_INLINE Packet8f pisnan(const Packet8f& a) { return _mm256_cmp_ps(a,a,_CMP_UNORD_Q); }
|
||||
|
||||
template<> EIGEN_STRONG_INLINE Packet4d pcmp_le(const Packet4d& a, const Packet4d& b) { return _mm256_cmp_pd(a,b,_CMP_LE_OQ); }
|
||||
template<> EIGEN_STRONG_INLINE Packet4d pcmp_lt(const Packet4d& a, const Packet4d& b) { return _mm256_cmp_pd(a,b,_CMP_LT_OQ); }
|
||||
|
@ -353,7 +353,7 @@ EIGEN_STRONG_INLINE Packet8d pnegate(const Packet8d& a) {
|
||||
}
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet16i pnegate(const Packet16i& a) {
|
||||
return _mm512_sub_epi32(_mm512_set1_epi32(0), a);
|
||||
return _mm512_sub_epi32(_mm512_setzero_si512(), a);
|
||||
}
|
||||
|
||||
template <>
|
||||
@ -580,66 +580,72 @@ EIGEN_STRONG_INLINE __m256i Pack32To16(Packet16f rf) {
|
||||
return _mm256_insertf128_si256(_mm256_castsi128_si256(result_lo), result_hi, 1);
|
||||
}
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet16f pisnan(const Packet16f& a) {
|
||||
__mmask16 mask = _mm512_cmp_ps_mask(a, a, _CMP_UNORD_Q);
|
||||
return _mm512_castsi512_ps(_mm512_maskz_set1_epi32(mask, 0xffffffffu));
|
||||
}
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet16f pcmp_eq(const Packet16f& a, const Packet16f& b) {
|
||||
__mmask16 mask = _mm512_cmp_ps_mask(a, b, _CMP_EQ_OQ);
|
||||
return _mm512_castsi512_ps(
|
||||
_mm512_mask_set1_epi32(_mm512_set1_epi32(0), mask, 0xffffffffu));
|
||||
_mm512_mask_set1_epi32(_mm512_setzero_epi32(), mask, 0xffffffffu));
|
||||
}
|
||||
template<> EIGEN_STRONG_INLINE Packet16f pcmp_le(const Packet16f& a, const Packet16f& b) {
|
||||
__mmask16 mask = _mm512_cmp_ps_mask(a, b, _CMP_LE_OQ);
|
||||
return _mm512_castsi512_ps(
|
||||
_mm512_mask_set1_epi32(_mm512_set1_epi32(0), mask, 0xffffffffu));
|
||||
_mm512_mask_set1_epi32(_mm512_setzero_epi32(), mask, 0xffffffffu));
|
||||
}
|
||||
|
||||
template<> EIGEN_STRONG_INLINE Packet16f pcmp_lt(const Packet16f& a, const Packet16f& b) {
|
||||
__mmask16 mask = _mm512_cmp_ps_mask(a, b, _CMP_LT_OQ);
|
||||
return _mm512_castsi512_ps(
|
||||
_mm512_mask_set1_epi32(_mm512_set1_epi32(0), mask, 0xffffffffu));
|
||||
_mm512_mask_set1_epi32(_mm512_setzero_epi32(), mask, 0xffffffffu));
|
||||
}
|
||||
|
||||
template<> EIGEN_STRONG_INLINE Packet16f pcmp_lt_or_nan(const Packet16f& a, const Packet16f& b) {
|
||||
__mmask16 mask = _mm512_cmp_ps_mask(a, b, _CMP_NGE_UQ);
|
||||
return _mm512_castsi512_ps(
|
||||
_mm512_mask_set1_epi32(_mm512_set1_epi32(0), mask, 0xffffffffu));
|
||||
_mm512_mask_set1_epi32(_mm512_setzero_epi32(), mask, 0xffffffffu));
|
||||
}
|
||||
|
||||
template<> EIGEN_STRONG_INLINE Packet16i pcmp_eq(const Packet16i& a, const Packet16i& b) {
|
||||
__mmask16 mask = _mm512_cmp_epi32_mask(a, b, _MM_CMPINT_EQ);
|
||||
return _mm512_mask_set1_epi32(_mm512_set1_epi32(0), mask, 0xffffffffu);
|
||||
return _mm512_mask_set1_epi32(_mm512_setzero_epi32(), mask, 0xffffffffu);
|
||||
}
|
||||
template<> EIGEN_STRONG_INLINE Packet16i pcmp_le(const Packet16i& a, const Packet16i& b) {
|
||||
__mmask16 mask = _mm512_cmp_epi32_mask(a, b, _MM_CMPINT_LE);
|
||||
return _mm512_mask_set1_epi32(_mm512_set1_epi32(0), mask, 0xffffffffu);
|
||||
return _mm512_mask_set1_epi32(_mm512_setzero_epi32(), mask, 0xffffffffu);
|
||||
}
|
||||
template<> EIGEN_STRONG_INLINE Packet16i pcmp_lt(const Packet16i& a, const Packet16i& b) {
|
||||
__mmask16 mask = _mm512_cmp_epi32_mask(a, b, _MM_CMPINT_LT);
|
||||
return _mm512_mask_set1_epi32(_mm512_set1_epi32(0), mask, 0xffffffffu);
|
||||
return _mm512_mask_set1_epi32(_mm512_setzero_epi32(), mask, 0xffffffffu);
|
||||
}
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet8d pcmp_eq(const Packet8d& a, const Packet8d& b) {
|
||||
__mmask8 mask = _mm512_cmp_pd_mask(a, b, _CMP_EQ_OQ);
|
||||
return _mm512_castsi512_pd(
|
||||
_mm512_mask_set1_epi64(_mm512_set1_epi64(0), mask, 0xffffffffffffffffu));
|
||||
_mm512_mask_set1_epi64(_mm512_setzero_epi32(), mask, 0xffffffffffffffffu));
|
||||
}
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet8d pcmp_le(const Packet8d& a, const Packet8d& b) {
|
||||
__mmask8 mask = _mm512_cmp_pd_mask(a, b, _CMP_LE_OQ);
|
||||
return _mm512_castsi512_pd(
|
||||
_mm512_mask_set1_epi64(_mm512_set1_epi64(0), mask, 0xffffffffffffffffu));
|
||||
_mm512_mask_set1_epi64(_mm512_setzero_epi32(), mask, 0xffffffffffffffffu));
|
||||
}
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet8d pcmp_lt(const Packet8d& a, const Packet8d& b) {
|
||||
__mmask8 mask = _mm512_cmp_pd_mask(a, b, _CMP_LT_OQ);
|
||||
return _mm512_castsi512_pd(
|
||||
_mm512_mask_set1_epi64(_mm512_set1_epi64(0), mask, 0xffffffffffffffffu));
|
||||
_mm512_mask_set1_epi64(_mm512_setzero_epi32(), mask, 0xffffffffffffffffu));
|
||||
}
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet8d pcmp_lt_or_nan(const Packet8d& a, const Packet8d& b) {
|
||||
__mmask8 mask = _mm512_cmp_pd_mask(a, b, _CMP_NGE_UQ);
|
||||
return _mm512_castsi512_pd(
|
||||
_mm512_mask_set1_epi64(_mm512_set1_epi64(0), mask, 0xffffffffffffffffu));
|
||||
_mm512_mask_set1_epi64(_mm512_setzero_epi32(), mask, 0xffffffffffffffffu));
|
||||
}
|
||||
|
||||
template<> EIGEN_STRONG_INLINE Packet16f print<Packet16f>(const Packet16f& a) { return _mm512_roundscale_ps(a, _MM_FROUND_CUR_DIRECTION); }
|
||||
|
@ -16,6 +16,33 @@ namespace Eigen {
|
||||
|
||||
namespace internal {
|
||||
|
||||
template <>
|
||||
struct type_casting_traits<float, bool> {
|
||||
enum {
|
||||
VectorizedCast = 1,
|
||||
SrcCoeffRatio = 1,
|
||||
TgtCoeffRatio = 1
|
||||
};
|
||||
};
|
||||
|
||||
template <>
|
||||
struct type_casting_traits<bool,float> {
|
||||
enum {
|
||||
VectorizedCast = 1,
|
||||
SrcCoeffRatio = 1,
|
||||
TgtCoeffRatio = 1
|
||||
};
|
||||
};
|
||||
|
||||
template<> EIGEN_STRONG_INLINE Packet16b pcast<Packet16f, Packet16b>(const Packet16f& a) {
|
||||
__mmask16 mask = _mm512_cmpneq_ps_mask(a, pzero(a));
|
||||
return _mm512_maskz_cvtepi32_epi8(mask, _mm512_set1_epi32(1));
|
||||
}
|
||||
|
||||
template<> EIGEN_STRONG_INLINE Packet16f pcast<Packet16b, Packet16f>(const Packet16b& a) {
|
||||
return _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(a));
|
||||
}
|
||||
|
||||
template<> EIGEN_STRONG_INLINE Packet16i pcast<Packet16f, Packet16i>(const Packet16f& a) {
|
||||
return _mm512_cvttps_epi32(a);
|
||||
}
|
||||
|
@ -1124,7 +1124,7 @@ Packet psqrt_complex(const Packet& a) {
|
||||
Packet imag_inf_result;
|
||||
imag_inf_result.v = por(pand(cst_pos_inf, real_mask), pandnot(a.v, real_mask));
|
||||
// unless otherwise specified, if either the real or imaginary component is nan, the entire result is nan
|
||||
Packet result_is_nan = pandnot(ptrue(result), pcmp_eq(result, result));
|
||||
Packet result_is_nan = pisnan(result);
|
||||
result = por(result_is_nan, result);
|
||||
|
||||
return pselect(is_imag_inf, imag_inf_result, pselect(is_real_inf, real_inf_result, result));
|
||||
@ -1796,7 +1796,7 @@ EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet generic_pow(const Pac
|
||||
const Packet abs_x_is_lt_one = pcmp_lt(abs_x, cst_one);
|
||||
const Packet x_is_one = pandnot(abs_x_is_one, x_is_neg);
|
||||
const Packet x_is_neg_one = pand(abs_x_is_one, x_is_neg);
|
||||
const Packet x_is_nan = pandnot(ptrue(x), pcmp_eq(x, x));
|
||||
const Packet x_is_nan = pisnan(x);
|
||||
|
||||
// Predicates for sign and magnitude of y.
|
||||
const Packet abs_y = pabs(y);
|
||||
@ -1804,7 +1804,7 @@ EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet generic_pow(const Pac
|
||||
const Packet abs_y_is_zero = pcmp_eq(abs_y, cst_zero);
|
||||
const Packet y_is_neg = pcmp_lt(y, cst_zero);
|
||||
const Packet y_is_pos = pandnot(ptrue(y), por(abs_y_is_zero, y_is_neg));
|
||||
const Packet y_is_nan = pandnot(ptrue(y), pcmp_eq(y, y));
|
||||
const Packet y_is_nan = pisnan(y);
|
||||
const Packet abs_y_is_inf = pcmp_eq(abs_y, cst_pos_inf);
|
||||
EIGEN_CONSTEXPR Scalar huge_exponent =
|
||||
(NumTraits<Scalar>::max_exponent() * Scalar(EIGEN_LN2)) / NumTraits<Scalar>::epsilon();
|
||||
|
@ -859,22 +859,39 @@ struct functor_traits<scalar_ceil_op<Scalar> >
|
||||
* \brief Template functor to compute whether a scalar is NaN
|
||||
* \sa class CwiseUnaryOp, ArrayBase::isnan()
|
||||
*/
|
||||
template<typename Scalar> struct scalar_isnan_op {
|
||||
typedef bool result_type;
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE result_type operator() (const Scalar& a) const {
|
||||
template<typename Scalar, bool UseTypedPredicate=false>
|
||||
struct scalar_isnan_op {
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool operator() (const Scalar& a) const {
|
||||
#if defined(SYCL_DEVICE_ONLY)
|
||||
return numext::isnan(a);
|
||||
#else
|
||||
return (numext::isnan)(a);
|
||||
return numext::isnan EIGEN_NOT_A_MACRO (a);
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
template<typename Scalar>
|
||||
struct functor_traits<scalar_isnan_op<Scalar> >
|
||||
struct scalar_isnan_op<Scalar, true> {
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar operator() (const Scalar& a) const {
|
||||
#if defined(SYCL_DEVICE_ONLY)
|
||||
return (numext::isnan(a) ? ptrue(a) : pzero(a));
|
||||
#else
|
||||
return (numext::isnan EIGEN_NOT_A_MACRO (a) ? ptrue(a) : pzero(a));
|
||||
#endif
|
||||
}
|
||||
template <typename Packet>
|
||||
EIGEN_DEVICE_FUNC inline Packet packetOp(const Packet& a) const {
|
||||
return pisnan(a);
|
||||
}
|
||||
};
|
||||
|
||||
template<typename Scalar, bool UseTypedPredicate>
|
||||
struct functor_traits<scalar_isnan_op<Scalar, UseTypedPredicate> >
|
||||
{
|
||||
enum {
|
||||
Cost = NumTraits<Scalar>::MulCost,
|
||||
PacketAccess = false
|
||||
PacketAccess = packet_traits<Scalar>::HasCmp && UseTypedPredicate
|
||||
};
|
||||
};
|
||||
|
||||
|
@ -611,12 +611,13 @@ class TensorBase<Derived, ReadOnlyAccessors>
|
||||
return operator!=(constant(threshold));
|
||||
}
|
||||
|
||||
// Checks
|
||||
// Predicates.
|
||||
EIGEN_DEVICE_FUNC
|
||||
EIGEN_STRONG_INLINE const TensorCwiseUnaryOp<internal::scalar_isnan_op<Scalar>, const Derived>
|
||||
EIGEN_STRONG_INLINE const TensorConversionOp<bool, const TensorCwiseUnaryOp<internal::scalar_isnan_op<Scalar, true>, const Derived>>
|
||||
(isnan)() const {
|
||||
return unaryExpr(internal::scalar_isnan_op<Scalar>());
|
||||
return unaryExpr(internal::scalar_isnan_op<Scalar, true>()).template cast<bool>();
|
||||
}
|
||||
|
||||
EIGEN_DEVICE_FUNC
|
||||
EIGEN_STRONG_INLINE const TensorCwiseUnaryOp<internal::scalar_isinf_op<Scalar>, const Derived>
|
||||
(isinf)() const {
|
||||
@ -1219,4 +1220,3 @@ class TensorBase : public TensorBase<Derived, ReadOnlyAccessors> {
|
||||
} // end namespace Eigen
|
||||
|
||||
#endif // EIGEN_CXX11_TENSOR_TENSOR_BASE_H
|
||||
|
||||
|
@ -79,8 +79,36 @@ static void test_equality()
|
||||
}
|
||||
|
||||
|
||||
|
||||
static void test_isnan()
|
||||
{
|
||||
Tensor<Scalar, 3> mat(2,3,7);
|
||||
|
||||
mat.setRandom();
|
||||
for (int i = 0; i < 2; ++i) {
|
||||
for (int j = 0; j < 3; ++j) {
|
||||
for (int k = 0; k < 7; ++k) {
|
||||
if (internal::random<bool>()) {
|
||||
mat(i,j,k) = std::numeric_limits<Scalar>::quiet_NaN();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Tensor<bool, 3> nan(2,3,7);
|
||||
nan = (mat.isnan)();
|
||||
for (int i = 0; i < 2; ++i) {
|
||||
for (int j = 0; j < 3; ++j) {
|
||||
for (int k = 0; k < 7; ++k) {
|
||||
VERIFY_IS_EQUAL(nan(i,j,k), (std::isnan)(mat(i,j,k)));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
EIGEN_DECLARE_TEST(cxx11_tensor_comparisons)
|
||||
{
|
||||
CALL_SUBTEST(test_orderings());
|
||||
CALL_SUBTEST(test_equality());
|
||||
CALL_SUBTEST(test_isnan());
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user