From 0611f7fff0eea6d61c94a77583578409d6f44251 Mon Sep 17 00:00:00 2001 From: b-shi Date: Wed, 23 Mar 2022 21:10:26 +0000 Subject: [PATCH] Add missing explicit reinterprets --- Eigen/src/Core/arch/AVX512/TypeCasting.h | 20 ++++++++++++ Eigen/src/Core/arch/AVX512/unrolls_impl.hpp | 36 +++++++-------------- 2 files changed, 32 insertions(+), 24 deletions(-) diff --git a/Eigen/src/Core/arch/AVX512/TypeCasting.h b/Eigen/src/Core/arch/AVX512/TypeCasting.h index 2f299e221..8baced1d6 100644 --- a/Eigen/src/Core/arch/AVX512/TypeCasting.h +++ b/Eigen/src/Core/arch/AVX512/TypeCasting.h @@ -32,6 +32,26 @@ template<> EIGEN_STRONG_INLINE Packet16f preinterpret(cons return _mm512_castsi512_ps(a); } +template<> EIGEN_STRONG_INLINE Packet8d preinterpret(const Packet16f& a) { + return _mm512_castps_pd(a); +} + +template<> EIGEN_STRONG_INLINE Packet16f preinterpret(const Packet8d& a) { + return _mm512_castpd_ps(a); +} + +template<> EIGEN_STRONG_INLINE Packet8f preinterpret(const Packet16f& a) { + return _mm512_castps512_ps256(a); +} + +template<> EIGEN_STRONG_INLINE Packet16f preinterpret(const Packet16f& a) { + return a; +} + +template<> EIGEN_STRONG_INLINE Packet8d preinterpret(const Packet8d& a) { + return a; +} + template <> struct type_casting_traits { enum { diff --git a/Eigen/src/Core/arch/AVX512/unrolls_impl.hpp b/Eigen/src/Core/arch/AVX512/unrolls_impl.hpp index 81bcf0980..22cb1c93d 100644 --- a/Eigen/src/Core/arch/AVX512/unrolls_impl.hpp +++ b/Eigen/src/Core/arch/AVX512/unrolls_impl.hpp @@ -65,21 +65,6 @@ EIGEN_ALWAYS_INLINE auto remMask(int64_t m) { return 0; } -template -EIGEN_ALWAYS_INLINE T2 castPacket(T1 &a) { - return reinterpret_cast(a); -} - -template<> -EIGEN_ALWAYS_INLINE vecHalfFloat castPacket(vecFullFloat &a) { - return _mm512_castps512_ps256(a); -} - -template<> -EIGEN_ALWAYS_INLINE vecFullDouble castPacket(vecFullDouble &a) { - return a; -} - /*** * Unrolls for tranposed C stores */ @@ -118,7 +103,7 @@ public: pstoreu( C_arr + LDC*startN, padd(ploadu((const Scalar*)C_arr + LDC*startN, remMask(remM_)), - castPacket(zmm.packet[packetIndexOffset + (unrollN/PacketSize)*startN]), + preinterpret(zmm.packet[packetIndexOffset + (unrollN/PacketSize)*startN]), remMask(remM_)), remMask(remM_)); } @@ -126,27 +111,30 @@ public: pstoreu( C_arr + LDC*startN, padd(ploadu((const Scalar*)C_arr + LDC*startN), - castPacket(zmm.packet[packetIndexOffset + (unrollN/PacketSize)*startN]))); + preinterpret(zmm.packet[packetIndexOffset + (unrollN/PacketSize)*startN]))); } } - else { + else { // This block is only needed for fp32 case + // Reinterpret as __m512 for _mm512_shuffle_f32x4 + vecFullFloat zmm2vecFullFloat = preinterpret( + zmm.packet[packetIndexOffset + (unrollN/PacketSize)*(startN - EIGEN_AVX_MAX_NUM_ROW)]); + // Swap lower and upper half of avx register. zmm.packet[packetIndexOffset + (unrollN/PacketSize)*(startN - EIGEN_AVX_MAX_NUM_ROW)] = - _mm512_shuffle_f32x4( - zmm.packet[packetIndexOffset + (unrollN/PacketSize)*(startN - EIGEN_AVX_MAX_NUM_ROW)], - zmm.packet[packetIndexOffset + (unrollN/PacketSize)*(startN - EIGEN_AVX_MAX_NUM_ROW)], 0b01001110); + preinterpret(_mm512_shuffle_f32x4(zmm2vecFullFloat, zmm2vecFullFloat, 0b01001110)); + EIGEN_IF_CONSTEXPR(remM) { pstoreu( C_arr + LDC*startN, padd(ploadu((const Scalar*)C_arr + LDC*startN, remMask(remM_)), - castPacket(zmm.packet[packetIndexOffset + (unrollN/PacketSize)*(startN-EIGEN_AVX_MAX_NUM_ROW)])), + preinterpret(zmm.packet[packetIndexOffset + (unrollN/PacketSize)*(startN-EIGEN_AVX_MAX_NUM_ROW)])), remMask(remM_)); } else { pstoreu( C_arr + LDC*startN, padd(ploadu((const Scalar*)C_arr + LDC*startN), - castPacket(zmm.packet[packetIndexOffset + (unrollN/PacketSize)*(startN-EIGEN_AVX_MAX_NUM_ROW)]))); + preinterpret(zmm.packet[packetIndexOffset + (unrollN/PacketSize)*(startN-EIGEN_AVX_MAX_NUM_ROW)]))); } } aux_storeC(C_arr, LDC, zmm, remM_); @@ -234,7 +222,7 @@ public: * used as temps for transposing. * * Blocks will be of size Lx{U1,U2,U3}. packetIndexOffset is used to index between these subblocks - * For fp32, PacketSize = 2*EIGEN_AVX_MAX_NUM_ROW, so we cast packets to packets half the size (zmm -> ymm). + * For fp32, PacketSize = 2*EIGEN_AVX_MAX_NUM_ROW, so we reinterpret packets as packets half the size (zmm -> ymm). */ template class transB {