From a0fc640c185b8d021f31b386b4eed0dd06250cdf Mon Sep 17 00:00:00 2001 From: Ilya Tokar Date: Fri, 21 Jan 2022 19:55:23 +0000 Subject: [PATCH] Add support for packets of int64 on x86 --- Eigen/src/Core/arch/AVX/PacketMath.h | 221 +++++++++++++++++++++++- Eigen/src/Core/arch/AVX512/PacketMath.h | 20 +-- 2 files changed, 229 insertions(+), 12 deletions(-) diff --git a/Eigen/src/Core/arch/AVX/PacketMath.h b/Eigen/src/Core/arch/AVX/PacketMath.h index ef73d35c5..6b20da6ae 100644 --- a/Eigen/src/Core/arch/AVX/PacketMath.h +++ b/Eigen/src/Core/arch/AVX/PacketMath.h @@ -31,16 +31,25 @@ namespace internal { #endif typedef __m256 Packet8f; -typedef __m256i Packet8i; +typedef eigen_packet_wrapper<__m256i, 0> Packet8i; typedef __m256d Packet4d; typedef eigen_packet_wrapper<__m128i, 2> Packet8h; typedef eigen_packet_wrapper<__m128i, 3> Packet8bf; +#ifdef EIGEN_VECTORIZE_AVX2 +// Start from 3 to be compatible with AVX512 +typedef eigen_packet_wrapper<__m256i, 3> Packet4l; +#endif + template<> struct is_arithmetic<__m256> { enum { value = true }; }; template<> struct is_arithmetic<__m256i> { enum { value = true }; }; template<> struct is_arithmetic<__m256d> { enum { value = true }; }; +template<> struct is_arithmetic { enum { value = true }; }; template<> struct is_arithmetic { enum { value = true }; }; template<> struct is_arithmetic { enum { value = true }; }; +#ifdef EIGEN_VECTORIZE_AVX2 +template<> struct is_arithmetic { enum { value = true }; }; +#endif #define EIGEN_DECLARE_CONST_Packet8f(NAME,X) \ const Packet8f p8f_##NAME = pset1(X) @@ -209,6 +218,25 @@ template<> struct packet_traits : default_packet_traits size=8 }; }; + +#ifdef EIGEN_VECTORIZE_AVX2 +template<> struct packet_traits : default_packet_traits +{ + typedef Packet4l type; + // There is no half-size packet for current Packet4l. + // TODO: support as SSE path. + typedef Packet4l half; + enum { + Vectorizable = 1, + AlignedOnScalar = 1, + size=4, + + // requires AVX512 + HasShift = 0, + }; +}; +#endif + #endif template<> struct scalar_div_cost { enum { value = 14 }; }; @@ -231,6 +259,13 @@ template<> struct unpacket_traits { typedef Packet4i half; enum {size=8, alignment=Aligned32, vectorizable=true, masked_load_available=false, masked_store_available=false}; }; +#ifdef EIGEN_VECTORIZE_AVX2 +template<> struct unpacket_traits { + typedef int64_t type; + typedef Packet4l half; + enum {size=4, alignment=Aligned32, vectorizable=true, masked_load_available=false, masked_store_available=false}; +}; +#endif template<> struct unpacket_traits { typedef bfloat16 type; typedef Packet8bf half; @@ -244,6 +279,181 @@ EIGEN_STRONG_INLINE __m128i Pack16To8(Packet8f rf) { _mm256_extractf128_si256(_mm256_castps_si256(rf), 1)); } +#ifdef EIGEN_VECTORIZE_AVX2 +template <> +EIGEN_STRONG_INLINE Packet4l pset1(const int64_t& from) { + return _mm256_set1_epi64x(from); +} +template <> +EIGEN_STRONG_INLINE Packet4l pzero(const Packet4l& /*a*/) { + return _mm256_setzero_si256(); +} +template <> +EIGEN_STRONG_INLINE Packet4l peven_mask(const Packet4l& /*a*/) { + return _mm256_set_epi64x(0ll, -1ll, 0ll, -1ll); +} +template <> +EIGEN_STRONG_INLINE Packet4l pload1(const int64_t* from) { + return _mm256_set1_epi64x(*from); +} +template <> +EIGEN_STRONG_INLINE Packet4l padd(const Packet4l& a, const Packet4l& b) { + return _mm256_add_epi64(a, b); +} +template <> +EIGEN_STRONG_INLINE Packet4l plset(const int64_t& a) { + return padd(pset1(a), Packet4l(_mm256_set_epi64x(3ll, 2ll, 1ll, 0ll))); +} +template <> +EIGEN_STRONG_INLINE Packet4l psub(const Packet4l& a, const Packet4l& b) { + return _mm256_sub_epi64(a, b); +} +template <> +EIGEN_STRONG_INLINE Packet4l pnegate(const Packet4l& a) { + return psub(pzero(a), a); +} +template <> +EIGEN_STRONG_INLINE Packet4l pconj(const Packet4l& a) { + return a; +} +template <> +EIGEN_STRONG_INLINE Packet4l pcmp_le(const Packet4l& a, const Packet4l& b) { + return _mm256_xor_si256(_mm256_cmpgt_epi64(a, b), _mm256_set1_epi32(-1)); +} +template <> +EIGEN_STRONG_INLINE Packet4l pcmp_lt(const Packet4l& a, const Packet4l& b) { + return _mm256_cmpgt_epi64(b, a); +} +template <> +EIGEN_STRONG_INLINE Packet4l pcmp_eq(const Packet4l& a, const Packet4l& b) { + return _mm256_cmpeq_epi64(a, b); +} +template <> +EIGEN_STRONG_INLINE Packet4l ptrue(const Packet4l& a) { + return _mm256_cmpeq_epi64(a, a); +} +template <> +EIGEN_STRONG_INLINE Packet4l pand(const Packet4l& a, const Packet4l& b) { + return _mm256_and_si256(a, b); +} +template <> +EIGEN_STRONG_INLINE Packet4l por(const Packet4l& a, const Packet4l& b) { + return _mm256_or_si256(a, b); +} +template <> +EIGEN_STRONG_INLINE Packet4l pxor(const Packet4l& a, const Packet4l& b) { + return _mm256_xor_si256(a, b); +} +template <> +EIGEN_STRONG_INLINE Packet4l pandnot(const Packet4l& a, const Packet4l& b) { + return _mm256_andnot_si256(b, a); +} +template +EIGEN_STRONG_INLINE Packet4l plogical_shift_right(Packet4l a) { + return _mm256_srli_epi64(a, N); +} +template +EIGEN_STRONG_INLINE Packet4l plogical_shift_left(Packet4l a) { + return _mm256_slli_epi64(a, N); +} +template <> +EIGEN_STRONG_INLINE Packet4l pload(const int64_t* from) { + EIGEN_DEBUG_ALIGNED_LOAD return _mm256_load_si256(reinterpret_cast(from)); +} +template <> +EIGEN_STRONG_INLINE Packet4l ploadu(const int64_t* from) { + EIGEN_DEBUG_UNALIGNED_LOAD return _mm256_loadu_si256(reinterpret_cast(from)); +} +// Loads 2 int64_ts from memory a returns the packet {a0, a0, a1, a1} +template <> +EIGEN_STRONG_INLINE Packet4l ploaddup(const int64_t* from) { + const Packet4l a = _mm256_castsi128_si256(_mm_loadu_si128(reinterpret_cast(from))); + return _mm256_permutevar8x32_epi32(a, _mm256_setr_epi32(0, 1, 0, 1, 2, 3, 2, 3)); +} +template <> +EIGEN_STRONG_INLINE void pstore(int64_t* to, const Packet4l& from) { + EIGEN_DEBUG_ALIGNED_STORE _mm256_storeu_si256(reinterpret_cast<__m256i*>(to), from); +} +template <> +EIGEN_STRONG_INLINE void pstoreu(int64_t* to, const Packet4l& from) { + EIGEN_DEBUG_UNALIGNED_STORE _mm256_storeu_si256(reinterpret_cast<__m256i*>(to), from); +} +template <> +EIGEN_DEVICE_FUNC inline Packet4l pgather(const int64_t* from, Index stride) { + return _mm256_set_epi64x(from[3 * stride], from[2 * stride], from[1 * stride], from[0 * stride]); +} +template <> +EIGEN_DEVICE_FUNC inline void pscatter(int64_t* to, const Packet4l& from, Index stride) { + __m128i low = _mm256_extractf128_si256(from, 0); + to[stride * 0] = _mm_extract_epi64(low, 0); + to[stride * 1] = _mm_extract_epi64(low, 1); + + __m128i high = _mm256_extractf128_si256(from, 1); + to[stride * 2] = _mm_extract_epi64(high, 0); + to[stride * 3] = _mm_extract_epi64(high, 1); +} +template <> +EIGEN_STRONG_INLINE void pstore1(int64_t* to, const int64_t& a) { + Packet4l pa = pset1(a); + pstore(to, pa); +} +template <> +EIGEN_STRONG_INLINE int64_t pfirst(const Packet4l& a) { + return _mm_cvtsi128_si64(_mm256_castsi256_si128(a)); +} +template <> +EIGEN_STRONG_INLINE int64_t predux(const Packet4l& a) { + __m128i r = _mm_add_epi64(_mm256_castsi256_si128(a), _mm256_extractf128_si256(a, 1)); + return _mm_extract_epi64(r, 0) + _mm_extract_epi64(r, 1); +} +#define MM256_SHUFFLE_EPI64(A, B, M) _mm256_shuffle_pd(_mm256_castsi256_pd(A), _mm256_castsi256_pd(B), M) +EIGEN_DEVICE_FUNC inline void ptranspose(PacketBlock& kernel) { + __m256d T0 = MM256_SHUFFLE_EPI64(kernel.packet[0], kernel.packet[1], 15); + __m256d T1 = MM256_SHUFFLE_EPI64(kernel.packet[0], kernel.packet[1], 0); + __m256d T2 = MM256_SHUFFLE_EPI64(kernel.packet[2], kernel.packet[3], 15); + __m256d T3 = MM256_SHUFFLE_EPI64(kernel.packet[2], kernel.packet[3], 0); + + kernel.packet[1] = _mm256_castpd_si256(_mm256_permute2f128_pd(T0, T2, 32)); + kernel.packet[3] = _mm256_castpd_si256(_mm256_permute2f128_pd(T0, T2, 49)); + kernel.packet[0] = _mm256_castpd_si256(_mm256_permute2f128_pd(T1, T3, 32)); + kernel.packet[2] = _mm256_castpd_si256(_mm256_permute2f128_pd(T1, T3, 49)); +} +template <> +EIGEN_STRONG_INLINE Packet4l pmin(const Packet4l& a, const Packet4l& b) { + __m256i cmp = _mm256_cmpgt_epi64(a, b); + __m256i a_min = _mm256_andnot_si256(cmp, a); + __m256i b_min = _mm256_and_si256(cmp, b); + return Packet4l(_mm256_or_si256(a_min, b_min)); +} +template <> +EIGEN_STRONG_INLINE Packet4l pmax(const Packet4l& a, const Packet4l& b) { + __m256i cmp = _mm256_cmpgt_epi64(a, b); + __m256i a_min = _mm256_and_si256(cmp, a); + __m256i b_min = _mm256_andnot_si256(cmp, b); + return Packet4l(_mm256_or_si256(a_min, b_min)); +} +template <> +EIGEN_STRONG_INLINE Packet4l pabs(const Packet4l& a) { + Packet4l pz = pzero(a); + Packet4l cmp = _mm256_cmpgt_epi64(a, pz); + return psub(cmp, pxor(a, cmp)); +} +template <> +EIGEN_STRONG_INLINE Packet4l pmul(const Packet4l& a, const Packet4l& b) { + // 64-bit mul requires avx512, so do this with 32-bit multiplication + __m256i upper32_a = _mm256_srli_epi64(a, 32); + __m256i upper32_b = _mm256_srli_epi64(b, 32); + + // upper * lower + __m256i mul1 = _mm256_mul_epu32(upper32_a, b); + __m256i mul2 = _mm256_mul_epu32(upper32_b, a); + // Gives us both upper*upper and lower*lower + __m256i mul3 = _mm256_mul_epu32(a, b); + + __m256i high = _mm256_slli_epi64(_mm256_add_epi64(mul1, mul2), 32); + return _mm256_add_epi64(high, mul3); +} +#endif template<> EIGEN_STRONG_INLINE Packet8f pset1(const float& from) { return _mm256_set1_ps(from); } template<> EIGEN_STRONG_INLINE Packet4d pset1(const double& from) { return _mm256_set1_pd(from); } @@ -278,7 +488,7 @@ template<> EIGEN_STRONG_INLINE Packet8i padd(const Packet8i& a, const template<> EIGEN_STRONG_INLINE Packet8f plset(const float& a) { return padd(pset1(a), _mm256_set_ps(7.0,6.0,5.0,4.0,3.0,2.0,1.0,0.0)); } template<> EIGEN_STRONG_INLINE Packet4d plset(const double& a) { return padd(pset1(a), _mm256_set_pd(3.0,2.0,1.0,0.0)); } -template<> EIGEN_STRONG_INLINE Packet8i plset(const int& a) { return padd(pset1(a), _mm256_set_epi32(7,6,5,4,3,2,1,0)); } +template<> EIGEN_STRONG_INLINE Packet8i plset(const int& a) { return padd(pset1(a), (Packet8i)_mm256_set_epi32(7,6,5,4,3,2,1,0)); } template<> EIGEN_STRONG_INLINE Packet8f psub(const Packet8f& a, const Packet8f& b) { return _mm256_sub_ps(a,b); } template<> EIGEN_STRONG_INLINE Packet4d psub(const Packet4d& a, const Packet4d& b) { return _mm256_sub_pd(a,b); } @@ -812,6 +1022,13 @@ template<> EIGEN_STRONG_INLINE Packet8i preverse(const Packet8i& a) return _mm256_castps_si256(preverse(_mm256_castsi256_ps(a))); } +#ifdef EIGEN_VECTORIZE_AVX2 +template<> EIGEN_STRONG_INLINE Packet4l preverse(const Packet4l& a) +{ + return _mm256_castpd_si256(preverse(_mm256_castsi256_pd(a))); +} +#endif + // pabs should be ok template<> EIGEN_STRONG_INLINE Packet8f pabs(const Packet8f& a) { diff --git a/Eigen/src/Core/arch/AVX512/PacketMath.h b/Eigen/src/Core/arch/AVX512/PacketMath.h index 4f857269e..ff0dd29b4 100644 --- a/Eigen/src/Core/arch/AVX512/PacketMath.h +++ b/Eigen/src/Core/arch/AVX512/PacketMath.h @@ -1772,7 +1772,7 @@ EIGEN_STRONG_INLINE Packet16h float2half(const Packet16f& a) { } template<> EIGEN_STRONG_INLINE Packet16h ptrue(const Packet16h& a) { - return ptrue(Packet8i(a)); + return Packet16h(ptrue(Packet8i(a))); } template <> @@ -1801,16 +1801,16 @@ EIGEN_STRONG_INLINE Packet16h plset(const half& a) { template<> EIGEN_STRONG_INLINE Packet16h por(const Packet16h& a,const Packet16h& b) { // in some cases Packet8i is a wrapper around __m256i, so we need to // cast to Packet8i to call the correct overload. - return por(Packet8i(a),Packet8i(b)); + return Packet16h(por(Packet8i(a),Packet8i(b))); } template<> EIGEN_STRONG_INLINE Packet16h pxor(const Packet16h& a,const Packet16h& b) { - return pxor(Packet8i(a),Packet8i(b)); + return Packet16h(pxor(Packet8i(a),Packet8i(b))); } template<> EIGEN_STRONG_INLINE Packet16h pand(const Packet16h& a,const Packet16h& b) { - return pand(Packet8i(a),Packet8i(b)); + return Packet16h(pand(Packet8i(a),Packet8i(b))); } template<> EIGEN_STRONG_INLINE Packet16h pandnot(const Packet16h& a,const Packet16h& b) { - return pandnot(Packet8i(a),Packet8i(b)); + return Packet16h(pandnot(Packet8i(a),Packet8i(b))); } template<> EIGEN_STRONG_INLINE Packet16h pselect(const Packet16h& mask, const Packet16h& a, const Packet16h& b) { @@ -2265,28 +2265,28 @@ EIGEN_STRONG_INLINE Packet16bf F32ToBf16(const Packet16f& a) { template <> EIGEN_STRONG_INLINE Packet16bf ptrue(const Packet16bf& a) { - return ptrue(a); + return Packet16bf(ptrue(Packet8i(a))); } template <> EIGEN_STRONG_INLINE Packet16bf por(const Packet16bf& a, const Packet16bf& b) { - return por(a, b); + return Packet16bf(por(Packet8i(a), Packet8i(b))); } template <> EIGEN_STRONG_INLINE Packet16bf pxor(const Packet16bf& a, const Packet16bf& b) { - return pxor(a, b); + return Packet16bf(pxor(Packet8i(a), Packet8i(b))); } template <> EIGEN_STRONG_INLINE Packet16bf pand(const Packet16bf& a, const Packet16bf& b) { - return pand(a, b); + return Packet16bf(pand((Packet8i)a, (Packet8i)b)); } template <> EIGEN_STRONG_INLINE Packet16bf pandnot(const Packet16bf& a, const Packet16bf& b) { - return pandnot(a, b); + return Packet16bf(pandnot((Packet8i)a, (Packet8i)b)); } template <>