Add AVX512 s/dgemm optimizations for compute kernel (2nd try)

This commit is contained in:
aaraujom 2022-05-28 02:00:21 +00:00 committed by Antonio Sánchez
parent 510f6b9f15
commit d49ede4dc4
9 changed files with 1377 additions and 24 deletions

View File

@ -356,6 +356,10 @@ using std::ptrdiff_t;
#include "src/Core/arch/NEON/GeneralBlockPanelKernel.h" #include "src/Core/arch/NEON/GeneralBlockPanelKernel.h"
#endif #endif
#if defined(EIGEN_VECTORIZE_AVX512)
#include "src/Core/arch/AVX512/GemmKernel.h"
#endif
#include "src/Core/BooleanRedux.h" #include "src/Core/BooleanRedux.h"
#include "src/Core/Select.h" #include "src/Core/Select.h"
#include "src/Core/VectorwiseOp.h" #include "src/Core/VectorwiseOp.h"

File diff suppressed because it is too large Load Diff

View File

@ -927,6 +927,35 @@ EIGEN_STRONG_INLINE void pstoreu<double>(double* to, const Packet8d& from, uint8
EIGEN_DEBUG_UNALIGNED_STORE return _mm512_mask_storeu_pd(to, mask, from); EIGEN_DEBUG_UNALIGNED_STORE return _mm512_mask_storeu_pd(to, mask, from);
} }
template <typename Scalar, typename Packet>
EIGEN_DEVICE_FUNC inline Packet pgather(const Packet& src, const Scalar* from,
Index stride, typename unpacket_traits<Packet>::mask_t umask);
template <>
EIGEN_DEVICE_FUNC inline Packet16f pgather<float, Packet16f>(const Packet16f& src,
const float* from,
Index stride,
uint16_t umask) {
Packet16i stride_vector = _mm512_set1_epi32(convert_index<int>(stride));
Packet16i stride_multiplier =
_mm512_set_epi32(15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0);
Packet16i indices = _mm512_mullo_epi32(stride_vector, stride_multiplier);
__mmask16 mask = static_cast<__mmask16>(umask);
return _mm512_mask_i32gather_ps(src, mask, indices, from, 4);
}
template <>
EIGEN_DEVICE_FUNC inline Packet8d pgather<double, Packet8d>(const Packet8d& src,
const double* from,
Index stride,
uint8_t umask) {
Packet8i stride_vector = _mm256_set1_epi32(convert_index<int>(stride));
Packet8i stride_multiplier = _mm256_set_epi32(7, 6, 5, 4, 3, 2, 1, 0);
Packet8i indices = _mm256_mullo_epi32(stride_vector, stride_multiplier);
__mmask8 mask = static_cast<__mmask8>(umask);
return _mm512_mask_i32gather_pd(src, mask, indices, from, 8);
}
template <> template <>
EIGEN_DEVICE_FUNC inline Packet16f pgather<float, Packet16f>(const float* from, EIGEN_DEVICE_FUNC inline Packet16f pgather<float, Packet16f>(const float* from,
Index stride) { Index stride) {
@ -956,6 +985,33 @@ EIGEN_DEVICE_FUNC inline Packet16i pgather<int, Packet16i>(const int* from,
return _mm512_i32gather_epi32(indices, from, 4); return _mm512_i32gather_epi32(indices, from, 4);
} }
template <typename Scalar, typename Packet>
EIGEN_DEVICE_FUNC inline void pscatter(Scalar* to, const Packet& from,
Index stride, typename unpacket_traits<Packet>::mask_t umask);
template <>
EIGEN_DEVICE_FUNC inline void pscatter<float, Packet16f>(float* to,
const Packet16f& from,
Index stride,
uint16_t umask) {
Packet16i stride_vector = _mm512_set1_epi32(convert_index<int>(stride));
Packet16i stride_multiplier =
_mm512_set_epi32(15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0);
Packet16i indices = _mm512_mullo_epi32(stride_vector, stride_multiplier);
__mmask16 mask = static_cast<__mmask16>(umask);
_mm512_mask_i32scatter_ps(to, mask, indices, from, 4);
}
template <>
EIGEN_DEVICE_FUNC inline void pscatter<double, Packet8d>(double* to,
const Packet8d& from,
Index stride,
uint8_t umask) {
Packet8i stride_vector = _mm256_set1_epi32(convert_index<int>(stride));
Packet8i stride_multiplier = _mm256_set_epi32(7, 6, 5, 4, 3, 2, 1, 0);
Packet8i indices = _mm256_mullo_epi32(stride_vector, stride_multiplier);
__mmask8 mask = static_cast<__mmask8>(umask);
_mm512_mask_i32scatter_pd(to, mask, indices, from, 8);
}
template <> template <>
EIGEN_DEVICE_FUNC inline void pscatter<float, Packet16f>(float* to, EIGEN_DEVICE_FUNC inline void pscatter<float, Packet16f>(float* to,
const Packet16f& from, const Packet16f& from,
@ -1451,27 +1507,23 @@ EIGEN_DEVICE_FUNC inline void ptranspose(PacketBlock<Packet16f, 8>& kernel) {
kernel.packet[6] = _mm512_castpd_ps(_mm512_unpacklo_pd(_mm512_castps_pd(T5),_mm512_castps_pd(T7))); kernel.packet[6] = _mm512_castpd_ps(_mm512_unpacklo_pd(_mm512_castps_pd(T5),_mm512_castps_pd(T7)));
kernel.packet[7] = _mm512_castpd_ps(_mm512_unpackhi_pd(_mm512_castps_pd(T5),_mm512_castps_pd(T7))); kernel.packet[7] = _mm512_castpd_ps(_mm512_unpackhi_pd(_mm512_castps_pd(T5),_mm512_castps_pd(T7)));
T0 = _mm512_castpd_ps(_mm512_permutex_pd(_mm512_castps_pd(kernel.packet[4]), 0x4E)); T0 = _mm512_shuffle_f32x4(kernel.packet[0], kernel.packet[4], 0x44);
T0 = _mm512_mask_blend_ps(0xF0F0, kernel.packet[0], T0); T1 = _mm512_shuffle_f32x4(kernel.packet[0], kernel.packet[4], 0xee);
T4 = _mm512_castpd_ps(_mm512_permutex_pd(_mm512_castps_pd(kernel.packet[0]), 0x4E)); T2 = _mm512_shuffle_f32x4(kernel.packet[1], kernel.packet[5], 0x44);
T4 = _mm512_mask_blend_ps(0xF0F0, T4, kernel.packet[4]); T3 = _mm512_shuffle_f32x4(kernel.packet[1], kernel.packet[5], 0xee);
T1 = _mm512_castpd_ps(_mm512_permutex_pd(_mm512_castps_pd(kernel.packet[5]), 0x4E)); T4 = _mm512_shuffle_f32x4(kernel.packet[2], kernel.packet[6], 0x44);
T1 = _mm512_mask_blend_ps(0xF0F0, kernel.packet[1], T1); T5 = _mm512_shuffle_f32x4(kernel.packet[2], kernel.packet[6], 0xee);
T5 = _mm512_castpd_ps(_mm512_permutex_pd(_mm512_castps_pd(kernel.packet[1]), 0x4E)); T6 = _mm512_shuffle_f32x4(kernel.packet[3], kernel.packet[7], 0x44);
T5 = _mm512_mask_blend_ps(0xF0F0, T5, kernel.packet[5]); T7 = _mm512_shuffle_f32x4(kernel.packet[3], kernel.packet[7], 0xee);
T2 = _mm512_castpd_ps(_mm512_permutex_pd(_mm512_castps_pd(kernel.packet[6]), 0x4E));
T2 = _mm512_mask_blend_ps(0xF0F0, kernel.packet[2], T2);
T6 = _mm512_castpd_ps(_mm512_permutex_pd(_mm512_castps_pd(kernel.packet[2]), 0x4E));
T6 = _mm512_mask_blend_ps(0xF0F0, T6, kernel.packet[6]);
T3 = _mm512_castpd_ps(_mm512_permutex_pd(_mm512_castps_pd(kernel.packet[7]), 0x4E));
T3 = _mm512_mask_blend_ps(0xF0F0, kernel.packet[3], T3);
T7 = _mm512_castpd_ps(_mm512_permutex_pd(_mm512_castps_pd(kernel.packet[3]), 0x4E));
T7 = _mm512_mask_blend_ps(0xF0F0, T7, kernel.packet[7]);
kernel.packet[0] = T0; kernel.packet[1] = T1; kernel.packet[0] = _mm512_shuffle_f32x4(T0, T2, 0x88);
kernel.packet[2] = T2; kernel.packet[3] = T3; kernel.packet[2] = _mm512_shuffle_f32x4(T0, T2, 0xdd);
kernel.packet[4] = T4; kernel.packet[5] = T5; kernel.packet[1] = _mm512_shuffle_f32x4(T4, T6, 0x88);
kernel.packet[6] = T6; kernel.packet[7] = T7; kernel.packet[3] = _mm512_shuffle_f32x4(T4, T6, 0xdd);
kernel.packet[4] = _mm512_shuffle_f32x4(T1, T3, 0x88);
kernel.packet[6] = _mm512_shuffle_f32x4(T1, T3, 0xdd);
kernel.packet[5] = _mm512_shuffle_f32x4(T5, T7, 0x88);
kernel.packet[7] = _mm512_shuffle_f32x4(T5, T7, 0xdd);
} }
EIGEN_DEVICE_FUNC inline void ptranspose(PacketBlock<Packet16f, 4>& kernel) { EIGEN_DEVICE_FUNC inline void ptranspose(PacketBlock<Packet16f, 4>& kernel) {

View File

@ -65,6 +65,57 @@ EIGEN_ALWAYS_INLINE auto remMask(int64_t m) {
return 0; return 0;
} }
template <typename Packet>
EIGEN_ALWAYS_INLINE void trans8x8blocks(PacketBlock<Packet, 8>& kernel);
template <>
EIGEN_ALWAYS_INLINE void trans8x8blocks(PacketBlock<Packet16f, 8>& kernel) {
__m512 T0 = _mm512_unpacklo_ps(kernel.packet[0],kernel.packet[1]);
__m512 T1 = _mm512_unpackhi_ps(kernel.packet[0],kernel.packet[1]);
__m512 T2 = _mm512_unpacklo_ps(kernel.packet[2],kernel.packet[3]);
__m512 T3 = _mm512_unpackhi_ps(kernel.packet[2],kernel.packet[3]);
__m512 T4 = _mm512_unpacklo_ps(kernel.packet[4],kernel.packet[5]);
__m512 T5 = _mm512_unpackhi_ps(kernel.packet[4],kernel.packet[5]);
__m512 T6 = _mm512_unpacklo_ps(kernel.packet[6],kernel.packet[7]);
__m512 T7 = _mm512_unpackhi_ps(kernel.packet[6],kernel.packet[7]);
kernel.packet[0] = _mm512_castpd_ps(_mm512_unpacklo_pd(_mm512_castps_pd(T0),_mm512_castps_pd(T2)));
kernel.packet[1] = _mm512_castpd_ps(_mm512_unpackhi_pd(_mm512_castps_pd(T0),_mm512_castps_pd(T2)));
kernel.packet[2] = _mm512_castpd_ps(_mm512_unpacklo_pd(_mm512_castps_pd(T1),_mm512_castps_pd(T3)));
kernel.packet[3] = _mm512_castpd_ps(_mm512_unpackhi_pd(_mm512_castps_pd(T1),_mm512_castps_pd(T3)));
kernel.packet[4] = _mm512_castpd_ps(_mm512_unpacklo_pd(_mm512_castps_pd(T4),_mm512_castps_pd(T6)));
kernel.packet[5] = _mm512_castpd_ps(_mm512_unpackhi_pd(_mm512_castps_pd(T4),_mm512_castps_pd(T6)));
kernel.packet[6] = _mm512_castpd_ps(_mm512_unpacklo_pd(_mm512_castps_pd(T5),_mm512_castps_pd(T7)));
kernel.packet[7] = _mm512_castpd_ps(_mm512_unpackhi_pd(_mm512_castps_pd(T5),_mm512_castps_pd(T7)));
T0 = _mm512_castpd_ps(_mm512_permutex_pd(_mm512_castps_pd(kernel.packet[4]), 0x4E));
T0 = _mm512_mask_blend_ps(0xF0F0, kernel.packet[0], T0);
T4 = _mm512_castpd_ps(_mm512_permutex_pd(_mm512_castps_pd(kernel.packet[0]), 0x4E));
T4 = _mm512_mask_blend_ps(0xF0F0, T4, kernel.packet[4]);
T1 = _mm512_castpd_ps(_mm512_permutex_pd(_mm512_castps_pd(kernel.packet[5]), 0x4E));
T1 = _mm512_mask_blend_ps(0xF0F0, kernel.packet[1], T1);
T5 = _mm512_castpd_ps(_mm512_permutex_pd(_mm512_castps_pd(kernel.packet[1]), 0x4E));
T5 = _mm512_mask_blend_ps(0xF0F0, T5, kernel.packet[5]);
T2 = _mm512_castpd_ps(_mm512_permutex_pd(_mm512_castps_pd(kernel.packet[6]), 0x4E));
T2 = _mm512_mask_blend_ps(0xF0F0, kernel.packet[2], T2);
T6 = _mm512_castpd_ps(_mm512_permutex_pd(_mm512_castps_pd(kernel.packet[2]), 0x4E));
T6 = _mm512_mask_blend_ps(0xF0F0, T6, kernel.packet[6]);
T3 = _mm512_castpd_ps(_mm512_permutex_pd(_mm512_castps_pd(kernel.packet[7]), 0x4E));
T3 = _mm512_mask_blend_ps(0xF0F0, kernel.packet[3], T3);
T7 = _mm512_castpd_ps(_mm512_permutex_pd(_mm512_castps_pd(kernel.packet[3]), 0x4E));
T7 = _mm512_mask_blend_ps(0xF0F0, T7, kernel.packet[7]);
kernel.packet[0] = T0; kernel.packet[1] = T1;
kernel.packet[2] = T2; kernel.packet[3] = T3;
kernel.packet[4] = T4; kernel.packet[5] = T5;
kernel.packet[6] = T6; kernel.packet[7] = T7;
}
template <>
EIGEN_ALWAYS_INLINE void trans8x8blocks(PacketBlock<Packet8d, 8>& kernel) {
ptranspose(kernel);
}
/*** /***
* Unrolls for tranposed C stores * Unrolls for tranposed C stores
*/ */
@ -198,7 +249,7 @@ public:
r.packet[5] = zmm.packet[packetIndexOffset + zmmStride*5]; r.packet[5] = zmm.packet[packetIndexOffset + zmmStride*5];
r.packet[6] = zmm.packet[packetIndexOffset + zmmStride*6]; r.packet[6] = zmm.packet[packetIndexOffset + zmmStride*6];
r.packet[7] = zmm.packet[packetIndexOffset + zmmStride*7]; r.packet[7] = zmm.packet[packetIndexOffset + zmmStride*7];
ptranspose(r); trans8x8blocks(r);
zmm.packet[packetIndexOffset + zmmStride*0] = r.packet[0]; zmm.packet[packetIndexOffset + zmmStride*0] = r.packet[0];
zmm.packet[packetIndexOffset + zmmStride*1] = r.packet[1]; zmm.packet[packetIndexOffset + zmmStride*1] = r.packet[1];
zmm.packet[packetIndexOffset + zmmStride*2] = r.packet[2]; zmm.packet[packetIndexOffset + zmmStride*2] = r.packet[2];

View File

@ -44,6 +44,34 @@ template<> EIGEN_STRONG_INLINE Packet8f preinterpret<Packet8f, Packet16f>(const
return _mm512_castps512_ps256(a); return _mm512_castps512_ps256(a);
} }
template<> EIGEN_STRONG_INLINE Packet4f preinterpret<Packet4f, Packet16f>(const Packet16f& a) {
return _mm512_castps512_ps128(a);
}
template<> EIGEN_STRONG_INLINE Packet4d preinterpret<Packet4d, Packet8d>(const Packet8d& a) {
return _mm512_castpd512_pd256(a);
}
template<> EIGEN_STRONG_INLINE Packet2d preinterpret<Packet2d, Packet8d>(const Packet8d& a) {
return _mm512_castpd512_pd128(a);
}
template<> EIGEN_STRONG_INLINE Packet16f preinterpret<Packet16f, Packet8f>(const Packet8f& a) {
return _mm512_castps256_ps512(a);
}
template<> EIGEN_STRONG_INLINE Packet16f preinterpret<Packet16f, Packet4f>(const Packet4f& a) {
return _mm512_castps128_ps512(a);
}
template<> EIGEN_STRONG_INLINE Packet8d preinterpret<Packet8d, Packet4d>(const Packet4d& a) {
return _mm512_castpd256_pd512(a);
}
template<> EIGEN_STRONG_INLINE Packet8d preinterpret<Packet8d, Packet2d>(const Packet2d& a) {
return _mm512_castpd128_pd512(a);
}
template<> EIGEN_STRONG_INLINE Packet16f preinterpret<Packet16f, Packet16f>(const Packet16f& a) { template<> EIGEN_STRONG_INLINE Packet16f preinterpret<Packet16f, Packet16f>(const Packet16f& a) {
return a; return a;
} }

View File

@ -285,6 +285,10 @@ template<> EIGEN_STRONG_INLINE Packet4i padd<Packet4i>(const Packet4i& a, const
template<> EIGEN_STRONG_INLINE Packet16b padd<Packet16b>(const Packet16b& a, const Packet16b& b) { return _mm_or_si128(a,b); } template<> EIGEN_STRONG_INLINE Packet16b padd<Packet16b>(const Packet16b& a, const Packet16b& b) { return _mm_or_si128(a,b); }
template<typename Packet> EIGEN_STRONG_INLINE Packet padds(const Packet& a, const Packet& b);
template<> EIGEN_STRONG_INLINE Packet4f padds<Packet4f>(const Packet4f& a, const Packet4f& b) { return _mm_add_ss(a,b); }
template<> EIGEN_STRONG_INLINE Packet2d padds<Packet2d>(const Packet2d& a, const Packet2d& b) { return _mm_add_sd(a,b); }
template<> EIGEN_STRONG_INLINE Packet4f psub<Packet4f>(const Packet4f& a, const Packet4f& b) { return _mm_sub_ps(a,b); } template<> EIGEN_STRONG_INLINE Packet4f psub<Packet4f>(const Packet4f& a, const Packet4f& b) { return _mm_sub_ps(a,b); }
template<> EIGEN_STRONG_INLINE Packet2d psub<Packet2d>(const Packet2d& a, const Packet2d& b) { return _mm_sub_pd(a,b); } template<> EIGEN_STRONG_INLINE Packet2d psub<Packet2d>(const Packet2d& a, const Packet2d& b) { return _mm_sub_pd(a,b); }
template<> EIGEN_STRONG_INLINE Packet4i psub<Packet4i>(const Packet4i& a, const Packet4i& b) { return _mm_sub_epi32(a,b); } template<> EIGEN_STRONG_INLINE Packet4i psub<Packet4i>(const Packet4i& a, const Packet4i& b) { return _mm_sub_epi32(a,b); }
@ -370,6 +374,10 @@ template<> EIGEN_STRONG_INLINE Packet4f pnmadd(const Packet4f& a, const Packet4f
template<> EIGEN_STRONG_INLINE Packet2d pnmadd(const Packet2d& a, const Packet2d& b, const Packet2d& c) { return _mm_fnmadd_pd(a,b,c); } template<> EIGEN_STRONG_INLINE Packet2d pnmadd(const Packet2d& a, const Packet2d& b, const Packet2d& c) { return _mm_fnmadd_pd(a,b,c); }
template<> EIGEN_STRONG_INLINE Packet4f pnmsub(const Packet4f& a, const Packet4f& b, const Packet4f& c) { return _mm_fnmsub_ps(a,b,c); } template<> EIGEN_STRONG_INLINE Packet4f pnmsub(const Packet4f& a, const Packet4f& b, const Packet4f& c) { return _mm_fnmsub_ps(a,b,c); }
template<> EIGEN_STRONG_INLINE Packet2d pnmsub(const Packet2d& a, const Packet2d& b, const Packet2d& c) { return _mm_fnmsub_pd(a,b,c); } template<> EIGEN_STRONG_INLINE Packet2d pnmsub(const Packet2d& a, const Packet2d& b, const Packet2d& c) { return _mm_fnmsub_pd(a,b,c); }
template<typename Packet> EIGEN_STRONG_INLINE Packet pmadds(const Packet& a, const Packet& b, const Packet& c);
template<> EIGEN_STRONG_INLINE Packet4f pmadds<Packet4f>(const Packet4f& a, const Packet4f& b, const Packet4f& c) { return _mm_fmadd_ss(a,b,c); }
template<> EIGEN_STRONG_INLINE Packet2d pmadds<Packet2d>(const Packet2d& a, const Packet2d& b, const Packet2d& c) { return _mm_fmadd_sd(a,b,c); }
#endif #endif
#ifdef EIGEN_VECTORIZE_SSE4_1 #ifdef EIGEN_VECTORIZE_SSE4_1
@ -746,6 +754,15 @@ template<> EIGEN_STRONG_INLINE Packet16b ploadu<Packet16b>(const bool* from)
return _mm_loadu_si128(reinterpret_cast<const __m128i*>(from)); return _mm_loadu_si128(reinterpret_cast<const __m128i*>(from));
} }
// Load lower part of packet zero extending.
template<typename Packet> EIGEN_STRONG_INLINE Packet ploadl(const typename unpacket_traits<Packet>::type* from);
template<> EIGEN_STRONG_INLINE Packet4f ploadl<Packet4f>(const float* from) { EIGEN_DEBUG_UNALIGNED_LOAD return _mm_castpd_ps(_mm_load_sd(reinterpret_cast<const double*>(from))); }
template<> EIGEN_STRONG_INLINE Packet2d ploadl<Packet2d>(const double* from) { EIGEN_DEBUG_UNALIGNED_LOAD return _mm_load_sd(from); }
// Load scalar
template<typename Packet> EIGEN_STRONG_INLINE Packet ploads(const typename unpacket_traits<Packet>::type* from);
template<> EIGEN_STRONG_INLINE Packet4f ploads<Packet4f>(const float* from) { EIGEN_DEBUG_UNALIGNED_LOAD return _mm_load_ss(from); }
template<> EIGEN_STRONG_INLINE Packet2d ploads<Packet2d>(const double* from) { EIGEN_DEBUG_UNALIGNED_LOAD return _mm_load_sd(from); }
template<> EIGEN_STRONG_INLINE Packet4f ploaddup<Packet4f>(const float* from) template<> EIGEN_STRONG_INLINE Packet4f ploaddup<Packet4f>(const float* from)
{ {
@ -787,6 +804,14 @@ template<> EIGEN_STRONG_INLINE void pstoreu<float>(float* to, const Packet4f&
template<> EIGEN_STRONG_INLINE void pstoreu<int>(int* to, const Packet4i& from) { EIGEN_DEBUG_UNALIGNED_STORE _mm_storeu_si128(reinterpret_cast<__m128i*>(to), from); } template<> EIGEN_STRONG_INLINE void pstoreu<int>(int* to, const Packet4i& from) { EIGEN_DEBUG_UNALIGNED_STORE _mm_storeu_si128(reinterpret_cast<__m128i*>(to), from); }
template<> EIGEN_STRONG_INLINE void pstoreu<bool>(bool* to, const Packet16b& from) { EIGEN_DEBUG_ALIGNED_STORE _mm_storeu_si128(reinterpret_cast<__m128i*>(to), from); } template<> EIGEN_STRONG_INLINE void pstoreu<bool>(bool* to, const Packet16b& from) { EIGEN_DEBUG_ALIGNED_STORE _mm_storeu_si128(reinterpret_cast<__m128i*>(to), from); }
template<typename Scalar, typename Packet> EIGEN_STRONG_INLINE void pstorel(Scalar* to, const Packet& from);
template<> EIGEN_STRONG_INLINE void pstorel(float* to, const Packet4f& from) { EIGEN_DEBUG_UNALIGNED_STORE _mm_storel_pi(reinterpret_cast<__m64*>(to), from); }
template<> EIGEN_STRONG_INLINE void pstorel(double* to, const Packet2d& from) { EIGEN_DEBUG_UNALIGNED_STORE _mm_storel_pd(to, from); }
template<typename Scalar, typename Packet> EIGEN_STRONG_INLINE void pstores(Scalar* to, const Packet& from);
template<> EIGEN_STRONG_INLINE void pstores(float* to, const Packet4f& from) { EIGEN_DEBUG_UNALIGNED_STORE _mm_store_ss(to, from); }
template<> EIGEN_STRONG_INLINE void pstores(double* to, const Packet2d& from) { EIGEN_DEBUG_UNALIGNED_STORE _mm_store_sd(to, from); }
template<> EIGEN_DEVICE_FUNC inline Packet4f pgather<float, Packet4f>(const float* from, Index stride) template<> EIGEN_DEVICE_FUNC inline Packet4f pgather<float, Packet4f>(const float* from, Index stride)
{ {
return _mm_set_ps(from[3*stride], from[2*stride], from[1*stride], from[0*stride]); return _mm_set_ps(from[3*stride], from[2*stride], from[1*stride], from[0*stride]);

View File

@ -71,6 +71,14 @@ template<> EIGEN_STRONG_INLINE Packet2d pcast<Packet4f, Packet2d>(const Packet4f
return _mm_cvtps_pd(a); return _mm_cvtps_pd(a);
} }
template<> EIGEN_STRONG_INLINE Packet2d preinterpret<Packet2d, Packet4f>(const Packet4f& a) {
return _mm_castps_pd(a);
}
template<> EIGEN_STRONG_INLINE Packet4f preinterpret<Packet4f, Packet2d>(const Packet2d& a) {
return _mm_castpd_ps(a);
}
template<> EIGEN_STRONG_INLINE Packet4i preinterpret<Packet4i,Packet4f>(const Packet4f& a) { template<> EIGEN_STRONG_INLINE Packet4i preinterpret<Packet4i,Packet4f>(const Packet4f& a) {
return _mm_castps_si128(a); return _mm_castps_si128(a);
} }

View File

@ -287,7 +287,6 @@ class gemm_blocking_space<StorageOrder,LhsScalar_,RhsScalar_,MaxRows, MaxCols, M
}; };
typedef std::conditional_t<Transpose,RhsScalar_,LhsScalar_> LhsScalar; typedef std::conditional_t<Transpose,RhsScalar_,LhsScalar_> LhsScalar;
typedef std::conditional_t<Transpose,LhsScalar_,RhsScalar_> RhsScalar; typedef std::conditional_t<Transpose,LhsScalar_,RhsScalar_> RhsScalar;
typedef gebp_traits<LhsScalar,RhsScalar> Traits;
enum { enum {
SizeA = ActualRows * MaxDepth, SizeA = ActualRows * MaxDepth,
SizeB = ActualCols * MaxDepth SizeB = ActualCols * MaxDepth
@ -336,7 +335,6 @@ class gemm_blocking_space<StorageOrder,LhsScalar_,RhsScalar_,MaxRows, MaxCols, M
}; };
typedef std::conditional_t<Transpose,RhsScalar_,LhsScalar_> LhsScalar; typedef std::conditional_t<Transpose,RhsScalar_,LhsScalar_> LhsScalar;
typedef std::conditional_t<Transpose,LhsScalar_,RhsScalar_> RhsScalar; typedef std::conditional_t<Transpose,LhsScalar_,RhsScalar_> RhsScalar;
typedef gebp_traits<LhsScalar,RhsScalar> Traits;
Index m_sizeA; Index m_sizeA;
Index m_sizeB; Index m_sizeB;

View File

@ -229,6 +229,7 @@ public:
} }
EIGEN_DEVICE_FUNC const Index stride() const { return m_stride; } EIGEN_DEVICE_FUNC const Index stride() const { return m_stride; }
EIGEN_DEVICE_FUNC const Index incr() const { return 1; }
EIGEN_DEVICE_FUNC const Scalar* data() const { return m_data; } EIGEN_DEVICE_FUNC const Scalar* data() const { return m_data; }
EIGEN_DEVICE_FUNC Index firstAligned(Index size) const { EIGEN_DEVICE_FUNC Index firstAligned(Index size) const {
@ -402,6 +403,10 @@ public:
storePacketBlock_helper<SubPacket, Scalar, n, n-1> spb; storePacketBlock_helper<SubPacket, Scalar, n, n-1> spb;
spb.store(this, i,j,block); spb.store(this, i,j,block);
} }
EIGEN_DEVICE_FUNC const Index stride() const { return m_stride; }
EIGEN_DEVICE_FUNC const Index incr() const { return m_incr.value(); }
EIGEN_DEVICE_FUNC Scalar* data() const { return m_data; }
protected: protected:
Scalar* EIGEN_RESTRICT m_data; Scalar* EIGEN_RESTRICT m_data;
const Index m_stride; const Index m_stride;