Add missing explicit reinterprets

This commit is contained in:
b-shi 2022-03-23 21:10:26 +00:00 committed by Antonio Sánchez
parent cd3c81c3bc
commit 0611f7fff0
2 changed files with 32 additions and 24 deletions

View File

@ -32,6 +32,26 @@ template<> EIGEN_STRONG_INLINE Packet16f preinterpret<Packet16f, Packet16i>(cons
return _mm512_castsi512_ps(a); return _mm512_castsi512_ps(a);
} }
template<> EIGEN_STRONG_INLINE Packet8d preinterpret<Packet8d, Packet16f>(const Packet16f& a) {
return _mm512_castps_pd(a);
}
template<> EIGEN_STRONG_INLINE Packet16f preinterpret<Packet16f, Packet8d>(const Packet8d& a) {
return _mm512_castpd_ps(a);
}
template<> EIGEN_STRONG_INLINE Packet8f preinterpret<Packet8f, Packet16f>(const Packet16f& a) {
return _mm512_castps512_ps256(a);
}
template<> EIGEN_STRONG_INLINE Packet16f preinterpret<Packet16f, Packet16f>(const Packet16f& a) {
return a;
}
template<> EIGEN_STRONG_INLINE Packet8d preinterpret<Packet8d, Packet8d>(const Packet8d& a) {
return a;
}
template <> template <>
struct type_casting_traits<half, float> { struct type_casting_traits<half, float> {
enum { enum {

View File

@ -65,21 +65,6 @@ EIGEN_ALWAYS_INLINE auto remMask(int64_t m) {
return 0; return 0;
} }
template<typename T1, typename T2>
EIGEN_ALWAYS_INLINE T2 castPacket(T1 &a) {
return reinterpret_cast<T2>(a);
}
template<>
EIGEN_ALWAYS_INLINE vecHalfFloat castPacket(vecFullFloat &a) {
return _mm512_castps512_ps256(a);
}
template<>
EIGEN_ALWAYS_INLINE vecFullDouble castPacket(vecFullDouble &a) {
return a;
}
/*** /***
* Unrolls for tranposed C stores * Unrolls for tranposed C stores
*/ */
@ -118,7 +103,7 @@ public:
pstoreu<Scalar>( pstoreu<Scalar>(
C_arr + LDC*startN, C_arr + LDC*startN,
padd(ploadu<vecHalf>((const Scalar*)C_arr + LDC*startN, remMask<EIGEN_AVX_MAX_NUM_ROW>(remM_)), padd(ploadu<vecHalf>((const Scalar*)C_arr + LDC*startN, remMask<EIGEN_AVX_MAX_NUM_ROW>(remM_)),
castPacket<vec,vecHalf>(zmm.packet[packetIndexOffset + (unrollN/PacketSize)*startN]), preinterpret<vecHalf>(zmm.packet[packetIndexOffset + (unrollN/PacketSize)*startN]),
remMask<EIGEN_AVX_MAX_NUM_ROW>(remM_)), remMask<EIGEN_AVX_MAX_NUM_ROW>(remM_)),
remMask<EIGEN_AVX_MAX_NUM_ROW>(remM_)); remMask<EIGEN_AVX_MAX_NUM_ROW>(remM_));
} }
@ -126,27 +111,30 @@ public:
pstoreu<Scalar>( pstoreu<Scalar>(
C_arr + LDC*startN, C_arr + LDC*startN,
padd(ploadu<vecHalf>((const Scalar*)C_arr + LDC*startN), padd(ploadu<vecHalf>((const Scalar*)C_arr + LDC*startN),
castPacket<vec,vecHalf>(zmm.packet[packetIndexOffset + (unrollN/PacketSize)*startN]))); preinterpret<vecHalf>(zmm.packet[packetIndexOffset + (unrollN/PacketSize)*startN])));
} }
} }
else { else { // This block is only needed for fp32 case
// Reinterpret as __m512 for _mm512_shuffle_f32x4
vecFullFloat zmm2vecFullFloat = preinterpret<vecFullFloat>(
zmm.packet[packetIndexOffset + (unrollN/PacketSize)*(startN - EIGEN_AVX_MAX_NUM_ROW)]);
// Swap lower and upper half of avx register.
zmm.packet[packetIndexOffset + (unrollN/PacketSize)*(startN - EIGEN_AVX_MAX_NUM_ROW)] = zmm.packet[packetIndexOffset + (unrollN/PacketSize)*(startN - EIGEN_AVX_MAX_NUM_ROW)] =
_mm512_shuffle_f32x4( preinterpret<vec>(_mm512_shuffle_f32x4(zmm2vecFullFloat, zmm2vecFullFloat, 0b01001110));
zmm.packet[packetIndexOffset + (unrollN/PacketSize)*(startN - EIGEN_AVX_MAX_NUM_ROW)],
zmm.packet[packetIndexOffset + (unrollN/PacketSize)*(startN - EIGEN_AVX_MAX_NUM_ROW)], 0b01001110);
EIGEN_IF_CONSTEXPR(remM) { EIGEN_IF_CONSTEXPR(remM) {
pstoreu<Scalar>( pstoreu<Scalar>(
C_arr + LDC*startN, C_arr + LDC*startN,
padd(ploadu<vecHalf>((const Scalar*)C_arr + LDC*startN, padd(ploadu<vecHalf>((const Scalar*)C_arr + LDC*startN,
remMask<EIGEN_AVX_MAX_NUM_ROW>(remM_)), remMask<EIGEN_AVX_MAX_NUM_ROW>(remM_)),
castPacket<vec,vecHalf>(zmm.packet[packetIndexOffset + (unrollN/PacketSize)*(startN-EIGEN_AVX_MAX_NUM_ROW)])), preinterpret<vecHalf>(zmm.packet[packetIndexOffset + (unrollN/PacketSize)*(startN-EIGEN_AVX_MAX_NUM_ROW)])),
remMask<EIGEN_AVX_MAX_NUM_ROW>(remM_)); remMask<EIGEN_AVX_MAX_NUM_ROW>(remM_));
} }
else { else {
pstoreu<Scalar>( pstoreu<Scalar>(
C_arr + LDC*startN, C_arr + LDC*startN,
padd(ploadu<vecHalf>((const Scalar*)C_arr + LDC*startN), padd(ploadu<vecHalf>((const Scalar*)C_arr + LDC*startN),
castPacket<vec,vecHalf>(zmm.packet[packetIndexOffset + (unrollN/PacketSize)*(startN-EIGEN_AVX_MAX_NUM_ROW)]))); preinterpret<vecHalf>(zmm.packet[packetIndexOffset + (unrollN/PacketSize)*(startN-EIGEN_AVX_MAX_NUM_ROW)])));
} }
} }
aux_storeC<endN, counter - 1, unrollN, packetIndexOffset, remM>(C_arr, LDC, zmm, remM_); aux_storeC<endN, counter - 1, unrollN, packetIndexOffset, remM>(C_arr, LDC, zmm, remM_);
@ -234,7 +222,7 @@ public:
* used as temps for transposing. * used as temps for transposing.
* *
* Blocks will be of size Lx{U1,U2,U3}. packetIndexOffset is used to index between these subblocks * Blocks will be of size Lx{U1,U2,U3}. packetIndexOffset is used to index between these subblocks
* For fp32, PacketSize = 2*EIGEN_AVX_MAX_NUM_ROW, so we cast packets to packets half the size (zmm -> ymm). * For fp32, PacketSize = 2*EIGEN_AVX_MAX_NUM_ROW, so we reinterpret packets as packets half the size (zmm -> ymm).
*/ */
template <typename Scalar> template <typename Scalar>
class transB { class transB {