From 3c9109238fddbd65412b34ae7a503b08ad786970 Mon Sep 17 00:00:00 2001 From: Rasmus Munk Larsen Date: Tue, 9 Apr 2024 22:58:44 +0000 Subject: [PATCH] Add support for Packet8l to AVX512. --- Eigen/src/Core/arch/AVX/PacketMath.h | 4 +- Eigen/src/Core/arch/AVX512/PacketMath.h | 411 +++++++++++++++++++---- Eigen/src/Core/arch/AVX512/TypeCasting.h | 41 +++ 3 files changed, 396 insertions(+), 60 deletions(-) diff --git a/Eigen/src/Core/arch/AVX/PacketMath.h b/Eigen/src/Core/arch/AVX/PacketMath.h index 2383e46e5..a53c38ded 100644 --- a/Eigen/src/Core/arch/AVX/PacketMath.h +++ b/Eigen/src/Core/arch/AVX/PacketMath.h @@ -1844,7 +1844,7 @@ EIGEN_STRONG_INLINE Packet8f psignbit(const Packet8f& a) { #endif } template <> -EIGEN_STRONG_INLINE Packet8ui psignbit(const Packet8ui& a) { +EIGEN_STRONG_INLINE Packet8ui psignbit(const Packet8ui& /*unused*/) { return _mm256_setzero_si256(); } #ifdef EIGEN_VECTORIZE_AVX2 @@ -1853,7 +1853,7 @@ EIGEN_STRONG_INLINE Packet4d psignbit(const Packet4d& a) { return _mm256_castsi256_pd(_mm256_cmpgt_epi64(_mm256_setzero_si256(), _mm256_castpd_si256(a))); } template <> -EIGEN_STRONG_INLINE Packet4ul psignbit(const Packet4ul& a) { +EIGEN_STRONG_INLINE Packet4ul psignbit(const Packet4ul& /*unused*/) { return _mm256_setzero_si256(); } #endif diff --git a/Eigen/src/Core/arch/AVX512/PacketMath.h b/Eigen/src/Core/arch/AVX512/PacketMath.h index ed2f189aa..c1bed8020 100644 --- a/Eigen/src/Core/arch/AVX512/PacketMath.h +++ b/Eigen/src/Core/arch/AVX512/PacketMath.h @@ -34,7 +34,7 @@ namespace internal { typedef __m512 Packet16f; typedef __m512i Packet16i; typedef __m512d Packet8d; -// TODO(rmlarsen): Add support for Packet8l. +typedef eigen_packet_wrapper<__m512i, 1> Packet8l; #ifndef EIGEN_VECTORIZE_AVX512FP16 typedef eigen_packet_wrapper<__m256i, 1> Packet16h; #endif @@ -52,6 +52,10 @@ template <> struct is_arithmetic<__m512d> { enum { value = true }; }; +template <> +struct is_arithmetic { + enum { value = true }; +}; #ifndef EIGEN_VECTORIZE_AVX512FP16 template <> @@ -171,6 +175,13 @@ struct packet_traits : default_packet_traits { enum { Vectorizable = 1, AlignedOnScalar = 1, HasBlend = 0, HasCmp = 1, HasDiv = 1, size = 16 }; }; +template <> +struct packet_traits : default_packet_traits { + typedef Packet8l type; + typedef Packet4l half; + enum { Vectorizable = 1, AlignedOnScalar = 1, HasCmp = 1, size = 8 }; +}; + template <> struct unpacket_traits { typedef float type; @@ -190,6 +201,7 @@ template <> struct unpacket_traits { typedef double type; typedef Packet4d half; + typedef Packet8l integer_packet; typedef uint8_t mask_t; enum { size = 8, @@ -213,6 +225,19 @@ struct unpacket_traits { }; }; +template <> +struct unpacket_traits { + typedef int64_t type; + typedef Packet4l half; + enum { + size = 8, + alignment = Aligned64, + vectorizable = true, + masked_load_available = false, + masked_store_available = false + }; +}; + #ifndef EIGEN_VECTORIZE_AVX512FP16 template <> struct unpacket_traits { @@ -240,6 +265,10 @@ template <> EIGEN_STRONG_INLINE Packet16i pset1(const int& from) { return _mm512_set1_epi32(from); } +template <> +EIGEN_STRONG_INLINE Packet8l pset1(const int64_t& from) { + return _mm512_set1_epi64(from); +} template <> EIGEN_STRONG_INLINE Packet16f pset1frombits(unsigned int from) { @@ -264,6 +293,11 @@ EIGEN_STRONG_INLINE Packet16i pzero(const Packet16i& /*a*/) { return _mm512_setzero_si512(); } +template <> +EIGEN_STRONG_INLINE Packet8l pzero(const Packet8l& /*a*/) { + return _mm512_setzero_si512(); +} + template <> EIGEN_STRONG_INLINE Packet16f peven_mask(const Packet16f& /*a*/) { return _mm512_castsi512_ps(_mm512_set_epi32(0, -1, 0, -1, 0, -1, 0, -1, 0, -1, 0, -1, 0, -1, 0, -1)); @@ -276,6 +310,10 @@ template <> EIGEN_STRONG_INLINE Packet8d peven_mask(const Packet8d& /*a*/) { return _mm512_castsi512_pd(_mm512_set_epi32(0, 0, -1, -1, 0, 0, -1, -1, 0, 0, -1, -1, 0, 0, -1, -1)); } +template <> +EIGEN_STRONG_INLINE Packet8l peven_mask(const Packet8l& /*a*/) { + return _mm512_set_epi32(0, 0, -1, -1, 0, 0, -1, -1, 0, 0, -1, -1, 0, 0, -1, -1); +} template <> EIGEN_STRONG_INLINE Packet16f pload1(const float* from) { @@ -313,6 +351,10 @@ template <> EIGEN_STRONG_INLINE Packet16i plset(const int& a) { return _mm512_add_epi32(_mm512_set1_epi32(a), _mm512_set_epi32(15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0)); } +template <> +EIGEN_STRONG_INLINE Packet8l plset(const int64_t& a) { + return _mm512_add_epi64(_mm512_set1_epi64(a), _mm512_set_epi64(7, 6, 5, 4, 3, 2, 1, 0)); +} template <> EIGEN_STRONG_INLINE Packet16f padd(const Packet16f& a, const Packet16f& b) { @@ -326,6 +368,10 @@ template <> EIGEN_STRONG_INLINE Packet16i padd(const Packet16i& a, const Packet16i& b) { return _mm512_add_epi32(a, b); } +template <> +EIGEN_STRONG_INLINE Packet8l padd(const Packet8l& a, const Packet8l& b) { + return _mm512_add_epi64(a, b); +} template <> EIGEN_STRONG_INLINE Packet16f padd(const Packet16f& a, const Packet16f& b, uint16_t umask) { @@ -350,6 +396,10 @@ template <> EIGEN_STRONG_INLINE Packet16i psub(const Packet16i& a, const Packet16i& b) { return _mm512_sub_epi32(a, b); } +template <> +EIGEN_STRONG_INLINE Packet8l psub(const Packet8l& a, const Packet8l& b) { + return _mm512_sub_epi64(a, b); +} template <> EIGEN_STRONG_INLINE Packet16f pnegate(const Packet16f& a) { @@ -372,6 +422,10 @@ template <> EIGEN_STRONG_INLINE Packet16i pnegate(const Packet16i& a) { return _mm512_sub_epi32(_mm512_setzero_si512(), a); } +template <> +EIGEN_STRONG_INLINE Packet8l pnegate(const Packet8l& a) { + return _mm512_sub_epi64(_mm512_setzero_si512(), a); +} template <> EIGEN_STRONG_INLINE Packet16f pconj(const Packet16f& a) { @@ -385,6 +439,10 @@ template <> EIGEN_STRONG_INLINE Packet16i pconj(const Packet16i& a) { return a; } +template <> +EIGEN_STRONG_INLINE Packet8l pconj(const Packet8l& a) { + return a; +} template <> EIGEN_STRONG_INLINE Packet16f pmul(const Packet16f& a, const Packet16f& b) { @@ -398,6 +456,14 @@ template <> EIGEN_STRONG_INLINE Packet16i pmul(const Packet16i& a, const Packet16i& b) { return _mm512_mullo_epi32(a, b); } +template <> +EIGEN_STRONG_INLINE Packet8l pmul(const Packet8l& a, const Packet8l& b) { +#ifdef EIGEN_VECTORIZE_AVX512DQ + return _mm512_mullo_epi64(a, b); +#else + return _mm512_mullox_epi64(a, b); +#endif +} template <> EIGEN_STRONG_INLINE Packet16f pdiv(const Packet16f& a, const Packet16f& b) { @@ -466,6 +532,12 @@ EIGEN_DEVICE_FUNC inline Packet16i pselect(const Packet16i& mask, const Packet16 return _mm512_mask_blend_epi32(mask16, a, b); } +template <> +EIGEN_DEVICE_FUNC inline Packet8l pselect(const Packet8l& mask, const Packet8l& a, const Packet8l& b) { + __mmask8 mask8 = _mm512_cmpeq_epi64_mask(mask, _mm512_setzero_si512()); + return _mm512_mask_blend_epi64(mask8, a, b); +} + template <> EIGEN_DEVICE_FUNC inline Packet8d pselect(const Packet8d& mask, const Packet8d& a, const Packet8d& b) { __mmask8 mask8 = _mm512_cmp_epi64_mask(_mm512_castpd_si512(mask), _mm512_setzero_epi32(), _MM_CMPINT_EQ); @@ -486,6 +558,10 @@ template <> EIGEN_STRONG_INLINE Packet16i pmin(const Packet16i& a, const Packet16i& b) { return _mm512_min_epi32(b, a); } +template <> +EIGEN_STRONG_INLINE Packet8l pmin(const Packet8l& a, const Packet8l& b) { + return _mm512_min_epi64(b, a); +} template <> EIGEN_STRONG_INLINE Packet16f pmax(const Packet16f& a, const Packet16f& b) { @@ -501,6 +577,10 @@ template <> EIGEN_STRONG_INLINE Packet16i pmax(const Packet16i& a, const Packet16i& b) { return _mm512_max_epi32(b, a); } +template <> +EIGEN_STRONG_INLINE Packet8l pmax(const Packet8l& a, const Packet8l& b) { + return _mm512_max_epi64(b, a); +} // Add specializations for min/max with prescribed NaN progation. template <> @@ -593,46 +673,62 @@ EIGEN_STRONG_INLINE __m256i Pack32To16(Packet16f rf) { template <> EIGEN_STRONG_INLINE Packet16f pisnan(const Packet16f& a) { __mmask16 mask = _mm512_cmp_ps_mask(a, a, _CMP_UNORD_Q); - return _mm512_castsi512_ps(_mm512_maskz_set1_epi32(mask, 0xffffffffu)); + return _mm512_castsi512_ps(_mm512_maskz_set1_epi32(mask, int32_t(-1))); } template <> EIGEN_STRONG_INLINE Packet16f pcmp_eq(const Packet16f& a, const Packet16f& b) { __mmask16 mask = _mm512_cmp_ps_mask(a, b, _CMP_EQ_OQ); - return _mm512_castsi512_ps(_mm512_mask_set1_epi32(_mm512_setzero_epi32(), mask, 0xffffffffu)); + return _mm512_castsi512_ps(_mm512_mask_set1_epi32(_mm512_setzero_epi32(), mask, int32_t(-1))); } template <> EIGEN_STRONG_INLINE Packet16f pcmp_le(const Packet16f& a, const Packet16f& b) { __mmask16 mask = _mm512_cmp_ps_mask(a, b, _CMP_LE_OQ); - return _mm512_castsi512_ps(_mm512_mask_set1_epi32(_mm512_setzero_epi32(), mask, 0xffffffffu)); + return _mm512_castsi512_ps(_mm512_mask_set1_epi32(_mm512_setzero_epi32(), mask, int32_t(-1))); } template <> EIGEN_STRONG_INLINE Packet16f pcmp_lt(const Packet16f& a, const Packet16f& b) { __mmask16 mask = _mm512_cmp_ps_mask(a, b, _CMP_LT_OQ); - return _mm512_castsi512_ps(_mm512_mask_set1_epi32(_mm512_setzero_epi32(), mask, 0xffffffffu)); + return _mm512_castsi512_ps(_mm512_mask_set1_epi32(_mm512_setzero_epi32(), mask, int32_t(-1))); } template <> EIGEN_STRONG_INLINE Packet16f pcmp_lt_or_nan(const Packet16f& a, const Packet16f& b) { __mmask16 mask = _mm512_cmp_ps_mask(a, b, _CMP_NGE_UQ); - return _mm512_castsi512_ps(_mm512_mask_set1_epi32(_mm512_setzero_epi32(), mask, 0xffffffffu)); + return _mm512_castsi512_ps(_mm512_mask_set1_epi32(_mm512_setzero_epi32(), mask, int32_t(-1))); } template <> EIGEN_STRONG_INLINE Packet16i pcmp_eq(const Packet16i& a, const Packet16i& b) { __mmask16 mask = _mm512_cmp_epi32_mask(a, b, _MM_CMPINT_EQ); - return _mm512_mask_set1_epi32(_mm512_setzero_epi32(), mask, 0xffffffffu); + return _mm512_mask_set1_epi32(_mm512_setzero_epi32(), mask, int32_t(-1)); } template <> EIGEN_STRONG_INLINE Packet16i pcmp_le(const Packet16i& a, const Packet16i& b) { __mmask16 mask = _mm512_cmp_epi32_mask(a, b, _MM_CMPINT_LE); - return _mm512_mask_set1_epi32(_mm512_setzero_epi32(), mask, 0xffffffffu); + return _mm512_mask_set1_epi32(_mm512_setzero_epi32(), mask, int32_t(-1)); } template <> EIGEN_STRONG_INLINE Packet16i pcmp_lt(const Packet16i& a, const Packet16i& b) { __mmask16 mask = _mm512_cmp_epi32_mask(a, b, _MM_CMPINT_LT); - return _mm512_mask_set1_epi32(_mm512_setzero_epi32(), mask, 0xffffffffu); + return _mm512_mask_set1_epi32(_mm512_setzero_epi32(), mask, int32_t(-1)); +} + +template <> +EIGEN_STRONG_INLINE Packet8l pcmp_eq(const Packet8l& a, const Packet8l& b) { + __mmask8 mask = _mm512_cmp_epi64_mask(a, b, _MM_CMPINT_EQ); + return _mm512_mask_set1_epi64(_mm512_setzero_si512(), mask, int64_t(-1)); +} +template <> +EIGEN_STRONG_INLINE Packet8l pcmp_le(const Packet8l& a, const Packet8l& b) { + __mmask8 mask = _mm512_cmp_epi64_mask(a, b, _MM_CMPINT_LE); + return _mm512_mask_set1_epi64(_mm512_setzero_si512(), mask, int64_t(-1)); +} +template <> +EIGEN_STRONG_INLINE Packet8l pcmp_lt(const Packet8l& a, const Packet8l& b) { + __mmask8 mask = _mm512_cmp_epi64_mask(a, b, _MM_CMPINT_LT); + return _mm512_mask_set1_epi64(_mm512_setzero_si512(), mask, int64_t(-1)); } template <> @@ -685,7 +781,12 @@ EIGEN_STRONG_INLINE Packet8d pfloor(const Packet8d& a) { template <> EIGEN_STRONG_INLINE Packet16i ptrue(const Packet16i& /*a*/) { - return _mm512_set1_epi32(0xffffffffu); + return _mm512_set1_epi32(int32_t(-1)); +} + +template <> +EIGEN_STRONG_INLINE Packet8l ptrue(const Packet8l& /*a*/) { + return _mm512_set1_epi64(int64_t(-1)); } template <> @@ -703,6 +804,11 @@ EIGEN_STRONG_INLINE Packet16i pand(const Packet16i& a, const Packet16 return _mm512_and_si512(a, b); } +template <> +EIGEN_STRONG_INLINE Packet8l pand(const Packet8l& a, const Packet8l& b) { + return _mm512_and_si512(a, b); +} + template <> EIGEN_STRONG_INLINE Packet16f pand(const Packet16f& a, const Packet16f& b) { #ifdef EIGEN_VECTORIZE_AVX512DQ @@ -732,6 +838,11 @@ EIGEN_STRONG_INLINE Packet16i por(const Packet16i& a, const Packet16i return _mm512_or_si512(a, b); } +template <> +EIGEN_STRONG_INLINE Packet8l por(const Packet8l& a, const Packet8l& b) { + return _mm512_or_si512(a, b); +} + template <> EIGEN_STRONG_INLINE Packet16f por(const Packet16f& a, const Packet16f& b) { #ifdef EIGEN_VECTORIZE_AVX512DQ @@ -755,6 +866,11 @@ EIGEN_STRONG_INLINE Packet16i pxor(const Packet16i& a, const Packet16 return _mm512_xor_si512(a, b); } +template <> +EIGEN_STRONG_INLINE Packet8l pxor(const Packet8l& a, const Packet8l& b) { + return _mm512_xor_si512(a, b); +} + template <> EIGEN_STRONG_INLINE Packet16f pxor(const Packet16f& a, const Packet16f& b) { #ifdef EIGEN_VECTORIZE_AVX512DQ @@ -778,6 +894,11 @@ EIGEN_STRONG_INLINE Packet16i pandnot(const Packet16i& a, const Packe return _mm512_andnot_si512(b, a); } +template <> +EIGEN_STRONG_INLINE Packet8l pandnot(const Packet8l& a, const Packet8l& b) { + return _mm512_andnot_si512(b, a); +} + template <> EIGEN_STRONG_INLINE Packet16f pandnot(const Packet16f& a, const Packet16f& b) { #ifdef EIGEN_VECTORIZE_AVX512DQ @@ -825,6 +946,21 @@ EIGEN_STRONG_INLINE Packet16i plogical_shift_left(Packet16i a) { return _mm512_slli_epi32(a, N); } +template +EIGEN_STRONG_INLINE Packet8l parithmetic_shift_right(Packet8l a) { + return _mm512_srai_epi64(a, N); +} + +template +EIGEN_STRONG_INLINE Packet8l plogical_shift_right(Packet8l a) { + return _mm512_srli_epi64(a, N); +} + +template +EIGEN_STRONG_INLINE Packet8l plogical_shift_left(Packet8l a) { + return _mm512_slli_epi64(a, N); +} + template <> EIGEN_STRONG_INLINE Packet16f pload(const float* from) { EIGEN_DEBUG_ALIGNED_LOAD return _mm512_load_ps(from); @@ -835,7 +971,11 @@ EIGEN_STRONG_INLINE Packet8d pload(const double* from) { } template <> EIGEN_STRONG_INLINE Packet16i pload(const int* from) { - EIGEN_DEBUG_ALIGNED_LOAD return _mm512_load_si512(reinterpret_cast(from)); + EIGEN_DEBUG_ALIGNED_LOAD return _mm512_load_epi64(from); +} +template <> +EIGEN_STRONG_INLINE Packet8l pload(const int64_t* from) { + EIGEN_DEBUG_ALIGNED_LOAD return _mm512_load_epi64(from); } template <> @@ -848,7 +988,11 @@ EIGEN_STRONG_INLINE Packet8d ploadu(const double* from) { } template <> EIGEN_STRONG_INLINE Packet16i ploadu(const int* from) { - EIGEN_DEBUG_UNALIGNED_LOAD return _mm512_loadu_si512(reinterpret_cast(from)); + EIGEN_DEBUG_UNALIGNED_LOAD return _mm512_loadu_epi32(from); +} +template <> +EIGEN_STRONG_INLINE Packet8l ploadu(const int64_t* from) { + EIGEN_DEBUG_UNALIGNED_LOAD return _mm512_loadu_epi64(from); } template <> @@ -868,42 +1012,35 @@ template <> EIGEN_STRONG_INLINE Packet16f ploaddup(const float* from) { // an unaligned load is required here as there is no requirement // on the alignment of input pointer 'from' - __m256i low_half = _mm256_loadu_si256(reinterpret_cast(from)); + __m256i low_half = _mm256_castps_si256(_mm256_loadu_ps(from)); __m512 even_elements = _mm512_castsi512_ps(_mm512_cvtepu32_epi64(low_half)); __m512 pairs = _mm512_permute_ps(even_elements, _MM_SHUFFLE(2, 2, 0, 0)); return pairs; } -#ifdef EIGEN_VECTORIZE_AVX512DQ -// FIXME: this does not look optimal, better load a Packet4d and shuffle... -// Loads 4 doubles from memory a returns the packet {a0, a0 a1, a1, a2, a2, a3, +// Loads 4 doubles from memory a returns the packet {a0, a0, a1, a1, a2, a2, a3, // a3} template <> EIGEN_STRONG_INLINE Packet8d ploaddup(const double* from) { - __m512d x = _mm512_setzero_pd(); - x = _mm512_insertf64x2(x, _mm_loaddup_pd(&from[0]), 0); - x = _mm512_insertf64x2(x, _mm_loaddup_pd(&from[1]), 1); - x = _mm512_insertf64x2(x, _mm_loaddup_pd(&from[2]), 2); - x = _mm512_insertf64x2(x, _mm_loaddup_pd(&from[3]), 3); - return x; + Packet8d tmp = _mm512_castpd256_pd512(ploadu(from)); + const Packet8l scatter_mask = _mm512_set_epi64(3, 3, 2, 2, 1, 1, 0, 0); + return _mm512_permutexvar_pd(scatter_mask, tmp); } -#else + +// Loads 4 int64_t from memory a returns the packet {a0, a0, a1, a1, a2, a2, a3, +// a3} template <> -EIGEN_STRONG_INLINE Packet8d ploaddup(const double* from) { - __m512d x = _mm512_setzero_pd(); - x = _mm512_mask_broadcastsd_pd(x, 0x3 << 0, _mm_load_sd(from + 0)); - x = _mm512_mask_broadcastsd_pd(x, 0x3 << 2, _mm_load_sd(from + 1)); - x = _mm512_mask_broadcastsd_pd(x, 0x3 << 4, _mm_load_sd(from + 2)); - x = _mm512_mask_broadcastsd_pd(x, 0x3 << 6, _mm_load_sd(from + 3)); - return x; +EIGEN_STRONG_INLINE Packet8l ploaddup(const int64_t* from) { + Packet8l tmp = _mm512_castsi256_si512(ploadu(from)); + const Packet8l scatter_mask = _mm512_set_epi64(3, 3, 2, 2, 1, 1, 0, 0); + return _mm512_permutexvar_epi64(scatter_mask, tmp); } -#endif // Loads 8 integers from memory and returns the packet // {a0, a0 a1, a1, a2, a2, a3, a3, a4, a4, a5, a5, a6, a6, a7, a7} template <> EIGEN_STRONG_INLINE Packet16i ploaddup(const int* from) { - __m256i low_half = _mm256_loadu_si256(reinterpret_cast(from)); + __m256i low_half = _mm256_load_si256(reinterpret_cast(from)); __m512 even_elements = _mm512_castsi512_ps(_mm512_cvtepu32_epi64(low_half)); __m512 pairs = _mm512_permute_ps(even_elements, _MM_SHUFFLE(2, 2, 0, 0)); return _mm512_castps_si512(pairs); @@ -929,6 +1066,17 @@ EIGEN_STRONG_INLINE Packet8d ploadquad(const double* from) { return _mm512_insertf64x4(tmp, lane1, 1); } +// Loads 2 int64_t from memory a returns the packet +// {a0, a0 a0, a0, a1, a1, a1, a1} +template <> +EIGEN_STRONG_INLINE Packet8l ploadquad(const int64_t* from) { + __m256i lane0 = _mm256_set1_epi64x(*from); + __m256i lane1 = _mm256_set1_epi64x(*(from + 1)); + __m512i tmp = _mm512_undefined_epi32(); + tmp = _mm512_inserti64x4(tmp, lane0, 0); + return _mm512_inserti64x4(tmp, lane1, 1); +} + // Loads 4 integers from memory and returns the packet // {a0, a0 a0, a0, a1, a1, a1, a1, a2, a2, a2, a2, a3, a3, a3, a3} template <> @@ -948,7 +1096,11 @@ EIGEN_STRONG_INLINE void pstore(double* to, const Packet8d& from) { } template <> EIGEN_STRONG_INLINE void pstore(int* to, const Packet16i& from) { - EIGEN_DEBUG_ALIGNED_STORE _mm512_storeu_si512(reinterpret_cast<__m512i*>(to), from); + EIGEN_DEBUG_ALIGNED_STORE _mm512_store_epi32(to, from); +} +template <> +EIGEN_STRONG_INLINE void pstore(int64_t* to, const Packet8l& from) { + EIGEN_DEBUG_ALIGNED_STORE _mm512_store_epi64(to, from); } template <> @@ -961,7 +1113,11 @@ EIGEN_STRONG_INLINE void pstoreu(double* to, const Packet8d& from) { } template <> EIGEN_STRONG_INLINE void pstoreu(int* to, const Packet16i& from) { - EIGEN_DEBUG_UNALIGNED_STORE _mm512_storeu_si512(reinterpret_cast<__m512i*>(to), from); + EIGEN_DEBUG_UNALIGNED_STORE _mm512_storeu_epi32(to, from); +} +template <> +EIGEN_STRONG_INLINE void pstoreu(int64_t* to, const Packet8l& from) { + EIGEN_DEBUG_UNALIGNED_STORE _mm512_storeu_epi64(to, from); } template <> EIGEN_STRONG_INLINE void pstoreu(float* to, const Packet16f& from, uint16_t umask) { @@ -1015,6 +1171,14 @@ EIGEN_DEVICE_FUNC inline Packet8d pgather(const double* from, return _mm512_i32gather_pd(indices, from, 8); } template <> +EIGEN_DEVICE_FUNC inline Packet8l pgather(const int64_t* from, Index stride) { + Packet8i stride_vector = _mm256_set1_epi32(convert_index(stride)); + Packet8i stride_multiplier = _mm256_set_epi32(7, 6, 5, 4, 3, 2, 1, 0); + Packet8i indices = _mm256_mullo_epi32(stride_vector, stride_multiplier); + + return _mm512_i32gather_epi64(indices, from, 8); +} +template <> EIGEN_DEVICE_FUNC inline Packet16i pgather(const int* from, Index stride) { Packet16i stride_vector = _mm512_set1_epi32(convert_index(stride)); Packet16i stride_multiplier = _mm512_set_epi32(15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0); @@ -1043,7 +1207,6 @@ EIGEN_DEVICE_FUNC inline void pscatter(double* to, const Packe __mmask8 mask = static_cast<__mmask8>(umask); _mm512_mask_i32scatter_pd(to, mask, indices, from, 8); } - template <> EIGEN_DEVICE_FUNC inline void pscatter(float* to, const Packet16f& from, Index stride) { Packet16i stride_vector = _mm512_set1_epi32(convert_index(stride)); @@ -1059,6 +1222,13 @@ EIGEN_DEVICE_FUNC inline void pscatter(double* to, const Packe _mm512_i32scatter_pd(to, indices, from, 8); } template <> +EIGEN_DEVICE_FUNC inline void pscatter(int64_t* to, const Packet8l& from, Index stride) { + Packet8i stride_vector = _mm256_set1_epi32(convert_index(stride)); + Packet8i stride_multiplier = _mm256_set_epi32(7, 6, 5, 4, 3, 2, 1, 0); + Packet8i indices = _mm256_mullo_epi32(stride_vector, stride_multiplier); + _mm512_i32scatter_epi64(to, indices, from, 8); +} +template <> EIGEN_DEVICE_FUNC inline void pscatter(int* to, const Packet16i& from, Index stride) { Packet16i stride_vector = _mm512_set1_epi32(convert_index(stride)); Packet16i stride_multiplier = _mm512_set_epi32(15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0); @@ -1081,6 +1251,11 @@ EIGEN_STRONG_INLINE void pstore1(int* to, const int& a) { Packet16i pa = pset1(a); pstore(to, pa); } +template <> +EIGEN_STRONG_INLINE void pstore1(int64_t* to, const int64_t& a) { + Packet8l pa = pset1(a); + pstore(to, pa); +} template <> EIGEN_STRONG_INLINE void prefetch(const float* addr) { @@ -1097,15 +1272,20 @@ EIGEN_STRONG_INLINE void prefetch(const int* addr) { template <> EIGEN_STRONG_INLINE float pfirst(const Packet16f& a) { - return _mm_cvtss_f32(_mm512_extractf32x4_ps(a, 0)); + return _mm512_cvtss_f32(a); } template <> EIGEN_STRONG_INLINE double pfirst(const Packet8d& a) { - return _mm_cvtsd_f64(_mm256_extractf128_pd(_mm512_extractf64x4_pd(a, 0), 0)); + return _mm512_cvtsd_f64(a); +} +template <> +EIGEN_STRONG_INLINE int64_t pfirst(const Packet8l& a) { + int64_t x = _mm_extract_epi64_0(_mm512_extracti32x4_epi32(a, 0)); + return x; } template <> EIGEN_STRONG_INLINE int pfirst(const Packet16i& a) { - return _mm_extract_epi32(_mm512_extracti32x4_epi32(a, 0), 0); + return _mm512_cvtsi512_si32(a); } template <> @@ -1123,6 +1303,11 @@ EIGEN_STRONG_INLINE Packet16i preverse(const Packet16i& a) { return _mm512_permutexvar_epi32(_mm512_set_epi32(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15), a); } +template <> +EIGEN_STRONG_INLINE Packet8l preverse(const Packet8l& a) { + return _mm512_permutexvar_epi64(_mm512_set_epi64(0, 1, 2, 3, 4, 5, 6, 7), a); +} + template <> EIGEN_STRONG_INLINE Packet16f pabs(const Packet16f& a) { // _mm512_abs_ps intrinsic not found, so hack around it @@ -1137,6 +1322,10 @@ template <> EIGEN_STRONG_INLINE Packet16i pabs(const Packet16i& a) { return _mm512_abs_epi32(a); } +template <> +EIGEN_STRONG_INLINE Packet8l pabs(const Packet8l& a) { + return _mm512_abs_epi64(a); +} template <> EIGEN_STRONG_INLINE Packet16h psignbit(const Packet16h& a) { @@ -1268,9 +1457,7 @@ EIGEN_STRONG_INLINE float predux(const Packet16f& a) { __m128 lane2 = _mm512_extractf32x4_ps(a, 2); __m128 lane3 = _mm512_extractf32x4_ps(a, 3); __m128 sum = _mm_add_ps(_mm_add_ps(lane0, lane1), _mm_add_ps(lane2, lane3)); - sum = _mm_hadd_ps(sum, sum); - sum = _mm_hadd_ps(sum, _mm_permute_ps(sum, 1)); - return _mm_cvtss_f32(sum); + return predux(sum); #endif } template <> @@ -1278,26 +1465,17 @@ EIGEN_STRONG_INLINE double predux(const Packet8d& a) { __m256d lane0 = _mm512_extractf64x4_pd(a, 0); __m256d lane1 = _mm512_extractf64x4_pd(a, 1); __m256d sum = _mm256_add_pd(lane0, lane1); - __m256d tmp0 = _mm256_hadd_pd(sum, _mm256_permute2f128_pd(sum, sum, 1)); - return _mm_cvtsd_f64(_mm256_castpd256_pd128(_mm256_hadd_pd(tmp0, tmp0))); + return predux(sum); } + +template <> +EIGEN_STRONG_INLINE int64_t predux(const Packet8l& a) { + return _mm512_reduce_add_epi64(a); +} + template <> EIGEN_STRONG_INLINE int predux(const Packet16i& a) { -#ifdef EIGEN_VECTORIZE_AVX512DQ - __m256i lane0 = _mm512_extracti32x8_epi32(a, 0); - __m256i lane1 = _mm512_extracti32x8_epi32(a, 1); - Packet8i x = _mm256_add_epi32(lane0, lane1); - return predux(x); -#else - __m128i lane0 = _mm512_extracti32x4_epi32(a, 0); - __m128i lane1 = _mm512_extracti32x4_epi32(a, 1); - __m128i lane2 = _mm512_extracti32x4_epi32(a, 2); - __m128i lane3 = _mm512_extracti32x4_epi32(a, 3); - __m128i sum = _mm_add_epi32(_mm_add_epi32(lane0, lane1), _mm_add_epi32(lane2, lane3)); - sum = _mm_hadd_epi32(sum, sum); - sum = _mm_hadd_epi32(sum, _mm_castps_si128(_mm_permute_ps(_mm_castsi128_ps(sum), 1))); - return _mm_cvtsi128_si32(sum); -#endif + return _mm512_reduce_add_epi32(a); } template <> @@ -1339,6 +1517,13 @@ EIGEN_STRONG_INLINE Packet8i predux_half_dowto4(const Packet16i& a) { #endif } +template <> +EIGEN_STRONG_INLINE Packet4l predux_half_dowto4(const Packet8l& a) { + __m256i lane0 = _mm512_extracti64x4_epi64(a, 0); + __m256i lane1 = _mm512_extracti64x4_epi64(a, 1); + return _mm256_add_epi64(lane0, lane1); +} + template <> EIGEN_STRONG_INLINE float predux_mul(const Packet16f& a) { // #ifdef EIGEN_VECTORIZE_AVX512DQ @@ -1367,6 +1552,14 @@ EIGEN_STRONG_INLINE double predux_mul(const Packet8d& a) { res = pmul(res, _mm256_permute2f128_pd(res, res, 1)); return pfirst(pmul(res, _mm256_shuffle_pd(res, res, 1))); } +template <> +EIGEN_STRONG_INLINE int predux_mul(const Packet16i& a) { + return _mm512_reduce_mul_epi32(a); +} +template <> +EIGEN_STRONG_INLINE int64_t predux_mul(const Packet8l& a) { + return _mm512_reduce_mul_epi64(a); +} template <> EIGEN_STRONG_INLINE float predux_min(const Packet16f& a) { @@ -1386,6 +1579,14 @@ EIGEN_STRONG_INLINE double predux_min(const Packet8d& a) { res = _mm256_min_pd(res, _mm256_permute2f128_pd(res, res, 1)); return pfirst(_mm256_min_pd(res, _mm256_shuffle_pd(res, res, 1))); } +template <> +EIGEN_STRONG_INLINE int predux_min(const Packet16i& a) { + return _mm512_reduce_min_epi32(a); +} +template <> +EIGEN_STRONG_INLINE int64_t predux_min(const Packet8l& a) { + return _mm512_reduce_min_epi64(a); +} template <> EIGEN_STRONG_INLINE float predux_max(const Packet16f& a) { @@ -1406,6 +1607,14 @@ EIGEN_STRONG_INLINE double predux_max(const Packet8d& a) { res = _mm256_max_pd(res, _mm256_permute2f128_pd(res, res, 1)); return pfirst(_mm256_max_pd(res, _mm256_shuffle_pd(res, res, 1))); } +template <> +EIGEN_STRONG_INLINE int predux_max(const Packet16i& a) { + return _mm512_reduce_max_epi32(a); +} +template <> +EIGEN_STRONG_INLINE int64_t predux_max(const Packet8l& a) { + return _mm512_reduce_max_epi64(a); +} template <> EIGEN_STRONG_INLINE bool predux_any(const Packet16f& x) { @@ -1617,6 +1826,10 @@ EIGEN_DEVICE_FUNC inline void ptranspose(PacketBlock& kernel) { OUTPUT[INDEX] = _mm512_insertf64x4(OUTPUT[INDEX], INPUT[(2 * INDEX)], 0); \ OUTPUT[INDEX] = _mm512_insertf64x4(OUTPUT[INDEX], INPUT[(2 * INDEX) + STRIDE], 1); +#define PACK_OUTPUT_L(OUTPUT, INPUT, INDEX, STRIDE) \ + OUTPUT[INDEX] = _mm512_inserti64x4(OUTPUT[INDEX], INPUT[(2 * INDEX)], 0); \ + OUTPUT[INDEX] = _mm512_inserti64x4(OUTPUT[INDEX], INPUT[(2 * INDEX) + STRIDE], 1); + EIGEN_DEVICE_FUNC inline void ptranspose(PacketBlock& kernel) { __m512d T0 = _mm512_shuffle_pd(kernel.packet[0], kernel.packet[1], 0); __m512d T1 = _mm512_shuffle_pd(kernel.packet[0], kernel.packet[1], 0xff); @@ -1695,6 +1908,88 @@ EIGEN_DEVICE_FUNC inline void ptranspose(PacketBlock& kernel) { kernel.packet[7] = T7; } +EIGEN_DEVICE_FUNC inline void ptranspose(PacketBlock& kernel) { + __m512i T0 = _mm512_castpd_si512( + _mm512_shuffle_pd(_mm512_castsi512_pd(kernel.packet[0]), _mm512_castsi512_pd(kernel.packet[1]), 0)); + __m512i T1 = _mm512_castpd_si512( + _mm512_shuffle_pd(_mm512_castsi512_pd(kernel.packet[0]), _mm512_castsi512_pd(kernel.packet[1]), 0xff)); + __m512i T2 = _mm512_castpd_si512( + _mm512_shuffle_pd(_mm512_castsi512_pd(kernel.packet[2]), _mm512_castsi512_pd(kernel.packet[3]), 0)); + __m512i T3 = _mm512_castpd_si512( + _mm512_shuffle_pd(_mm512_castsi512_pd(kernel.packet[2]), _mm512_castsi512_pd(kernel.packet[3]), 0xff)); + + PacketBlock tmp; + + tmp.packet[0] = _mm256_permute2x128_si256(_mm512_extracti64x4_epi64(T0, 0), _mm512_extracti64x4_epi64(T2, 0), 0x20); + tmp.packet[1] = _mm256_permute2x128_si256(_mm512_extracti64x4_epi64(T1, 0), _mm512_extracti64x4_epi64(T3, 0), 0x20); + tmp.packet[2] = _mm256_permute2x128_si256(_mm512_extracti64x4_epi64(T0, 0), _mm512_extracti64x4_epi64(T2, 0), 0x31); + tmp.packet[3] = _mm256_permute2x128_si256(_mm512_extracti64x4_epi64(T1, 0), _mm512_extracti64x4_epi64(T3, 0), 0x31); + + tmp.packet[4] = _mm256_permute2x128_si256(_mm512_extracti64x4_epi64(T0, 1), _mm512_extracti64x4_epi64(T2, 1), 0x20); + tmp.packet[5] = _mm256_permute2x128_si256(_mm512_extracti64x4_epi64(T1, 1), _mm512_extracti64x4_epi64(T3, 1), 0x20); + tmp.packet[6] = _mm256_permute2x128_si256(_mm512_extracti64x4_epi64(T0, 1), _mm512_extracti64x4_epi64(T2, 1), 0x31); + tmp.packet[7] = _mm256_permute2x128_si256(_mm512_extracti64x4_epi64(T1, 1), _mm512_extracti64x4_epi64(T3, 1), 0x31); + + PACK_OUTPUT_L(kernel.packet, tmp.packet, 0, 1); + PACK_OUTPUT_L(kernel.packet, tmp.packet, 1, 1); + PACK_OUTPUT_L(kernel.packet, tmp.packet, 2, 1); + PACK_OUTPUT_L(kernel.packet, tmp.packet, 3, 1); +} + +EIGEN_DEVICE_FUNC inline void ptranspose(PacketBlock& kernel) { + __m512i T0 = _mm512_unpacklo_epi64(kernel.packet[0], kernel.packet[1]); + __m512i T1 = _mm512_unpackhi_epi64(kernel.packet[0], kernel.packet[1]); + __m512i T2 = _mm512_unpacklo_epi64(kernel.packet[2], kernel.packet[3]); + __m512i T3 = _mm512_unpackhi_epi64(kernel.packet[2], kernel.packet[3]); + __m512i T4 = _mm512_unpacklo_epi64(kernel.packet[4], kernel.packet[5]); + __m512i T5 = _mm512_unpackhi_epi64(kernel.packet[4], kernel.packet[5]); + __m512i T6 = _mm512_unpacklo_epi64(kernel.packet[6], kernel.packet[7]); + __m512i T7 = _mm512_unpackhi_epi64(kernel.packet[6], kernel.packet[7]); + + kernel.packet[0] = _mm512_permutex_epi64(T2, 0x4E); + kernel.packet[0] = _mm512_mask_blend_epi64(0xCC, T0, kernel.packet[0]); + kernel.packet[2] = _mm512_permutex_epi64(T0, 0x4E); + kernel.packet[2] = _mm512_mask_blend_epi64(0xCC, kernel.packet[2], T2); + kernel.packet[1] = _mm512_permutex_epi64(T3, 0x4E); + kernel.packet[1] = _mm512_mask_blend_epi64(0xCC, T1, kernel.packet[1]); + kernel.packet[3] = _mm512_permutex_epi64(T1, 0x4E); + kernel.packet[3] = _mm512_mask_blend_epi64(0xCC, kernel.packet[3], T3); + kernel.packet[4] = _mm512_permutex_epi64(T6, 0x4E); + kernel.packet[4] = _mm512_mask_blend_epi64(0xCC, T4, kernel.packet[4]); + kernel.packet[6] = _mm512_permutex_epi64(T4, 0x4E); + kernel.packet[6] = _mm512_mask_blend_epi64(0xCC, kernel.packet[6], T6); + kernel.packet[5] = _mm512_permutex_epi64(T7, 0x4E); + kernel.packet[5] = _mm512_mask_blend_epi64(0xCC, T5, kernel.packet[5]); + kernel.packet[7] = _mm512_permutex_epi64(T5, 0x4E); + kernel.packet[7] = _mm512_mask_blend_epi64(0xCC, kernel.packet[7], T7); + + T0 = _mm512_shuffle_i64x2(kernel.packet[4], kernel.packet[4], 0x4E); + T0 = _mm512_mask_blend_epi64(0xF0, kernel.packet[0], T0); + T4 = _mm512_shuffle_i64x2(kernel.packet[0], kernel.packet[0], 0x4E); + T4 = _mm512_mask_blend_epi64(0xF0, T4, kernel.packet[4]); + T1 = _mm512_shuffle_i64x2(kernel.packet[5], kernel.packet[5], 0x4E); + T1 = _mm512_mask_blend_epi64(0xF0, kernel.packet[1], T1); + T5 = _mm512_shuffle_i64x2(kernel.packet[1], kernel.packet[1], 0x4E); + T5 = _mm512_mask_blend_epi64(0xF0, T5, kernel.packet[5]); + T2 = _mm512_shuffle_i64x2(kernel.packet[6], kernel.packet[6], 0x4E); + T2 = _mm512_mask_blend_epi64(0xF0, kernel.packet[2], T2); + T6 = _mm512_shuffle_i64x2(kernel.packet[2], kernel.packet[2], 0x4E); + T6 = _mm512_mask_blend_epi64(0xF0, T6, kernel.packet[6]); + T3 = _mm512_shuffle_i64x2(kernel.packet[7], kernel.packet[7], 0x4E); + T3 = _mm512_mask_blend_epi64(0xF0, kernel.packet[3], T3); + T7 = _mm512_shuffle_i64x2(kernel.packet[3], kernel.packet[3], 0x4E); + T7 = _mm512_mask_blend_epi64(0xF0, T7, kernel.packet[7]); + + kernel.packet[0] = T0; + kernel.packet[1] = T1; + kernel.packet[2] = T2; + kernel.packet[3] = T3; + kernel.packet[4] = T4; + kernel.packet[5] = T5; + kernel.packet[6] = T6; + kernel.packet[7] = T7; +} + #define PACK_OUTPUT_I32(OUTPUT, INPUT, INDEX, STRIDE) \ EIGEN_INSERT_8i_INTO_16i(OUTPUT[INDEX], INPUT[INDEX], INPUT[INDEX + STRIDE]); diff --git a/Eigen/src/Core/arch/AVX512/TypeCasting.h b/Eigen/src/Core/arch/AVX512/TypeCasting.h index ccdb563de..b16e9f6c8 100644 --- a/Eigen/src/Core/arch/AVX512/TypeCasting.h +++ b/Eigen/src/Core/arch/AVX512/TypeCasting.h @@ -37,6 +37,11 @@ struct type_casting_traits : vectorized_type_casting_traits struct type_casting_traits : vectorized_type_casting_traits {}; +template <> +struct type_casting_traits : vectorized_type_casting_traits {}; +template <> +struct type_casting_traits : vectorized_type_casting_traits {}; + #ifndef EIGEN_VECTORIZE_AVX512FP16 template <> struct type_casting_traits : vectorized_type_casting_traits {}; @@ -75,6 +80,19 @@ EIGEN_STRONG_INLINE Packet8d pcast(const Packet8f& a) { return _mm512_cvtps_pd(a); } +template <> +EIGEN_STRONG_INLINE Packet8l pcast(const Packet8d& a) { +#if defined(EIGEN_VECTORIZE_AVX512DQ) && defined(EIGEN_VECTORIZE_AVS512VL) + return _mm512_cvttpd_epi64(a); +#else + EIGEN_ALIGN16 double aux[8]; + pstore(aux, a); + return _mm512_set_epi64(static_cast(aux[7]), static_cast(aux[6]), static_cast(aux[5]), + static_cast(aux[4]), static_cast(aux[3]), static_cast(aux[2]), + static_cast(aux[1]), static_cast(aux[0])); +#endif +} + template <> EIGEN_STRONG_INLINE Packet16f pcast(const Packet16i& a) { return _mm512_cvtepi32_ps(a); @@ -90,6 +108,19 @@ EIGEN_STRONG_INLINE Packet8d pcast(const Packet8i& a) { return _mm512_cvtepi32_pd(a); } +template <> +EIGEN_STRONG_INLINE Packet8d pcast(const Packet8l& a) { +#if defined(EIGEN_VECTORIZE_AVX512DQ) && defined(EIGEN_VECTORIZE_AVS512VL) + return _mm512_cvtepi64_pd(a); +#else + EIGEN_ALIGN16 int64_t aux[8]; + pstore(aux, a); + return _mm512_set_pd(static_cast(aux[7]), static_cast(aux[6]), static_cast(aux[5]), + static_cast(aux[4]), static_cast(aux[3]), static_cast(aux[2]), + static_cast(aux[1]), static_cast(aux[0])); +#endif +} + template <> EIGEN_STRONG_INLINE Packet16f pcast(const Packet8d& a, const Packet8d& b) { return cat256(_mm512_cvtpd_ps(a), _mm512_cvtpd_ps(b)); @@ -124,6 +155,16 @@ EIGEN_STRONG_INLINE Packet8d preinterpret(const Packet16f& return _mm512_castps_pd(a); } +template <> +EIGEN_STRONG_INLINE Packet8d preinterpret(const Packet8l& a) { + return _mm512_castsi512_pd(a); +} + +template <> +EIGEN_STRONG_INLINE Packet8l preinterpret(const Packet8d& a) { + return _mm512_castpd_si512(a); +} + template <> EIGEN_STRONG_INLINE Packet16f preinterpret(const Packet8d& a) { return _mm512_castpd_ps(a);