mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-08-08 17:59:00 +08:00
Add missing explicit reinterprets
This commit is contained in:
parent
cd3c81c3bc
commit
0611f7fff0
@ -32,6 +32,26 @@ template<> EIGEN_STRONG_INLINE Packet16f preinterpret<Packet16f, Packet16i>(cons
|
||||
return _mm512_castsi512_ps(a);
|
||||
}
|
||||
|
||||
template<> EIGEN_STRONG_INLINE Packet8d preinterpret<Packet8d, Packet16f>(const Packet16f& a) {
|
||||
return _mm512_castps_pd(a);
|
||||
}
|
||||
|
||||
template<> EIGEN_STRONG_INLINE Packet16f preinterpret<Packet16f, Packet8d>(const Packet8d& a) {
|
||||
return _mm512_castpd_ps(a);
|
||||
}
|
||||
|
||||
template<> EIGEN_STRONG_INLINE Packet8f preinterpret<Packet8f, Packet16f>(const Packet16f& a) {
|
||||
return _mm512_castps512_ps256(a);
|
||||
}
|
||||
|
||||
template<> EIGEN_STRONG_INLINE Packet16f preinterpret<Packet16f, Packet16f>(const Packet16f& a) {
|
||||
return a;
|
||||
}
|
||||
|
||||
template<> EIGEN_STRONG_INLINE Packet8d preinterpret<Packet8d, Packet8d>(const Packet8d& a) {
|
||||
return a;
|
||||
}
|
||||
|
||||
template <>
|
||||
struct type_casting_traits<half, float> {
|
||||
enum {
|
||||
|
@ -65,21 +65,6 @@ EIGEN_ALWAYS_INLINE auto remMask(int64_t m) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
template<typename T1, typename T2>
|
||||
EIGEN_ALWAYS_INLINE T2 castPacket(T1 &a) {
|
||||
return reinterpret_cast<T2>(a);
|
||||
}
|
||||
|
||||
template<>
|
||||
EIGEN_ALWAYS_INLINE vecHalfFloat castPacket(vecFullFloat &a) {
|
||||
return _mm512_castps512_ps256(a);
|
||||
}
|
||||
|
||||
template<>
|
||||
EIGEN_ALWAYS_INLINE vecFullDouble castPacket(vecFullDouble &a) {
|
||||
return a;
|
||||
}
|
||||
|
||||
/***
|
||||
* Unrolls for tranposed C stores
|
||||
*/
|
||||
@ -118,7 +103,7 @@ public:
|
||||
pstoreu<Scalar>(
|
||||
C_arr + LDC*startN,
|
||||
padd(ploadu<vecHalf>((const Scalar*)C_arr + LDC*startN, remMask<EIGEN_AVX_MAX_NUM_ROW>(remM_)),
|
||||
castPacket<vec,vecHalf>(zmm.packet[packetIndexOffset + (unrollN/PacketSize)*startN]),
|
||||
preinterpret<vecHalf>(zmm.packet[packetIndexOffset + (unrollN/PacketSize)*startN]),
|
||||
remMask<EIGEN_AVX_MAX_NUM_ROW>(remM_)),
|
||||
remMask<EIGEN_AVX_MAX_NUM_ROW>(remM_));
|
||||
}
|
||||
@ -126,27 +111,30 @@ public:
|
||||
pstoreu<Scalar>(
|
||||
C_arr + LDC*startN,
|
||||
padd(ploadu<vecHalf>((const Scalar*)C_arr + LDC*startN),
|
||||
castPacket<vec,vecHalf>(zmm.packet[packetIndexOffset + (unrollN/PacketSize)*startN])));
|
||||
preinterpret<vecHalf>(zmm.packet[packetIndexOffset + (unrollN/PacketSize)*startN])));
|
||||
}
|
||||
}
|
||||
else {
|
||||
else { // This block is only needed for fp32 case
|
||||
// Reinterpret as __m512 for _mm512_shuffle_f32x4
|
||||
vecFullFloat zmm2vecFullFloat = preinterpret<vecFullFloat>(
|
||||
zmm.packet[packetIndexOffset + (unrollN/PacketSize)*(startN - EIGEN_AVX_MAX_NUM_ROW)]);
|
||||
// Swap lower and upper half of avx register.
|
||||
zmm.packet[packetIndexOffset + (unrollN/PacketSize)*(startN - EIGEN_AVX_MAX_NUM_ROW)] =
|
||||
_mm512_shuffle_f32x4(
|
||||
zmm.packet[packetIndexOffset + (unrollN/PacketSize)*(startN - EIGEN_AVX_MAX_NUM_ROW)],
|
||||
zmm.packet[packetIndexOffset + (unrollN/PacketSize)*(startN - EIGEN_AVX_MAX_NUM_ROW)], 0b01001110);
|
||||
preinterpret<vec>(_mm512_shuffle_f32x4(zmm2vecFullFloat, zmm2vecFullFloat, 0b01001110));
|
||||
|
||||
EIGEN_IF_CONSTEXPR(remM) {
|
||||
pstoreu<Scalar>(
|
||||
C_arr + LDC*startN,
|
||||
padd(ploadu<vecHalf>((const Scalar*)C_arr + LDC*startN,
|
||||
remMask<EIGEN_AVX_MAX_NUM_ROW>(remM_)),
|
||||
castPacket<vec,vecHalf>(zmm.packet[packetIndexOffset + (unrollN/PacketSize)*(startN-EIGEN_AVX_MAX_NUM_ROW)])),
|
||||
preinterpret<vecHalf>(zmm.packet[packetIndexOffset + (unrollN/PacketSize)*(startN-EIGEN_AVX_MAX_NUM_ROW)])),
|
||||
remMask<EIGEN_AVX_MAX_NUM_ROW>(remM_));
|
||||
}
|
||||
else {
|
||||
pstoreu<Scalar>(
|
||||
C_arr + LDC*startN,
|
||||
padd(ploadu<vecHalf>((const Scalar*)C_arr + LDC*startN),
|
||||
castPacket<vec,vecHalf>(zmm.packet[packetIndexOffset + (unrollN/PacketSize)*(startN-EIGEN_AVX_MAX_NUM_ROW)])));
|
||||
preinterpret<vecHalf>(zmm.packet[packetIndexOffset + (unrollN/PacketSize)*(startN-EIGEN_AVX_MAX_NUM_ROW)])));
|
||||
}
|
||||
}
|
||||
aux_storeC<endN, counter - 1, unrollN, packetIndexOffset, remM>(C_arr, LDC, zmm, remM_);
|
||||
@ -234,7 +222,7 @@ public:
|
||||
* used as temps for transposing.
|
||||
*
|
||||
* Blocks will be of size Lx{U1,U2,U3}. packetIndexOffset is used to index between these subblocks
|
||||
* For fp32, PacketSize = 2*EIGEN_AVX_MAX_NUM_ROW, so we cast packets to packets half the size (zmm -> ymm).
|
||||
* For fp32, PacketSize = 2*EIGEN_AVX_MAX_NUM_ROW, so we reinterpret packets as packets half the size (zmm -> ymm).
|
||||
*/
|
||||
template <typename Scalar>
|
||||
class transB {
|
||||
|
Loading…
x
Reference in New Issue
Block a user