mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-08-08 09:49:03 +08:00
Add missing explicit reinterprets
This commit is contained in:
parent
cd3c81c3bc
commit
0611f7fff0
@ -32,6 +32,26 @@ template<> EIGEN_STRONG_INLINE Packet16f preinterpret<Packet16f, Packet16i>(cons
|
|||||||
return _mm512_castsi512_ps(a);
|
return _mm512_castsi512_ps(a);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template<> EIGEN_STRONG_INLINE Packet8d preinterpret<Packet8d, Packet16f>(const Packet16f& a) {
|
||||||
|
return _mm512_castps_pd(a);
|
||||||
|
}
|
||||||
|
|
||||||
|
template<> EIGEN_STRONG_INLINE Packet16f preinterpret<Packet16f, Packet8d>(const Packet8d& a) {
|
||||||
|
return _mm512_castpd_ps(a);
|
||||||
|
}
|
||||||
|
|
||||||
|
template<> EIGEN_STRONG_INLINE Packet8f preinterpret<Packet8f, Packet16f>(const Packet16f& a) {
|
||||||
|
return _mm512_castps512_ps256(a);
|
||||||
|
}
|
||||||
|
|
||||||
|
template<> EIGEN_STRONG_INLINE Packet16f preinterpret<Packet16f, Packet16f>(const Packet16f& a) {
|
||||||
|
return a;
|
||||||
|
}
|
||||||
|
|
||||||
|
template<> EIGEN_STRONG_INLINE Packet8d preinterpret<Packet8d, Packet8d>(const Packet8d& a) {
|
||||||
|
return a;
|
||||||
|
}
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
struct type_casting_traits<half, float> {
|
struct type_casting_traits<half, float> {
|
||||||
enum {
|
enum {
|
||||||
|
@ -65,21 +65,6 @@ EIGEN_ALWAYS_INLINE auto remMask(int64_t m) {
|
|||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
template<typename T1, typename T2>
|
|
||||||
EIGEN_ALWAYS_INLINE T2 castPacket(T1 &a) {
|
|
||||||
return reinterpret_cast<T2>(a);
|
|
||||||
}
|
|
||||||
|
|
||||||
template<>
|
|
||||||
EIGEN_ALWAYS_INLINE vecHalfFloat castPacket(vecFullFloat &a) {
|
|
||||||
return _mm512_castps512_ps256(a);
|
|
||||||
}
|
|
||||||
|
|
||||||
template<>
|
|
||||||
EIGEN_ALWAYS_INLINE vecFullDouble castPacket(vecFullDouble &a) {
|
|
||||||
return a;
|
|
||||||
}
|
|
||||||
|
|
||||||
/***
|
/***
|
||||||
* Unrolls for tranposed C stores
|
* Unrolls for tranposed C stores
|
||||||
*/
|
*/
|
||||||
@ -118,7 +103,7 @@ public:
|
|||||||
pstoreu<Scalar>(
|
pstoreu<Scalar>(
|
||||||
C_arr + LDC*startN,
|
C_arr + LDC*startN,
|
||||||
padd(ploadu<vecHalf>((const Scalar*)C_arr + LDC*startN, remMask<EIGEN_AVX_MAX_NUM_ROW>(remM_)),
|
padd(ploadu<vecHalf>((const Scalar*)C_arr + LDC*startN, remMask<EIGEN_AVX_MAX_NUM_ROW>(remM_)),
|
||||||
castPacket<vec,vecHalf>(zmm.packet[packetIndexOffset + (unrollN/PacketSize)*startN]),
|
preinterpret<vecHalf>(zmm.packet[packetIndexOffset + (unrollN/PacketSize)*startN]),
|
||||||
remMask<EIGEN_AVX_MAX_NUM_ROW>(remM_)),
|
remMask<EIGEN_AVX_MAX_NUM_ROW>(remM_)),
|
||||||
remMask<EIGEN_AVX_MAX_NUM_ROW>(remM_));
|
remMask<EIGEN_AVX_MAX_NUM_ROW>(remM_));
|
||||||
}
|
}
|
||||||
@ -126,27 +111,30 @@ public:
|
|||||||
pstoreu<Scalar>(
|
pstoreu<Scalar>(
|
||||||
C_arr + LDC*startN,
|
C_arr + LDC*startN,
|
||||||
padd(ploadu<vecHalf>((const Scalar*)C_arr + LDC*startN),
|
padd(ploadu<vecHalf>((const Scalar*)C_arr + LDC*startN),
|
||||||
castPacket<vec,vecHalf>(zmm.packet[packetIndexOffset + (unrollN/PacketSize)*startN])));
|
preinterpret<vecHalf>(zmm.packet[packetIndexOffset + (unrollN/PacketSize)*startN])));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
else {
|
else { // This block is only needed for fp32 case
|
||||||
|
// Reinterpret as __m512 for _mm512_shuffle_f32x4
|
||||||
|
vecFullFloat zmm2vecFullFloat = preinterpret<vecFullFloat>(
|
||||||
|
zmm.packet[packetIndexOffset + (unrollN/PacketSize)*(startN - EIGEN_AVX_MAX_NUM_ROW)]);
|
||||||
|
// Swap lower and upper half of avx register.
|
||||||
zmm.packet[packetIndexOffset + (unrollN/PacketSize)*(startN - EIGEN_AVX_MAX_NUM_ROW)] =
|
zmm.packet[packetIndexOffset + (unrollN/PacketSize)*(startN - EIGEN_AVX_MAX_NUM_ROW)] =
|
||||||
_mm512_shuffle_f32x4(
|
preinterpret<vec>(_mm512_shuffle_f32x4(zmm2vecFullFloat, zmm2vecFullFloat, 0b01001110));
|
||||||
zmm.packet[packetIndexOffset + (unrollN/PacketSize)*(startN - EIGEN_AVX_MAX_NUM_ROW)],
|
|
||||||
zmm.packet[packetIndexOffset + (unrollN/PacketSize)*(startN - EIGEN_AVX_MAX_NUM_ROW)], 0b01001110);
|
|
||||||
EIGEN_IF_CONSTEXPR(remM) {
|
EIGEN_IF_CONSTEXPR(remM) {
|
||||||
pstoreu<Scalar>(
|
pstoreu<Scalar>(
|
||||||
C_arr + LDC*startN,
|
C_arr + LDC*startN,
|
||||||
padd(ploadu<vecHalf>((const Scalar*)C_arr + LDC*startN,
|
padd(ploadu<vecHalf>((const Scalar*)C_arr + LDC*startN,
|
||||||
remMask<EIGEN_AVX_MAX_NUM_ROW>(remM_)),
|
remMask<EIGEN_AVX_MAX_NUM_ROW>(remM_)),
|
||||||
castPacket<vec,vecHalf>(zmm.packet[packetIndexOffset + (unrollN/PacketSize)*(startN-EIGEN_AVX_MAX_NUM_ROW)])),
|
preinterpret<vecHalf>(zmm.packet[packetIndexOffset + (unrollN/PacketSize)*(startN-EIGEN_AVX_MAX_NUM_ROW)])),
|
||||||
remMask<EIGEN_AVX_MAX_NUM_ROW>(remM_));
|
remMask<EIGEN_AVX_MAX_NUM_ROW>(remM_));
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
pstoreu<Scalar>(
|
pstoreu<Scalar>(
|
||||||
C_arr + LDC*startN,
|
C_arr + LDC*startN,
|
||||||
padd(ploadu<vecHalf>((const Scalar*)C_arr + LDC*startN),
|
padd(ploadu<vecHalf>((const Scalar*)C_arr + LDC*startN),
|
||||||
castPacket<vec,vecHalf>(zmm.packet[packetIndexOffset + (unrollN/PacketSize)*(startN-EIGEN_AVX_MAX_NUM_ROW)])));
|
preinterpret<vecHalf>(zmm.packet[packetIndexOffset + (unrollN/PacketSize)*(startN-EIGEN_AVX_MAX_NUM_ROW)])));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
aux_storeC<endN, counter - 1, unrollN, packetIndexOffset, remM>(C_arr, LDC, zmm, remM_);
|
aux_storeC<endN, counter - 1, unrollN, packetIndexOffset, remM>(C_arr, LDC, zmm, remM_);
|
||||||
@ -234,7 +222,7 @@ public:
|
|||||||
* used as temps for transposing.
|
* used as temps for transposing.
|
||||||
*
|
*
|
||||||
* Blocks will be of size Lx{U1,U2,U3}. packetIndexOffset is used to index between these subblocks
|
* Blocks will be of size Lx{U1,U2,U3}. packetIndexOffset is used to index between these subblocks
|
||||||
* For fp32, PacketSize = 2*EIGEN_AVX_MAX_NUM_ROW, so we cast packets to packets half the size (zmm -> ymm).
|
* For fp32, PacketSize = 2*EIGEN_AVX_MAX_NUM_ROW, so we reinterpret packets as packets half the size (zmm -> ymm).
|
||||||
*/
|
*/
|
||||||
template <typename Scalar>
|
template <typename Scalar>
|
||||||
class transB {
|
class transB {
|
||||||
|
Loading…
x
Reference in New Issue
Block a user