Add missing explicit reinterprets

This commit is contained in:
b-shi 2022-03-23 21:10:26 +00:00 committed by Antonio Sánchez
parent cd3c81c3bc
commit 0611f7fff0
2 changed files with 32 additions and 24 deletions

View File

@ -32,6 +32,26 @@ template<> EIGEN_STRONG_INLINE Packet16f preinterpret<Packet16f, Packet16i>(cons
return _mm512_castsi512_ps(a);
}
template<> EIGEN_STRONG_INLINE Packet8d preinterpret<Packet8d, Packet16f>(const Packet16f& a) {
return _mm512_castps_pd(a);
}
template<> EIGEN_STRONG_INLINE Packet16f preinterpret<Packet16f, Packet8d>(const Packet8d& a) {
return _mm512_castpd_ps(a);
}
template<> EIGEN_STRONG_INLINE Packet8f preinterpret<Packet8f, Packet16f>(const Packet16f& a) {
return _mm512_castps512_ps256(a);
}
template<> EIGEN_STRONG_INLINE Packet16f preinterpret<Packet16f, Packet16f>(const Packet16f& a) {
return a;
}
template<> EIGEN_STRONG_INLINE Packet8d preinterpret<Packet8d, Packet8d>(const Packet8d& a) {
return a;
}
template <>
struct type_casting_traits<half, float> {
enum {

View File

@ -65,21 +65,6 @@ EIGEN_ALWAYS_INLINE auto remMask(int64_t m) {
return 0;
}
template<typename T1, typename T2>
EIGEN_ALWAYS_INLINE T2 castPacket(T1 &a) {
return reinterpret_cast<T2>(a);
}
template<>
EIGEN_ALWAYS_INLINE vecHalfFloat castPacket(vecFullFloat &a) {
return _mm512_castps512_ps256(a);
}
template<>
EIGEN_ALWAYS_INLINE vecFullDouble castPacket(vecFullDouble &a) {
return a;
}
/***
* Unrolls for tranposed C stores
*/
@ -118,7 +103,7 @@ public:
pstoreu<Scalar>(
C_arr + LDC*startN,
padd(ploadu<vecHalf>((const Scalar*)C_arr + LDC*startN, remMask<EIGEN_AVX_MAX_NUM_ROW>(remM_)),
castPacket<vec,vecHalf>(zmm.packet[packetIndexOffset + (unrollN/PacketSize)*startN]),
preinterpret<vecHalf>(zmm.packet[packetIndexOffset + (unrollN/PacketSize)*startN]),
remMask<EIGEN_AVX_MAX_NUM_ROW>(remM_)),
remMask<EIGEN_AVX_MAX_NUM_ROW>(remM_));
}
@ -126,27 +111,30 @@ public:
pstoreu<Scalar>(
C_arr + LDC*startN,
padd(ploadu<vecHalf>((const Scalar*)C_arr + LDC*startN),
castPacket<vec,vecHalf>(zmm.packet[packetIndexOffset + (unrollN/PacketSize)*startN])));
preinterpret<vecHalf>(zmm.packet[packetIndexOffset + (unrollN/PacketSize)*startN])));
}
}
else {
else { // This block is only needed for fp32 case
// Reinterpret as __m512 for _mm512_shuffle_f32x4
vecFullFloat zmm2vecFullFloat = preinterpret<vecFullFloat>(
zmm.packet[packetIndexOffset + (unrollN/PacketSize)*(startN - EIGEN_AVX_MAX_NUM_ROW)]);
// Swap lower and upper half of avx register.
zmm.packet[packetIndexOffset + (unrollN/PacketSize)*(startN - EIGEN_AVX_MAX_NUM_ROW)] =
_mm512_shuffle_f32x4(
zmm.packet[packetIndexOffset + (unrollN/PacketSize)*(startN - EIGEN_AVX_MAX_NUM_ROW)],
zmm.packet[packetIndexOffset + (unrollN/PacketSize)*(startN - EIGEN_AVX_MAX_NUM_ROW)], 0b01001110);
preinterpret<vec>(_mm512_shuffle_f32x4(zmm2vecFullFloat, zmm2vecFullFloat, 0b01001110));
EIGEN_IF_CONSTEXPR(remM) {
pstoreu<Scalar>(
C_arr + LDC*startN,
padd(ploadu<vecHalf>((const Scalar*)C_arr + LDC*startN,
remMask<EIGEN_AVX_MAX_NUM_ROW>(remM_)),
castPacket<vec,vecHalf>(zmm.packet[packetIndexOffset + (unrollN/PacketSize)*(startN-EIGEN_AVX_MAX_NUM_ROW)])),
preinterpret<vecHalf>(zmm.packet[packetIndexOffset + (unrollN/PacketSize)*(startN-EIGEN_AVX_MAX_NUM_ROW)])),
remMask<EIGEN_AVX_MAX_NUM_ROW>(remM_));
}
else {
pstoreu<Scalar>(
C_arr + LDC*startN,
padd(ploadu<vecHalf>((const Scalar*)C_arr + LDC*startN),
castPacket<vec,vecHalf>(zmm.packet[packetIndexOffset + (unrollN/PacketSize)*(startN-EIGEN_AVX_MAX_NUM_ROW)])));
preinterpret<vecHalf>(zmm.packet[packetIndexOffset + (unrollN/PacketSize)*(startN-EIGEN_AVX_MAX_NUM_ROW)])));
}
}
aux_storeC<endN, counter - 1, unrollN, packetIndexOffset, remM>(C_arr, LDC, zmm, remM_);
@ -234,7 +222,7 @@ public:
* used as temps for transposing.
*
* Blocks will be of size Lx{U1,U2,U3}. packetIndexOffset is used to index between these subblocks
* For fp32, PacketSize = 2*EIGEN_AVX_MAX_NUM_ROW, so we cast packets to packets half the size (zmm -> ymm).
* For fp32, PacketSize = 2*EIGEN_AVX_MAX_NUM_ROW, so we reinterpret packets as packets half the size (zmm -> ymm).
*/
template <typename Scalar>
class transB {