From dc5b1f7d75e9a8d70d85ada780970ea22f2a6b64 Mon Sep 17 00:00:00 2001 From: Jakub Lichman Date: Wed, 25 Aug 2021 19:38:23 +0000 Subject: [PATCH] AVX512 and AVX2 support for Packet16i and Packet8i added --- Eigen/src/Core/arch/AVX/PacketMath.h | 152 +++++++++++- Eigen/src/Core/arch/AVX512/PacketMath.h | 301 +++++++++++++++++++++++- Eigen/src/Core/arch/SSE/PacketMath.h | 5 +- test/vectorization_logic.cpp | 3 +- 4 files changed, 439 insertions(+), 22 deletions(-) diff --git a/Eigen/src/Core/arch/AVX/PacketMath.h b/Eigen/src/Core/arch/AVX/PacketMath.h index 7fc32fd71..dc1a1d6b0 100644 --- a/Eigen/src/Core/arch/AVX/PacketMath.h +++ b/Eigen/src/Core/arch/AVX/PacketMath.h @@ -196,23 +196,21 @@ struct packet_traits : default_packet_traits { HasNdtri = 1 }; }; -#endif -template<> struct scalar_div_cost { enum { value = 14 }; }; -template<> struct scalar_div_cost { enum { value = 16 }; }; - -/* Proper support for integers is only provided by AVX2. In the meantime, we'll - use SSE instructions and packets to deal with integers. -template<> struct packet_traits : default_packet_traits +template<> struct packet_traits : default_packet_traits { typedef Packet8i type; + typedef Packet4i half; enum { Vectorizable = 1, AlignedOnScalar = 1, size=8 }; }; -*/ +#endif + +template<> struct scalar_div_cost { enum { value = 14 }; }; +template<> struct scalar_div_cost { enum { value = 16 }; }; template<> struct unpacket_traits { typedef float type; @@ -226,8 +224,16 @@ template<> struct unpacket_traits { typedef Packet2d half; enum {size=4, alignment=Aligned32, vectorizable=true, masked_load_available=false, masked_store_available=false}; }; -template<> struct unpacket_traits { typedef int type; typedef Packet4i half; enum {size=8, alignment=Aligned32, vectorizable=false, masked_load_available=false, masked_store_available=false}; }; -template<> struct unpacket_traits { typedef bfloat16 type; typedef Packet8bf half; enum {size=8, alignment=Aligned16, vectorizable=true, masked_load_available=false, masked_store_available=false}; }; +template<> struct unpacket_traits { + typedef int type; + typedef Packet4i half; + enum {size=8, alignment=Aligned32, vectorizable=true, masked_load_available=false, masked_store_available=false}; +}; +template<> struct unpacket_traits { + typedef bfloat16 type; + typedef Packet8bf half; + enum {size=8, alignment=Aligned16, vectorizable=true, masked_load_available=false, masked_store_available=false}; +}; // Helper function for bit packing snippet of low precision comparison. // It packs the flags from 16x16 to 8x16. @@ -258,6 +264,7 @@ template<> EIGEN_STRONG_INLINE Packet4d pload1(const double* from) { r template<> EIGEN_STRONG_INLINE Packet8f plset(const float& a) { return _mm256_add_ps(_mm256_set1_ps(a), _mm256_set_ps(7.0,6.0,5.0,4.0,3.0,2.0,1.0,0.0)); } template<> EIGEN_STRONG_INLINE Packet4d plset(const double& a) { return _mm256_add_pd(_mm256_set1_pd(a), _mm256_set_pd(3.0,2.0,1.0,0.0)); } +template<> EIGEN_STRONG_INLINE Packet8i plset(const int& a) { return _mm256_add_epi32(_mm256_set1_epi32(a), _mm256_set_epi32(7,6,5,4,3,2,1,0)); } template<> EIGEN_STRONG_INLINE Packet8f padd(const Packet8f& a, const Packet8f& b) { return _mm256_add_ps(a,b); } template<> EIGEN_STRONG_INLINE Packet4d padd(const Packet4d& a, const Packet4d& b) { return _mm256_add_pd(a,b); } @@ -291,6 +298,10 @@ template<> EIGEN_STRONG_INLINE Packet4d pnegate(const Packet4d& a) { return _mm256_sub_pd(_mm256_set1_pd(0.0),a); } +template<> EIGEN_STRONG_INLINE Packet8i pnegate(const Packet8i& a) +{ + return _mm256_sub_epi32(_mm256_set1_epi32(0), a); +} template<> EIGEN_STRONG_INLINE Packet8f pconj(const Packet8f& a) { return a; } template<> EIGEN_STRONG_INLINE Packet4d pconj(const Packet4d& a) { return a; } @@ -352,7 +363,26 @@ template<> EIGEN_STRONG_INLINE Packet4d pcmp_lt(const Packet4d& a, const Packet4 template<> EIGEN_STRONG_INLINE Packet4d pcmp_lt_or_nan(const Packet4d& a, const Packet4d& b) { return _mm256_cmp_pd(a, b, _CMP_NGE_UQ); } template<> EIGEN_STRONG_INLINE Packet4d pcmp_eq(const Packet4d& a, const Packet4d& b) { return _mm256_cmp_pd(a,b,_CMP_EQ_OQ); } - +template<> EIGEN_STRONG_INLINE Packet8i pcmp_le(const Packet8i& a, const Packet8i& b) { +#ifdef EIGEN_VECTORIZE_AVX2 + return _mm256_xor_si256(_mm256_cmpgt_epi32(a,b), _mm256_set1_epi32(-1)); +#else + __m128i lo = _mm_cmpgt_epi32(_mm256_extractf128_si256(a, 0), _mm256_extractf128_si256(b, 0)); + lo = _mm_xor_si128(lo, _mm_set1_epi32(-1)); + __m128i hi = _mm_cmpgt_epi32(_mm256_extractf128_si256(a, 1), _mm256_extractf128_si256(b, 1)); + hi = _mm_xor_si128(hi, _mm_set1_epi32(-1)); + return _mm256_insertf128_si256(_mm256_castsi128_si256(lo), (hi), 1); +#endif +} +template<> EIGEN_STRONG_INLINE Packet8i pcmp_lt(const Packet8i& a, const Packet8i& b) { +#ifdef EIGEN_VECTORIZE_AVX2 + return _mm256_cmpgt_epi32(b,a); +#else + __m128i lo = _mm_cmpgt_epi32(_mm256_extractf128_si256(b, 0), _mm256_extractf128_si256(a, 0)); + __m128i hi = _mm_cmpgt_epi32(_mm256_extractf128_si256(b, 1), _mm256_extractf128_si256(a, 1)); + return _mm256_insertf128_si256(_mm256_castsi128_si256(lo), (hi), 1); +#endif +} template<> EIGEN_STRONG_INLINE Packet8i pcmp_eq(const Packet8i& a, const Packet8i& b) { #ifdef EIGEN_VECTORIZE_AVX2 return _mm256_cmpeq_epi32(a,b); @@ -388,6 +418,9 @@ template<> EIGEN_STRONG_INLINE Packet4d pmin(const Packet4d& a, const return _mm256_min_pd(b,a); #endif } +template<> EIGEN_STRONG_INLINE Packet8i pmin(const Packet8i& a, const Packet8i& b) { + return _mm256_min_epi32(a, b); +} template<> EIGEN_STRONG_INLINE Packet8f pmax(const Packet8f& a, const Packet8f& b) { #if EIGEN_COMP_GNUC && EIGEN_COMP_GNUC < 63 @@ -411,6 +444,9 @@ template<> EIGEN_STRONG_INLINE Packet4d pmax(const Packet4d& a, const return _mm256_max_pd(b,a); #endif } +template<> EIGEN_STRONG_INLINE Packet8i pmax(const Packet8i& a, const Packet8i& b) { + return _mm256_max_epi32(a, b); +} // Add specializations for min/max with prescribed NaN progation. template<> @@ -611,6 +647,12 @@ template<> EIGEN_STRONG_INLINE Packet4d ploaddup(const double* from) Packet4d tmp = _mm256_broadcast_pd((const __m128d*)(const void*)from); return _mm256_permute_pd(tmp, 3<<2); } +// Loads 4 integers from memory a returns the packet {a0, a0 a1, a1, a2, a2, a3, a3} +template<> EIGEN_STRONG_INLINE Packet8i ploaddup(const int* from) +{ + Packet8i a = _mm256_castsi128_si256(pload(from)); + return _mm256_permutevar8x32_epi32(a, _mm256_setr_epi32(0, 0, 1, 1, 2, 2, 3, 3)); +} // Loads 2 floats from memory a returns the packet {a0, a0 a0, a0, a1, a1, a1, a1} template<> EIGEN_STRONG_INLINE Packet8f ploadquad(const float* from) @@ -618,6 +660,10 @@ template<> EIGEN_STRONG_INLINE Packet8f ploadquad(const float* from) Packet8f tmp = _mm256_castps128_ps256(_mm_broadcast_ss(from)); return _mm256_insertf128_ps(tmp, _mm_broadcast_ss(from+1), 1); } +template<> EIGEN_STRONG_INLINE Packet8i ploadquad(const int* from) +{ + return _mm256_inserti128_si256(_mm256_set1_epi32(*from), _mm_set1_epi32(*(from+1)), 1); +} template<> EIGEN_STRONG_INLINE void pstore(float* to, const Packet8f& from) { EIGEN_DEBUG_ALIGNED_STORE _mm256_store_ps(to, from); } template<> EIGEN_STRONG_INLINE void pstore(double* to, const Packet4d& from) { EIGEN_DEBUG_ALIGNED_STORE _mm256_store_pd(to, from); } @@ -646,6 +692,11 @@ template<> EIGEN_DEVICE_FUNC inline Packet4d pgather(const dou { return _mm256_set_pd(from[3*stride], from[2*stride], from[1*stride], from[0*stride]); } +template<> EIGEN_DEVICE_FUNC inline Packet8i pgather(const int* from, Index stride) +{ + return _mm256_set_epi32(from[7*stride], from[6*stride], from[5*stride], from[4*stride], + from[3*stride], from[2*stride], from[1*stride], from[0*stride]); +} template<> EIGEN_DEVICE_FUNC inline void pscatter(float* to, const Packet8f& from, Index stride) { @@ -670,6 +721,20 @@ template<> EIGEN_DEVICE_FUNC inline void pscatter(double* to, to[stride*2] = _mm_cvtsd_f64(high); to[stride*3] = _mm_cvtsd_f64(_mm_shuffle_pd(high, high, 1)); } +template<> EIGEN_DEVICE_FUNC inline void pscatter(int* to, const Packet8i& from, Index stride) +{ + __m128i low = _mm256_extracti128_si256(from, 0); + to[stride*0] = _mm_extract_epi32(low, 0); + to[stride*1] = _mm_extract_epi32(low, 1); + to[stride*2] = _mm_extract_epi32(low, 2); + to[stride*3] = _mm_extract_epi32(low, 3); + + __m128i high = _mm256_extracti128_si256(from, 1); + to[stride*4] = _mm_extract_epi32(high, 0); + to[stride*5] = _mm_extract_epi32(high, 1); + to[stride*6] = _mm_extract_epi32(high, 2); + to[stride*7] = _mm_extract_epi32(high, 3); +} template<> EIGEN_STRONG_INLINE void pstore1(float* to, const float& a) { @@ -720,6 +785,10 @@ template<> EIGEN_STRONG_INLINE Packet4d preverse(const Packet4d& a) return _mm256_permute_pd(swap_halves,5); #endif } +template<> EIGEN_STRONG_INLINE Packet8i preverse(const Packet8i& a) +{ + return _mm256_castps_si256(preverse(_mm256_castsi256_ps(a))); +} // pabs should be ok template<> EIGEN_STRONG_INLINE Packet8f pabs(const Packet8f& a) @@ -732,6 +801,10 @@ template<> EIGEN_STRONG_INLINE Packet4d pabs(const Packet4d& a) const Packet4d mask = _mm256_castsi256_pd(_mm256_setr_epi32(0xFFFFFFFF,0x7FFFFFFF,0xFFFFFFFF,0x7FFFFFFF,0xFFFFFFFF,0x7FFFFFFF,0xFFFFFFFF,0x7FFFFFFF)); return _mm256_and_pd(a,mask); } +template<> EIGEN_STRONG_INLINE Packet8i pabs(const Packet8i& a) +{ + return _mm256_abs_epi32(a); +} template<> EIGEN_STRONG_INLINE Packet8f pfrexp(const Packet8f& a, Packet8f& exponent) { return pfrexp_generic(a,exponent); @@ -803,11 +876,19 @@ template<> EIGEN_STRONG_INLINE double predux(const Packet4d& a) { return predux(Packet2d(_mm_add_pd(_mm256_castpd256_pd128(a),_mm256_extractf128_pd(a,1)))); } +template<> EIGEN_STRONG_INLINE int predux(const Packet8i& a) +{ + return predux(Packet4i(_mm_add_epi32(_mm256_castsi256_si128(a),_mm256_extractf128_si256(a,1)))); +} template<> EIGEN_STRONG_INLINE Packet4f predux_half_dowto4(const Packet8f& a) { return _mm_add_ps(_mm256_castps256_ps128(a),_mm256_extractf128_ps(a,1)); } +template<> EIGEN_STRONG_INLINE Packet4i predux_half_dowto4(const Packet8i& a) +{ + return _mm_add_epi32(_mm256_castsi256_si128(a),_mm256_extractf128_si256(a,1)); +} template<> EIGEN_STRONG_INLINE float predux_mul(const Packet8f& a) { @@ -905,6 +986,55 @@ ptranspose(PacketBlock& kernel) { kernel.packet[3] = _mm256_permute2f128_ps(S2, S3, 0x31); } +#define MM256_SHUFFLE_EPI32(A, B, M) \ + _mm256_castps_si256(_mm256_shuffle_ps(_mm256_castsi256_ps(A), _mm256_castsi256_ps(B), M)) + +EIGEN_DEVICE_FUNC inline void +ptranspose(PacketBlock& kernel) { + __m256i T0 = _mm256_unpacklo_epi32(kernel.packet[0], kernel.packet[1]); + __m256i T1 = _mm256_unpackhi_epi32(kernel.packet[0], kernel.packet[1]); + __m256i T2 = _mm256_unpacklo_epi32(kernel.packet[2], kernel.packet[3]); + __m256i T3 = _mm256_unpackhi_epi32(kernel.packet[2], kernel.packet[3]); + __m256i T4 = _mm256_unpacklo_epi32(kernel.packet[4], kernel.packet[5]); + __m256i T5 = _mm256_unpackhi_epi32(kernel.packet[4], kernel.packet[5]); + __m256i T6 = _mm256_unpacklo_epi32(kernel.packet[6], kernel.packet[7]); + __m256i T7 = _mm256_unpackhi_epi32(kernel.packet[6], kernel.packet[7]); + __m256i S0 = MM256_SHUFFLE_EPI32(T0,T2,_MM_SHUFFLE(1,0,1,0)); + __m256i S1 = MM256_SHUFFLE_EPI32(T0,T2,_MM_SHUFFLE(3,2,3,2)); + __m256i S2 = MM256_SHUFFLE_EPI32(T1,T3,_MM_SHUFFLE(1,0,1,0)); + __m256i S3 = MM256_SHUFFLE_EPI32(T1,T3,_MM_SHUFFLE(3,2,3,2)); + __m256i S4 = MM256_SHUFFLE_EPI32(T4,T6,_MM_SHUFFLE(1,0,1,0)); + __m256i S5 = MM256_SHUFFLE_EPI32(T4,T6,_MM_SHUFFLE(3,2,3,2)); + __m256i S6 = MM256_SHUFFLE_EPI32(T5,T7,_MM_SHUFFLE(1,0,1,0)); + __m256i S7 = MM256_SHUFFLE_EPI32(T5,T7,_MM_SHUFFLE(3,2,3,2)); + kernel.packet[0] = _mm256_permute2f128_si256(S0, S4, 0x20); + kernel.packet[1] = _mm256_permute2f128_si256(S1, S5, 0x20); + kernel.packet[2] = _mm256_permute2f128_si256(S2, S6, 0x20); + kernel.packet[3] = _mm256_permute2f128_si256(S3, S7, 0x20); + kernel.packet[4] = _mm256_permute2f128_si256(S0, S4, 0x31); + kernel.packet[5] = _mm256_permute2f128_si256(S1, S5, 0x31); + kernel.packet[6] = _mm256_permute2f128_si256(S2, S6, 0x31); + kernel.packet[7] = _mm256_permute2f128_si256(S3, S7, 0x31); +} + +EIGEN_DEVICE_FUNC inline void +ptranspose(PacketBlock& kernel) { + __m256i T0 = _mm256_unpacklo_epi32(kernel.packet[0], kernel.packet[1]); + __m256i T1 = _mm256_unpackhi_epi32(kernel.packet[0], kernel.packet[1]); + __m256i T2 = _mm256_unpacklo_epi32(kernel.packet[2], kernel.packet[3]); + __m256i T3 = _mm256_unpackhi_epi32(kernel.packet[2], kernel.packet[3]); + + __m256i S0 = MM256_SHUFFLE_EPI32(T0,T2,_MM_SHUFFLE(1,0,1,0)); + __m256i S1 = MM256_SHUFFLE_EPI32(T0,T2,_MM_SHUFFLE(3,2,3,2)); + __m256i S2 = MM256_SHUFFLE_EPI32(T1,T3,_MM_SHUFFLE(1,0,1,0)); + __m256i S3 = MM256_SHUFFLE_EPI32(T1,T3,_MM_SHUFFLE(3,2,3,2)); + + kernel.packet[0] = _mm256_permute2f128_si256(S0, S1, 0x20); + kernel.packet[1] = _mm256_permute2f128_si256(S2, S3, 0x20); + kernel.packet[2] = _mm256_permute2f128_si256(S0, S1, 0x31); + kernel.packet[3] = _mm256_permute2f128_si256(S2, S3, 0x31); +} + EIGEN_DEVICE_FUNC inline void ptranspose(PacketBlock& kernel) { __m256d T0 = _mm256_shuffle_pd(kernel.packet[0], kernel.packet[1], 15); diff --git a/Eigen/src/Core/arch/AVX512/PacketMath.h b/Eigen/src/Core/arch/AVX512/PacketMath.h index 34d49ab66..6ce15c677 100644 --- a/Eigen/src/Core/arch/AVX512/PacketMath.h +++ b/Eigen/src/Core/arch/AVX512/PacketMath.h @@ -153,17 +153,16 @@ template<> struct packet_traits : default_packet_traits }; }; -/* TODO Implement AVX512 for integers -template<> struct packet_traits : default_packet_traits +template<> struct packet_traits : default_packet_traits { typedef Packet16i type; + typedef Packet8i half; enum { Vectorizable = 1, AlignedOnScalar = 1, - size=8 + size=16 }; }; -*/ template <> struct unpacket_traits { @@ -183,7 +182,7 @@ template <> struct unpacket_traits { typedef int type; typedef Packet8i half; - enum { size = 16, alignment=Aligned64, vectorizable=false, masked_load_available=false, masked_store_available=false }; + enum { size = 16, alignment=Aligned64, vectorizable=true, masked_load_available=false, masked_store_available=false }; }; template<> @@ -254,6 +253,12 @@ EIGEN_STRONG_INLINE Packet8d plset(const double& a) { return _mm512_add_pd(_mm512_set1_pd(a), _mm512_set_pd(7.0, 6.0, 5.0, 4.0, 3.0, 2.0, 1.0, 0.0)); } +template <> +EIGEN_STRONG_INLINE Packet16i plset(const int& a) { + return _mm512_add_epi32( + _mm512_set1_epi32(a), + _mm512_set_epi32(15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0)); +} template <> EIGEN_STRONG_INLINE Packet16f padd(const Packet16f& a, @@ -295,6 +300,10 @@ template <> EIGEN_STRONG_INLINE Packet8d pnegate(const Packet8d& a) { return _mm512_sub_pd(_mm512_set1_pd(0.0), a); } +template <> +EIGEN_STRONG_INLINE Packet16i pnegate(const Packet16i& a) { + return _mm512_sub_epi32(_mm512_set1_epi32(0), a); +} template <> EIGEN_STRONG_INLINE Packet16f pconj(const Packet16f& a) { @@ -379,6 +388,11 @@ EIGEN_STRONG_INLINE Packet8d pmin(const Packet8d& a, // Arguments are reversed to match NaN propagation behavior of std::min. return _mm512_min_pd(b, a); } +template <> +EIGEN_STRONG_INLINE Packet16i pmin(const Packet16i& a, + const Packet16i& b) { + return _mm512_min_epi32(b, a); +} template <> EIGEN_STRONG_INLINE Packet16f pmax(const Packet16f& a, @@ -392,6 +406,11 @@ EIGEN_STRONG_INLINE Packet8d pmax(const Packet8d& a, // Arguments are reversed to match NaN propagation behavior of std::max. return _mm512_max_pd(b, a); } +template <> +EIGEN_STRONG_INLINE Packet16i pmax(const Packet16i& a, + const Packet16i& b) { + return _mm512_max_epi32(b, a); +} // Add specializations for min/max with prescribed NaN progation. template<> @@ -493,10 +512,17 @@ template<> EIGEN_STRONG_INLINE Packet16f pcmp_lt_or_nan(const Packet16f& a, cons } template<> EIGEN_STRONG_INLINE Packet16i pcmp_eq(const Packet16i& a, const Packet16i& b) { - __mmask16 mask = _mm512_cmp_epi32_mask(a, b, _CMP_EQ_OQ); + __mmask16 mask = _mm512_cmp_epi32_mask(a, b, _MM_CMPINT_EQ); + return _mm512_mask_set1_epi32(_mm512_set1_epi32(0), mask, 0xffffffffu); +} +template<> EIGEN_STRONG_INLINE Packet16i pcmp_le(const Packet16i& a, const Packet16i& b) { + __mmask16 mask = _mm512_cmp_epi32_mask(a, b, _MM_CMPINT_LE); + return _mm512_mask_set1_epi32(_mm512_set1_epi32(0), mask, 0xffffffffu); +} +template<> EIGEN_STRONG_INLINE Packet16i pcmp_lt(const Packet16i& a, const Packet16i& b) { + __mmask16 mask = _mm512_cmp_epi32_mask(a, b, _MM_CMPINT_LT); return _mm512_mask_set1_epi32(_mm512_set1_epi32(0), mask, 0xffffffffu); } - template <> EIGEN_STRONG_INLINE Packet8d pcmp_eq(const Packet8d& a, const Packet8d& b) { @@ -746,6 +772,16 @@ EIGEN_STRONG_INLINE Packet8d ploaddup(const double* from) { } #endif +// Loads 8 integers from memory and returns the packet +// {a0, a0 a1, a1, a2, a2, a3, a3, a4, a4, a5, a5, a6, a6, a7, a7} +template <> +EIGEN_STRONG_INLINE Packet16i ploaddup(const int* from) { + __m256i low_half = _mm256_loadu_si256(reinterpret_cast(from)); + __m512 even_elements = _mm512_castsi512_ps(_mm512_cvtepu32_epi64(low_half)); + __m512 pairs = _mm512_permute_ps(even_elements, _MM_SHUFFLE(2, 2, 0, 0)); + return _mm512_castps_si512(pairs); +} + // Loads 4 floats from memory a returns the packet // {a0, a0 a0, a0, a1, a1, a1, a1, a2, a2, a2, a2, a3, a3, a3, a3} template <> @@ -766,6 +802,15 @@ EIGEN_STRONG_INLINE Packet8d ploadquad(const double* from) { return _mm512_insertf64x4(tmp, lane1, 1); } +// Loads 4 integers from memory and returns the packet +// {a0, a0 a0, a0, a1, a1, a1, a1, a2, a2, a2, a2, a3, a3, a3, a3} +template <> +EIGEN_STRONG_INLINE Packet16i ploadquad(const int* from) { + Packet16i tmp = _mm512_castsi128_si512(ploadu(from)); + const Packet16i scatter_mask = _mm512_set_epi32(3,3,3,3, 2,2,2,2, 1,1,1,1, 0,0,0,0); + return _mm512_permutexvar_epi32(scatter_mask, tmp); +} + template <> EIGEN_STRONG_INLINE void pstore(float* to, const Packet16f& from) { EIGEN_DEBUG_ALIGNED_STORE _mm512_store_ps(to, from); @@ -818,6 +863,15 @@ EIGEN_DEVICE_FUNC inline Packet8d pgather(const double* from, return _mm512_i32gather_pd(indices, from, 8); } +template <> +EIGEN_DEVICE_FUNC inline Packet16i pgather(const int* from, + Index stride) { + Packet16i stride_vector = _mm512_set1_epi32(convert_index(stride)); + Packet16i stride_multiplier = + _mm512_set_epi32(15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0); + Packet16i indices = _mm512_mullo_epi32(stride_vector, stride_multiplier); + return _mm512_i32gather_epi32(indices, from, 4); +} template <> EIGEN_DEVICE_FUNC inline void pscatter(float* to, @@ -838,6 +892,16 @@ EIGEN_DEVICE_FUNC inline void pscatter(double* to, Packet8i indices = _mm256_mullo_epi32(stride_vector, stride_multiplier); _mm512_i32scatter_pd(to, indices, from, 8); } +template <> +EIGEN_DEVICE_FUNC inline void pscatter(int* to, + const Packet16i& from, + Index stride) { + Packet16i stride_vector = _mm512_set1_epi32(convert_index(stride)); + Packet16i stride_multiplier = + _mm512_set_epi32(15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0); + Packet16i indices = _mm512_mullo_epi32(stride_vector, stride_multiplier); + _mm512_i32scatter_epi32(to, indices, from, 4); +} template <> EIGEN_STRONG_INLINE void pstore1(float* to, const float& a) { @@ -882,6 +946,11 @@ template<> EIGEN_STRONG_INLINE Packet8d preverse(const Packet8d& a) return _mm512_permutexvar_pd(_mm512_set_epi32(0, 0, 0, 1, 0, 2, 0, 3, 0, 4, 0, 5, 0, 6, 0, 7), a); } +template<> EIGEN_STRONG_INLINE Packet16i preverse(const Packet16i& a) +{ + return _mm512_permutexvar_epi32(_mm512_set_epi32(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15), a); +} + template<> EIGEN_STRONG_INLINE Packet16f pabs(const Packet16f& a) { // _mm512_abs_ps intrinsic not found, so hack around it @@ -893,6 +962,10 @@ EIGEN_STRONG_INLINE Packet8d pabs(const Packet8d& a) { return _mm512_castsi512_pd(_mm512_and_si512(_mm512_castpd_si512(a), _mm512_set1_epi64(0x7fffffffffffffff))); } +template<> EIGEN_STRONG_INLINE Packet16i pabs(const Packet16i& a) +{ + return _mm512_abs_epi32(a); +} template<> EIGEN_STRONG_INLINE Packet16f pfrexp(const Packet16f& a, Packet16f& exponent){ @@ -952,6 +1025,11 @@ template<> EIGEN_STRONG_INLINE Packet8d pldexp(const Packet8d& a, cons #define EIGEN_EXTRACT_8f_FROM_16f(INPUT, OUTPUT) \ __m256 OUTPUT##_0 = _mm512_extractf32x8_ps(INPUT, 0); \ __m256 OUTPUT##_1 = _mm512_extractf32x8_ps(INPUT, 1) + +// AVX512F does not define _mm512_extracti32x8_epi32 to extract _m256i from _m512i +#define EIGEN_EXTRACT_8i_FROM_16i(INPUT, OUTPUT) \ + __m256i OUTPUT##_0 = _mm512_extracti32x8_epi32(INPUT, 0) \ + __m256i OUTPUT##_1 = _mm512_extracti32x8_epi32(INPUT, 1) #else #define EIGEN_EXTRACT_8f_FROM_16f(INPUT, OUTPUT) \ __m256 OUTPUT##_0 = _mm256_insertf128_ps( \ @@ -960,11 +1038,22 @@ template<> EIGEN_STRONG_INLINE Packet8d pldexp(const Packet8d& a, cons __m256 OUTPUT##_1 = _mm256_insertf128_ps( \ _mm256_castps128_ps256(_mm512_extractf32x4_ps(INPUT, 2)), \ _mm512_extractf32x4_ps(INPUT, 3), 1); + +#define EIGEN_EXTRACT_8i_FROM_16i(INPUT, OUTPUT) \ + __m256i OUTPUT##_0 = _mm256_insertf128_si256( \ + _mm256_castsi128_si256(_mm512_extracti32x4_epi32(INPUT, 0)), \ + _mm512_extracti32x4_epi32(INPUT, 1), 1); \ + __m256i OUTPUT##_1 = _mm256_insertf128_si256( \ + _mm256_castsi128_si256(_mm512_extracti32x4_epi32(INPUT, 2)), \ + _mm512_extracti32x4_epi32(INPUT, 3), 1); #endif #ifdef EIGEN_VECTORIZE_AVX512DQ #define EIGEN_INSERT_8f_INTO_16f(OUTPUT, INPUTA, INPUTB) \ OUTPUT = _mm512_insertf32x8(_mm512_castps256_ps512(INPUTA), INPUTB, 1); + +#define EIGEN_INSERT_8i_INTO_16i(OUTPUT, INPUTA, INPUTB) \ + OUTPUT = _mm512_inserti32x8(_mm512_castsi256_si512(INPUTA), INPUTB, 1); #else #define EIGEN_INSERT_8f_INTO_16f(OUTPUT, INPUTA, INPUTB) \ OUTPUT = _mm512_undefined_ps(); \ @@ -972,6 +1061,13 @@ template<> EIGEN_STRONG_INLINE Packet8d pldexp(const Packet8d& a, cons OUTPUT = _mm512_insertf32x4(OUTPUT, _mm256_extractf128_ps(INPUTA, 1), 1); \ OUTPUT = _mm512_insertf32x4(OUTPUT, _mm256_extractf128_ps(INPUTB, 0), 2); \ OUTPUT = _mm512_insertf32x4(OUTPUT, _mm256_extractf128_ps(INPUTB, 1), 3); + +#define EIGEN_INSERT_8i_INTO_16i(OUTPUT, INPUTA, INPUTB) \ + OUTPUT = _mm512_undefined_epi32(); \ + OUTPUT = _mm512_inserti32x4(OUTPUT, _mm256_extractf128_si256(INPUTA, 0), 0); \ + OUTPUT = _mm512_inserti32x4(OUTPUT, _mm256_extractf128_si256(INPUTA, 1), 1); \ + OUTPUT = _mm512_inserti32x4(OUTPUT, _mm256_extractf128_si256(INPUTB, 0), 2); \ + OUTPUT = _mm512_inserti32x4(OUTPUT, _mm256_extractf128_si256(INPUTB, 1), 3); #endif template <> @@ -1000,6 +1096,24 @@ EIGEN_STRONG_INLINE double predux(const Packet8d& a) { __m256d tmp0 = _mm256_hadd_pd(sum, _mm256_permute2f128_pd(sum, sum, 1)); return _mm_cvtsd_f64(_mm256_castpd256_pd128(_mm256_hadd_pd(tmp0, tmp0))); } +template <> +EIGEN_STRONG_INLINE int predux(const Packet16i& a) { +#ifdef EIGEN_VECTORIZE_AVX512DQ + __m256i lane0 = _mm512_extracti32x8_epi32(a, 0); + __m256i lane1 = _mm512_extracti32x8_epi32(a, 1); + Packet8i x = _mm256_add_epi32(lane0, lane1); + return predux(x); +#else + __m128i lane0 = _mm512_extracti32x4_epi32(a, 0); + __m128i lane1 = _mm512_extracti32x4_epi32(a, 1); + __m128i lane2 = _mm512_extracti32x4_epi32(a, 2); + __m128i lane3 = _mm512_extracti32x4_epi32(a, 3); + __m128i sum = _mm_add_epi32(_mm_add_epi32(lane0, lane1), _mm_add_epi32(lane2, lane3)); + sum = _mm_hadd_epi32(sum, sum); + sum = _mm_hadd_epi32(sum, _mm_castps_si128(_mm_permute_ps(_mm_castsi128_ps(sum), 1))); + return _mm_cvtsi128_si32(sum); +#endif +} template <> EIGEN_STRONG_INLINE Packet8f predux_half_dowto4(const Packet16f& a) { @@ -1023,6 +1137,22 @@ EIGEN_STRONG_INLINE Packet4d predux_half_dowto4(const Packet8d& a) { __m256d lane1 = _mm512_extractf64x4_pd(a, 1); return _mm256_add_pd(lane0, lane1); } +template <> +EIGEN_STRONG_INLINE Packet8i predux_half_dowto4(const Packet16i& a) { +#ifdef EIGEN_VECTORIZE_AVX512DQ + __m256i lane0 = _mm512_extracti32x8_epi32(a, 0); + __m256i lane1 = _mm512_extracti32x8_epi32(a, 1); + return _mm256_add_epi32(lane0, lane1); +#else + __m128i lane0 = _mm512_extracti32x4_epi32(a, 0); + __m128i lane1 = _mm512_extracti32x4_epi32(a, 1); + __m128i lane2 = _mm512_extracti32x4_epi32(a, 2); + __m128i lane3 = _mm512_extracti32x4_epi32(a, 3); + __m128i sum0 = _mm_add_epi32(lane0, lane2); + __m128i sum1 = _mm_add_epi32(lane1, lane3); + return _mm256_inserti128_si256(_mm256_castsi128_si256(sum0), sum1, 1); +#endif +} template <> EIGEN_STRONG_INLINE float predux_mul(const Packet16f& a) { @@ -1352,6 +1482,163 @@ EIGEN_DEVICE_FUNC inline void ptranspose(PacketBlock& kernel) { PACK_OUTPUT_SQ_D(kernel.packet, tmp.packet, 6, 8); PACK_OUTPUT_SQ_D(kernel.packet, tmp.packet, 7, 8); } + +#define PACK_OUTPUT_I32(OUTPUT, INPUT, INDEX, STRIDE) \ + EIGEN_INSERT_8i_INTO_16i(OUTPUT[INDEX], INPUT[INDEX], INPUT[INDEX + STRIDE]); + +#define PACK_OUTPUT_I32_2(OUTPUT, INPUT, INDEX, STRIDE) \ + EIGEN_INSERT_8i_INTO_16i(OUTPUT[INDEX], INPUT[2 * INDEX], \ + INPUT[2 * INDEX + STRIDE]); + +#define SHUFFLE_EPI32(A, B, M) \ + _mm512_castps_si512(_mm512_shuffle_ps(_mm512_castsi512_ps(A), _mm512_castsi512_ps(B), M)) + +EIGEN_DEVICE_FUNC inline void ptranspose(PacketBlock& kernel) { + __m512i T0 = _mm512_unpacklo_epi32(kernel.packet[0], kernel.packet[1]); + __m512i T1 = _mm512_unpackhi_epi32(kernel.packet[0], kernel.packet[1]); + __m512i T2 = _mm512_unpacklo_epi32(kernel.packet[2], kernel.packet[3]); + __m512i T3 = _mm512_unpackhi_epi32(kernel.packet[2], kernel.packet[3]); + __m512i T4 = _mm512_unpacklo_epi32(kernel.packet[4], kernel.packet[5]); + __m512i T5 = _mm512_unpackhi_epi32(kernel.packet[4], kernel.packet[5]); + __m512i T6 = _mm512_unpacklo_epi32(kernel.packet[6], kernel.packet[7]); + __m512i T7 = _mm512_unpackhi_epi32(kernel.packet[6], kernel.packet[7]); + __m512i T8 = _mm512_unpacklo_epi32(kernel.packet[8], kernel.packet[9]); + __m512i T9 = _mm512_unpackhi_epi32(kernel.packet[8], kernel.packet[9]); + __m512i T10 = _mm512_unpacklo_epi32(kernel.packet[10], kernel.packet[11]); + __m512i T11 = _mm512_unpackhi_epi32(kernel.packet[10], kernel.packet[11]); + __m512i T12 = _mm512_unpacklo_epi32(kernel.packet[12], kernel.packet[13]); + __m512i T13 = _mm512_unpackhi_epi32(kernel.packet[12], kernel.packet[13]); + __m512i T14 = _mm512_unpacklo_epi32(kernel.packet[14], kernel.packet[15]); + __m512i T15 = _mm512_unpackhi_epi32(kernel.packet[14], kernel.packet[15]); + __m512i S0 = SHUFFLE_EPI32(T0, T2, _MM_SHUFFLE(1, 0, 1, 0)); + __m512i S1 = SHUFFLE_EPI32(T0, T2, _MM_SHUFFLE(3, 2, 3, 2)); + __m512i S2 = SHUFFLE_EPI32(T1, T3, _MM_SHUFFLE(1, 0, 1, 0)); + __m512i S3 = SHUFFLE_EPI32(T1, T3, _MM_SHUFFLE(3, 2, 3, 2)); + __m512i S4 = SHUFFLE_EPI32(T4, T6, _MM_SHUFFLE(1, 0, 1, 0)); + __m512i S5 = SHUFFLE_EPI32(T4, T6, _MM_SHUFFLE(3, 2, 3, 2)); + __m512i S6 = SHUFFLE_EPI32(T5, T7, _MM_SHUFFLE(1, 0, 1, 0)); + __m512i S7 = SHUFFLE_EPI32(T5, T7, _MM_SHUFFLE(3, 2, 3, 2)); + __m512i S8 = SHUFFLE_EPI32(T8, T10, _MM_SHUFFLE(1, 0, 1, 0)); + __m512i S9 = SHUFFLE_EPI32(T8, T10, _MM_SHUFFLE(3, 2, 3, 2)); + __m512i S10 = SHUFFLE_EPI32(T9, T11, _MM_SHUFFLE(1, 0, 1, 0)); + __m512i S11 = SHUFFLE_EPI32(T9, T11, _MM_SHUFFLE(3, 2, 3, 2)); + __m512i S12 = SHUFFLE_EPI32(T12, T14, _MM_SHUFFLE(1, 0, 1, 0)); + __m512i S13 = SHUFFLE_EPI32(T12, T14, _MM_SHUFFLE(3, 2, 3, 2)); + __m512i S14 = SHUFFLE_EPI32(T13, T15, _MM_SHUFFLE(1, 0, 1, 0)); + __m512i S15 = SHUFFLE_EPI32(T13, T15, _MM_SHUFFLE(3, 2, 3, 2)); + + EIGEN_EXTRACT_8i_FROM_16i(S0, S0); + EIGEN_EXTRACT_8i_FROM_16i(S1, S1); + EIGEN_EXTRACT_8i_FROM_16i(S2, S2); + EIGEN_EXTRACT_8i_FROM_16i(S3, S3); + EIGEN_EXTRACT_8i_FROM_16i(S4, S4); + EIGEN_EXTRACT_8i_FROM_16i(S5, S5); + EIGEN_EXTRACT_8i_FROM_16i(S6, S6); + EIGEN_EXTRACT_8i_FROM_16i(S7, S7); + EIGEN_EXTRACT_8i_FROM_16i(S8, S8); + EIGEN_EXTRACT_8i_FROM_16i(S9, S9); + EIGEN_EXTRACT_8i_FROM_16i(S10, S10); + EIGEN_EXTRACT_8i_FROM_16i(S11, S11); + EIGEN_EXTRACT_8i_FROM_16i(S12, S12); + EIGEN_EXTRACT_8i_FROM_16i(S13, S13); + EIGEN_EXTRACT_8i_FROM_16i(S14, S14); + EIGEN_EXTRACT_8i_FROM_16i(S15, S15); + + PacketBlock tmp; + + tmp.packet[0] = _mm256_permute2f128_si256(S0_0, S4_0, 0x20); + tmp.packet[1] = _mm256_permute2f128_si256(S1_0, S5_0, 0x20); + tmp.packet[2] = _mm256_permute2f128_si256(S2_0, S6_0, 0x20); + tmp.packet[3] = _mm256_permute2f128_si256(S3_0, S7_0, 0x20); + tmp.packet[4] = _mm256_permute2f128_si256(S0_0, S4_0, 0x31); + tmp.packet[5] = _mm256_permute2f128_si256(S1_0, S5_0, 0x31); + tmp.packet[6] = _mm256_permute2f128_si256(S2_0, S6_0, 0x31); + tmp.packet[7] = _mm256_permute2f128_si256(S3_0, S7_0, 0x31); + + tmp.packet[8] = _mm256_permute2f128_si256(S0_1, S4_1, 0x20); + tmp.packet[9] = _mm256_permute2f128_si256(S1_1, S5_1, 0x20); + tmp.packet[10] = _mm256_permute2f128_si256(S2_1, S6_1, 0x20); + tmp.packet[11] = _mm256_permute2f128_si256(S3_1, S7_1, 0x20); + tmp.packet[12] = _mm256_permute2f128_si256(S0_1, S4_1, 0x31); + tmp.packet[13] = _mm256_permute2f128_si256(S1_1, S5_1, 0x31); + tmp.packet[14] = _mm256_permute2f128_si256(S2_1, S6_1, 0x31); + tmp.packet[15] = _mm256_permute2f128_si256(S3_1, S7_1, 0x31); + + // Second set of _m256 outputs + tmp.packet[16] = _mm256_permute2f128_si256(S8_0, S12_0, 0x20); + tmp.packet[17] = _mm256_permute2f128_si256(S9_0, S13_0, 0x20); + tmp.packet[18] = _mm256_permute2f128_si256(S10_0, S14_0, 0x20); + tmp.packet[19] = _mm256_permute2f128_si256(S11_0, S15_0, 0x20); + tmp.packet[20] = _mm256_permute2f128_si256(S8_0, S12_0, 0x31); + tmp.packet[21] = _mm256_permute2f128_si256(S9_0, S13_0, 0x31); + tmp.packet[22] = _mm256_permute2f128_si256(S10_0, S14_0, 0x31); + tmp.packet[23] = _mm256_permute2f128_si256(S11_0, S15_0, 0x31); + + tmp.packet[24] = _mm256_permute2f128_si256(S8_1, S12_1, 0x20); + tmp.packet[25] = _mm256_permute2f128_si256(S9_1, S13_1, 0x20); + tmp.packet[26] = _mm256_permute2f128_si256(S10_1, S14_1, 0x20); + tmp.packet[27] = _mm256_permute2f128_si256(S11_1, S15_1, 0x20); + tmp.packet[28] = _mm256_permute2f128_si256(S8_1, S12_1, 0x31); + tmp.packet[29] = _mm256_permute2f128_si256(S9_1, S13_1, 0x31); + tmp.packet[30] = _mm256_permute2f128_si256(S10_1, S14_1, 0x31); + tmp.packet[31] = _mm256_permute2f128_si256(S11_1, S15_1, 0x31); + + // Pack them into the output + PACK_OUTPUT_I32(kernel.packet, tmp.packet, 0, 16); + PACK_OUTPUT_I32(kernel.packet, tmp.packet, 1, 16); + PACK_OUTPUT_I32(kernel.packet, tmp.packet, 2, 16); + PACK_OUTPUT_I32(kernel.packet, tmp.packet, 3, 16); + + PACK_OUTPUT_I32(kernel.packet, tmp.packet, 4, 16); + PACK_OUTPUT_I32(kernel.packet, tmp.packet, 5, 16); + PACK_OUTPUT_I32(kernel.packet, tmp.packet, 6, 16); + PACK_OUTPUT_I32(kernel.packet, tmp.packet, 7, 16); + + PACK_OUTPUT_I32(kernel.packet, tmp.packet, 8, 16); + PACK_OUTPUT_I32(kernel.packet, tmp.packet, 9, 16); + PACK_OUTPUT_I32(kernel.packet, tmp.packet, 10, 16); + PACK_OUTPUT_I32(kernel.packet, tmp.packet, 11, 16); + + PACK_OUTPUT_I32(kernel.packet, tmp.packet, 12, 16); + PACK_OUTPUT_I32(kernel.packet, tmp.packet, 13, 16); + PACK_OUTPUT_I32(kernel.packet, tmp.packet, 14, 16); + PACK_OUTPUT_I32(kernel.packet, tmp.packet, 15, 16); +} + +EIGEN_DEVICE_FUNC inline void ptranspose(PacketBlock& kernel) { + __m512i T0 = _mm512_unpacklo_epi32(kernel.packet[0], kernel.packet[1]); + __m512i T1 = _mm512_unpackhi_epi32(kernel.packet[0], kernel.packet[1]); + __m512i T2 = _mm512_unpacklo_epi32(kernel.packet[2], kernel.packet[3]); + __m512i T3 = _mm512_unpackhi_epi32(kernel.packet[2], kernel.packet[3]); + + __m512i S0 = SHUFFLE_EPI32(T0, T2, _MM_SHUFFLE(1, 0, 1, 0)); + __m512i S1 = SHUFFLE_EPI32(T0, T2, _MM_SHUFFLE(3, 2, 3, 2)); + __m512i S2 = SHUFFLE_EPI32(T1, T3, _MM_SHUFFLE(1, 0, 1, 0)); + __m512i S3 = SHUFFLE_EPI32(T1, T3, _MM_SHUFFLE(3, 2, 3, 2)); + + EIGEN_EXTRACT_8i_FROM_16i(S0, S0); + EIGEN_EXTRACT_8i_FROM_16i(S1, S1); + EIGEN_EXTRACT_8i_FROM_16i(S2, S2); + EIGEN_EXTRACT_8i_FROM_16i(S3, S3); + + PacketBlock tmp; + + tmp.packet[0] = _mm256_permute2f128_si256(S0_0, S1_0, 0x20); + tmp.packet[1] = _mm256_permute2f128_si256(S2_0, S3_0, 0x20); + tmp.packet[2] = _mm256_permute2f128_si256(S0_0, S1_0, 0x31); + tmp.packet[3] = _mm256_permute2f128_si256(S2_0, S3_0, 0x31); + + tmp.packet[4] = _mm256_permute2f128_si256(S0_1, S1_1, 0x20); + tmp.packet[5] = _mm256_permute2f128_si256(S2_1, S3_1, 0x20); + tmp.packet[6] = _mm256_permute2f128_si256(S0_1, S1_1, 0x31); + tmp.packet[7] = _mm256_permute2f128_si256(S2_1, S3_1, 0x31); + + PACK_OUTPUT_I32_2(kernel.packet, tmp.packet, 0, 1); + PACK_OUTPUT_I32_2(kernel.packet, tmp.packet, 1, 1); + PACK_OUTPUT_I32_2(kernel.packet, tmp.packet, 2, 1); + PACK_OUTPUT_I32_2(kernel.packet, tmp.packet, 3, 1); +} + template <> EIGEN_STRONG_INLINE Packet16f pblend(const Selector<16>& /*ifPacket*/, const Packet16f& /*thenPacket*/, diff --git a/Eigen/src/Core/arch/SSE/PacketMath.h b/Eigen/src/Core/arch/SSE/PacketMath.h index db102c73a..d658f65bb 100755 --- a/Eigen/src/Core/arch/SSE/PacketMath.h +++ b/Eigen/src/Core/arch/SSE/PacketMath.h @@ -180,7 +180,6 @@ struct packet_traits : default_packet_traits { HasRint = 1 }; }; -#endif template<> struct packet_traits : default_packet_traits { typedef Packet4i type; @@ -194,7 +193,7 @@ template<> struct packet_traits : default_packet_traits HasBlend = 1 }; }; - +#endif template<> struct packet_traits : default_packet_traits { typedef Packet16b type; @@ -233,7 +232,7 @@ template<> struct unpacket_traits { template<> struct unpacket_traits { typedef int type; typedef Packet4i half; - enum {size=4, alignment=Aligned16, vectorizable=false, masked_load_available=false, masked_store_available=false}; + enum {size=4, alignment=Aligned16, vectorizable=true, masked_load_available=false, masked_store_available=false}; }; template<> struct unpacket_traits { typedef bool type; diff --git a/test/vectorization_logic.cpp b/test/vectorization_logic.cpp index 97c0bdad9..602e9f15c 100644 --- a/test/vectorization_logic.cpp +++ b/test/vectorization_logic.cpp @@ -337,7 +337,8 @@ struct vectorization_logic_half ((!EIGEN_UNALIGNED_VECTORIZE) && (sizeof(Scalar)==16)) ? NoUnrolling : CompleteUnrolling)); VERIFY(test_assign(Matrix3(),Matrix3().cwiseQuotient(Matrix3()), - PacketTraits::HasDiv ? LinearVectorizedTraversal : LinearTraversal,CompleteUnrolling)); + PacketTraits::HasDiv ? LinearVectorizedTraversal : LinearTraversal, + PacketTraits::HasDiv ? CompleteUnrolling : NoUnrolling)); VERIFY(test_assign(Matrix(),Matrix()+Matrix(), sizeof(Scalar)==16 ? InnerVectorizedTraversal : (EIGEN_UNALIGNED_VECTORIZE ? LinearVectorizedTraversal : LinearTraversal),