From 988f24b730fe812e2e31d332d33277752fba435d Mon Sep 17 00:00:00 2001 From: Rasmus Munk Larsen Date: Thu, 20 Jun 2019 11:47:49 -0700 Subject: [PATCH] Various fixes for packet ops. 1. Fix buggy pcmp_eq and unit test for half types. 2. Add unit test for pselect and add specializations for SSE 4.1, AVX512, and half types. 3. Get rid of FIXME: Implement faster pnegate for half by XOR'ing with a sign bit mask. --- Eigen/src/Core/arch/AVX512/PacketMath.h | 18 ++++++++++ Eigen/src/Core/arch/GPU/PacketMathHalf.h | 46 ++++++++++++++++++------ Eigen/src/Core/arch/SSE/PacketMath.h | 6 ++++ test/packetmath.cpp | 32 +++++++++++++++-- 4 files changed, 89 insertions(+), 13 deletions(-) diff --git a/Eigen/src/Core/arch/AVX512/PacketMath.h b/Eigen/src/Core/arch/AVX512/PacketMath.h index 64619ecd9..383c49636 100644 --- a/Eigen/src/Core/arch/AVX512/PacketMath.h +++ b/Eigen/src/Core/arch/AVX512/PacketMath.h @@ -252,6 +252,24 @@ EIGEN_STRONG_INLINE Packet8d pmadd(const Packet8d& a, const Packet8d& b, } #endif +template <> +EIGEN_DEVICE_FUNC inline Packet16f pselect(const Packet16f& mask, + const Packet16f& a, + const Packet16f& b) { + __mmask16 mask16 = _mm512_cmp_epi32_mask( + _mm512_castps_si512(mask), _mm512_setzero_epi32(), _MM_CMPINT_EQ); + return _mm512_mask_blend_ps(mask16, a, b); +} + +template <> +EIGEN_DEVICE_FUNC inline Packet8d pselect(const Packet8d& mask, + const Packet8d& a, + const Packet8d& b) { + __mmask8 mask8 = _mm512_cmp_epi64_mask(_mm512_castpd_si512(mask), + _mm512_setzero_epi32(), _MM_CMPINT_EQ); + return _mm512_mask_blend_pd(mask8, a, b); +} + template <> EIGEN_STRONG_INLINE Packet16f pmin(const Packet16f& a, const Packet16f& b) { diff --git a/Eigen/src/Core/arch/GPU/PacketMathHalf.h b/Eigen/src/Core/arch/GPU/PacketMathHalf.h index b04a4d7d6..3273c5ea2 100644 --- a/Eigen/src/Core/arch/GPU/PacketMathHalf.h +++ b/Eigen/src/Core/arch/GPU/PacketMathHalf.h @@ -176,6 +176,15 @@ template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 plset(const Eigen: #endif } +template <> +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 pselect(const half2& mask, + const half2& a, + const half2& b) { + half result_low = __low2half(mask) == half(0) ? __low2half(b) : __low2half(a); + half result_high = __high2half(mask) == half(0) ? __high2half(b) : __high2half(a); + return __halves2half2(result_low, result_high); +} + template <> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 pcmp_eq(const half2& a, const half2& b) { @@ -726,18 +735,29 @@ template<> EIGEN_STRONG_INLINE Packet16h pandnot(const Packet16h& a,const Packet Packet16h r; r.x = pandnot(Packet8i(a.x),Packet8i(b.x)); return r; } +template<> EIGEN_STRONG_INLINE Packet16h pselect(const Packet16h& mask, const Packet16h& a, const Packet16h& b) { + Packet16h r; r.x = _mm256_blendv_epi8(b.x, a.x, mask.x); return r; +} + template<> EIGEN_STRONG_INLINE Packet16h pcmp_eq(const Packet16h& a,const Packet16h& b) { Packet16f af = half2float(a); Packet16f bf = half2float(b); Packet16f rf = pcmp_eq(af, bf); - return float2half(rf); + // Pack the 32-bit flags into 16-bits flags. + __m256i lo = _mm256_castps_si256(extract256<0>(rf)); + __m256i hi = _mm256_castps_si256(extract256<1>(rf)); + __m128i result_lo = _mm_packs_epi32(_mm256_extractf128_si256(lo, 0), + _mm256_extractf128_si256(lo, 1)); + __m128i result_hi = _mm_packs_epi32(_mm256_extractf128_si256(hi, 0), + _mm256_extractf128_si256(hi, 1)); + Packet16h result; result.x = _mm256_insertf128_si256(_mm256_castsi128_si256(result_lo), result_hi, 1); + return result; } template<> EIGEN_STRONG_INLINE Packet16h pnegate(const Packet16h& a) { - // FIXME we could do that with bit manipulation - Packet16f af = half2float(a); - Packet16f rf = pnegate(af); - return float2half(rf); + Packet16h sign_mask; sign_mask.x = _mm256_set1_epi16(static_cast(0x8000)); + Packet16h result; result.x = _mm256_xor_si256(a.x, sign_mask.x); + return result; } template<> EIGEN_STRONG_INLINE Packet16h padd(const Packet16h& a, const Packet16h& b) { @@ -1182,20 +1202,26 @@ template<> EIGEN_STRONG_INLINE Packet8h pandnot(const Packet8h& a,const Packet8h Packet8h r; r.x = _mm_andnot_si128(b.x,a.x); return r; } +template<> EIGEN_STRONG_INLINE Packet8h pselect(const Packet8h& mask, const Packet8h& a, const Packet8h& b) { + Packet8h r; r.x = _mm_blendv_epi8(b.x, a.x, mask.x); return r; +} + template<> EIGEN_STRONG_INLINE Packet8h pcmp_eq(const Packet8h& a,const Packet8h& b) { Packet8f af = half2float(a); Packet8f bf = half2float(b); Packet8f rf = pcmp_eq(af, bf); - return float2half(rf); + // Pack the 32-bit flags into 16-bits flags. + Packet8h result; result.x = _mm_packs_epi32(_mm256_extractf128_si256(_mm256_castps_si256(rf), 0), + _mm256_extractf128_si256(_mm256_castps_si256(rf), 1)); + return result; } template<> EIGEN_STRONG_INLINE Packet8h pconj(const Packet8h& a) { return a; } template<> EIGEN_STRONG_INLINE Packet8h pnegate(const Packet8h& a) { - // FIXME we could do that with bit manipulation - Packet8f af = half2float(a); - Packet8f rf = pnegate(af); - return float2half(rf); + Packet8h sign_mask; sign_mask.x = _mm_set1_epi16(static_cast(0x8000)); + Packet8h result; result.x = _mm_xor_si128(a.x, sign_mask.x); + return result; } template<> EIGEN_STRONG_INLINE Packet8h padd(const Packet8h& a, const Packet8h& b) { diff --git a/Eigen/src/Core/arch/SSE/PacketMath.h b/Eigen/src/Core/arch/SSE/PacketMath.h index b466d6462..0d571ce61 100755 --- a/Eigen/src/Core/arch/SSE/PacketMath.h +++ b/Eigen/src/Core/arch/SSE/PacketMath.h @@ -273,6 +273,12 @@ template<> EIGEN_STRONG_INLINE Packet4f pmadd(const Packet4f& a, const Packet4f& template<> EIGEN_STRONG_INLINE Packet2d pmadd(const Packet2d& a, const Packet2d& b, const Packet2d& c) { return _mm_fmadd_pd(a,b,c); } #endif +#ifdef EIGEN_VECTORIZE_SSE4_1 +template<> EIGEN_DEVICE_FUNC inline Packet4f pselect(const Packet4f& mask, const Packet4f& a, const Packet4f& b) { return _mm_blendv_ps(b,a,mask); } + +template<> EIGEN_DEVICE_FUNC inline Packet2d pselect(const Packet2d& mask, const Packet2d& a, const Packet2d& b) { return _mm_blendv_pd(b,a,mask); } +#endif + template<> EIGEN_STRONG_INLINE Packet4f pmin(const Packet4f& a, const Packet4f& b) { #if EIGEN_COMP_GNUC && EIGEN_COMP_GNUC < 63 // There appears to be a bug in GCC, by which the optimizer may diff --git a/test/packetmath.cpp b/test/packetmath.cpp index d018aaeb0..f1448f335 100644 --- a/test/packetmath.cpp +++ b/test/packetmath.cpp @@ -166,6 +166,16 @@ struct packet_helper VERIFY(areApprox(ref, data2, PacketSize) && #POP); \ } +#define CHECK_CWISE3_IF(COND, REFOP, POP) if (COND) { \ + packet_helper h; \ + for (int i = 0; i < PacketSize; ++i) \ + ref[i] = \ + REFOP(data1[i], data1[i + PacketSize], data1[i + 2 * PacketSize]); \ + h.store(data2, POP(h.load(data1), h.load(data1 + PacketSize), \ + h.load(data1 + 2 * PacketSize))); \ + VERIFY(areApprox(ref, data2, PacketSize) && #POP); \ +} + #define REF_ADD(a,b) ((a)+(b)) #define REF_SUB(a,b) ((a)-(b)) #define REF_MUL(a,b) ((a)*(b)) @@ -447,19 +457,35 @@ template void packetmath() data1[i] = internal::random(); unsigned char v = internal::random() ? 0xff : 0; char* bytes = (char*)(data1+PacketSize+i); - for(int k=0; k() ? 0xff : 0; + char* bytes = (char*)(data1+i); + for(int k=0; k(); + // "else" packet + data1[i+2*PacketSize] = internal::random(); + } + CHECK_CWISE3_IF(true, internal::pselect, internal::pselect); + } { for (int i = 0; i < PacketSize; ++i) { - data1[i] = internal::random(); - data2[i] = (i % 2) ? data1[i] : Scalar(0); + data1[i] = Scalar(i); + data1[i + PacketSize] = internal::random() ? data1[i] : Scalar(0); } CHECK_CWISE2_IF(true, internal::pcmp_eq, internal::pcmp_eq); }