mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-04-30 07:44:10 +08:00
Add AVX512 s/dgemm optimizations for compute kernel (2nd try)
This commit is contained in:
parent
510f6b9f15
commit
d49ede4dc4
@ -356,6 +356,10 @@ using std::ptrdiff_t;
|
||||
#include "src/Core/arch/NEON/GeneralBlockPanelKernel.h"
|
||||
#endif
|
||||
|
||||
#if defined(EIGEN_VECTORIZE_AVX512)
|
||||
#include "src/Core/arch/AVX512/GemmKernel.h"
|
||||
#endif
|
||||
|
||||
#include "src/Core/BooleanRedux.h"
|
||||
#include "src/Core/Select.h"
|
||||
#include "src/Core/VectorwiseOp.h"
|
||||
|
1182
Eigen/src/Core/arch/AVX512/GemmKernel.h
Normal file
1182
Eigen/src/Core/arch/AVX512/GemmKernel.h
Normal file
File diff suppressed because it is too large
Load Diff
@ -927,6 +927,35 @@ EIGEN_STRONG_INLINE void pstoreu<double>(double* to, const Packet8d& from, uint8
|
||||
EIGEN_DEBUG_UNALIGNED_STORE return _mm512_mask_storeu_pd(to, mask, from);
|
||||
}
|
||||
|
||||
template <typename Scalar, typename Packet>
|
||||
EIGEN_DEVICE_FUNC inline Packet pgather(const Packet& src, const Scalar* from,
|
||||
Index stride, typename unpacket_traits<Packet>::mask_t umask);
|
||||
template <>
|
||||
EIGEN_DEVICE_FUNC inline Packet16f pgather<float, Packet16f>(const Packet16f& src,
|
||||
const float* from,
|
||||
Index stride,
|
||||
uint16_t umask) {
|
||||
Packet16i stride_vector = _mm512_set1_epi32(convert_index<int>(stride));
|
||||
Packet16i stride_multiplier =
|
||||
_mm512_set_epi32(15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0);
|
||||
Packet16i indices = _mm512_mullo_epi32(stride_vector, stride_multiplier);
|
||||
__mmask16 mask = static_cast<__mmask16>(umask);
|
||||
|
||||
return _mm512_mask_i32gather_ps(src, mask, indices, from, 4);
|
||||
}
|
||||
template <>
|
||||
EIGEN_DEVICE_FUNC inline Packet8d pgather<double, Packet8d>(const Packet8d& src,
|
||||
const double* from,
|
||||
Index stride,
|
||||
uint8_t umask) {
|
||||
Packet8i stride_vector = _mm256_set1_epi32(convert_index<int>(stride));
|
||||
Packet8i stride_multiplier = _mm256_set_epi32(7, 6, 5, 4, 3, 2, 1, 0);
|
||||
Packet8i indices = _mm256_mullo_epi32(stride_vector, stride_multiplier);
|
||||
__mmask8 mask = static_cast<__mmask8>(umask);
|
||||
|
||||
return _mm512_mask_i32gather_pd(src, mask, indices, from, 8);
|
||||
}
|
||||
|
||||
template <>
|
||||
EIGEN_DEVICE_FUNC inline Packet16f pgather<float, Packet16f>(const float* from,
|
||||
Index stride) {
|
||||
@ -956,6 +985,33 @@ EIGEN_DEVICE_FUNC inline Packet16i pgather<int, Packet16i>(const int* from,
|
||||
return _mm512_i32gather_epi32(indices, from, 4);
|
||||
}
|
||||
|
||||
template <typename Scalar, typename Packet>
|
||||
EIGEN_DEVICE_FUNC inline void pscatter(Scalar* to, const Packet& from,
|
||||
Index stride, typename unpacket_traits<Packet>::mask_t umask);
|
||||
template <>
|
||||
EIGEN_DEVICE_FUNC inline void pscatter<float, Packet16f>(float* to,
|
||||
const Packet16f& from,
|
||||
Index stride,
|
||||
uint16_t umask) {
|
||||
Packet16i stride_vector = _mm512_set1_epi32(convert_index<int>(stride));
|
||||
Packet16i stride_multiplier =
|
||||
_mm512_set_epi32(15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0);
|
||||
Packet16i indices = _mm512_mullo_epi32(stride_vector, stride_multiplier);
|
||||
__mmask16 mask = static_cast<__mmask16>(umask);
|
||||
_mm512_mask_i32scatter_ps(to, mask, indices, from, 4);
|
||||
}
|
||||
template <>
|
||||
EIGEN_DEVICE_FUNC inline void pscatter<double, Packet8d>(double* to,
|
||||
const Packet8d& from,
|
||||
Index stride,
|
||||
uint8_t umask) {
|
||||
Packet8i stride_vector = _mm256_set1_epi32(convert_index<int>(stride));
|
||||
Packet8i stride_multiplier = _mm256_set_epi32(7, 6, 5, 4, 3, 2, 1, 0);
|
||||
Packet8i indices = _mm256_mullo_epi32(stride_vector, stride_multiplier);
|
||||
__mmask8 mask = static_cast<__mmask8>(umask);
|
||||
_mm512_mask_i32scatter_pd(to, mask, indices, from, 8);
|
||||
}
|
||||
|
||||
template <>
|
||||
EIGEN_DEVICE_FUNC inline void pscatter<float, Packet16f>(float* to,
|
||||
const Packet16f& from,
|
||||
@ -1450,28 +1506,24 @@ EIGEN_DEVICE_FUNC inline void ptranspose(PacketBlock<Packet16f, 8>& kernel) {
|
||||
kernel.packet[5] = _mm512_castpd_ps(_mm512_unpackhi_pd(_mm512_castps_pd(T4),_mm512_castps_pd(T6)));
|
||||
kernel.packet[6] = _mm512_castpd_ps(_mm512_unpacklo_pd(_mm512_castps_pd(T5),_mm512_castps_pd(T7)));
|
||||
kernel.packet[7] = _mm512_castpd_ps(_mm512_unpackhi_pd(_mm512_castps_pd(T5),_mm512_castps_pd(T7)));
|
||||
|
||||
T0 = _mm512_castpd_ps(_mm512_permutex_pd(_mm512_castps_pd(kernel.packet[4]), 0x4E));
|
||||
T0 = _mm512_mask_blend_ps(0xF0F0, kernel.packet[0], T0);
|
||||
T4 = _mm512_castpd_ps(_mm512_permutex_pd(_mm512_castps_pd(kernel.packet[0]), 0x4E));
|
||||
T4 = _mm512_mask_blend_ps(0xF0F0, T4, kernel.packet[4]);
|
||||
T1 = _mm512_castpd_ps(_mm512_permutex_pd(_mm512_castps_pd(kernel.packet[5]), 0x4E));
|
||||
T1 = _mm512_mask_blend_ps(0xF0F0, kernel.packet[1], T1);
|
||||
T5 = _mm512_castpd_ps(_mm512_permutex_pd(_mm512_castps_pd(kernel.packet[1]), 0x4E));
|
||||
T5 = _mm512_mask_blend_ps(0xF0F0, T5, kernel.packet[5]);
|
||||
T2 = _mm512_castpd_ps(_mm512_permutex_pd(_mm512_castps_pd(kernel.packet[6]), 0x4E));
|
||||
T2 = _mm512_mask_blend_ps(0xF0F0, kernel.packet[2], T2);
|
||||
T6 = _mm512_castpd_ps(_mm512_permutex_pd(_mm512_castps_pd(kernel.packet[2]), 0x4E));
|
||||
T6 = _mm512_mask_blend_ps(0xF0F0, T6, kernel.packet[6]);
|
||||
T3 = _mm512_castpd_ps(_mm512_permutex_pd(_mm512_castps_pd(kernel.packet[7]), 0x4E));
|
||||
T3 = _mm512_mask_blend_ps(0xF0F0, kernel.packet[3], T3);
|
||||
T7 = _mm512_castpd_ps(_mm512_permutex_pd(_mm512_castps_pd(kernel.packet[3]), 0x4E));
|
||||
T7 = _mm512_mask_blend_ps(0xF0F0, T7, kernel.packet[7]);
|
||||
|
||||
kernel.packet[0] = T0; kernel.packet[1] = T1;
|
||||
kernel.packet[2] = T2; kernel.packet[3] = T3;
|
||||
kernel.packet[4] = T4; kernel.packet[5] = T5;
|
||||
kernel.packet[6] = T6; kernel.packet[7] = T7;
|
||||
T0 = _mm512_shuffle_f32x4(kernel.packet[0], kernel.packet[4], 0x44);
|
||||
T1 = _mm512_shuffle_f32x4(kernel.packet[0], kernel.packet[4], 0xee);
|
||||
T2 = _mm512_shuffle_f32x4(kernel.packet[1], kernel.packet[5], 0x44);
|
||||
T3 = _mm512_shuffle_f32x4(kernel.packet[1], kernel.packet[5], 0xee);
|
||||
T4 = _mm512_shuffle_f32x4(kernel.packet[2], kernel.packet[6], 0x44);
|
||||
T5 = _mm512_shuffle_f32x4(kernel.packet[2], kernel.packet[6], 0xee);
|
||||
T6 = _mm512_shuffle_f32x4(kernel.packet[3], kernel.packet[7], 0x44);
|
||||
T7 = _mm512_shuffle_f32x4(kernel.packet[3], kernel.packet[7], 0xee);
|
||||
|
||||
kernel.packet[0] = _mm512_shuffle_f32x4(T0, T2, 0x88);
|
||||
kernel.packet[2] = _mm512_shuffle_f32x4(T0, T2, 0xdd);
|
||||
kernel.packet[1] = _mm512_shuffle_f32x4(T4, T6, 0x88);
|
||||
kernel.packet[3] = _mm512_shuffle_f32x4(T4, T6, 0xdd);
|
||||
kernel.packet[4] = _mm512_shuffle_f32x4(T1, T3, 0x88);
|
||||
kernel.packet[6] = _mm512_shuffle_f32x4(T1, T3, 0xdd);
|
||||
kernel.packet[5] = _mm512_shuffle_f32x4(T5, T7, 0x88);
|
||||
kernel.packet[7] = _mm512_shuffle_f32x4(T5, T7, 0xdd);
|
||||
}
|
||||
|
||||
EIGEN_DEVICE_FUNC inline void ptranspose(PacketBlock<Packet16f, 4>& kernel) {
|
||||
|
@ -65,6 +65,57 @@ EIGEN_ALWAYS_INLINE auto remMask(int64_t m) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
template <typename Packet>
|
||||
EIGEN_ALWAYS_INLINE void trans8x8blocks(PacketBlock<Packet, 8>& kernel);
|
||||
|
||||
template <>
|
||||
EIGEN_ALWAYS_INLINE void trans8x8blocks(PacketBlock<Packet16f, 8>& kernel) {
|
||||
__m512 T0 = _mm512_unpacklo_ps(kernel.packet[0],kernel.packet[1]);
|
||||
__m512 T1 = _mm512_unpackhi_ps(kernel.packet[0],kernel.packet[1]);
|
||||
__m512 T2 = _mm512_unpacklo_ps(kernel.packet[2],kernel.packet[3]);
|
||||
__m512 T3 = _mm512_unpackhi_ps(kernel.packet[2],kernel.packet[3]);
|
||||
__m512 T4 = _mm512_unpacklo_ps(kernel.packet[4],kernel.packet[5]);
|
||||
__m512 T5 = _mm512_unpackhi_ps(kernel.packet[4],kernel.packet[5]);
|
||||
__m512 T6 = _mm512_unpacklo_ps(kernel.packet[6],kernel.packet[7]);
|
||||
__m512 T7 = _mm512_unpackhi_ps(kernel.packet[6],kernel.packet[7]);
|
||||
|
||||
kernel.packet[0] = _mm512_castpd_ps(_mm512_unpacklo_pd(_mm512_castps_pd(T0),_mm512_castps_pd(T2)));
|
||||
kernel.packet[1] = _mm512_castpd_ps(_mm512_unpackhi_pd(_mm512_castps_pd(T0),_mm512_castps_pd(T2)));
|
||||
kernel.packet[2] = _mm512_castpd_ps(_mm512_unpacklo_pd(_mm512_castps_pd(T1),_mm512_castps_pd(T3)));
|
||||
kernel.packet[3] = _mm512_castpd_ps(_mm512_unpackhi_pd(_mm512_castps_pd(T1),_mm512_castps_pd(T3)));
|
||||
kernel.packet[4] = _mm512_castpd_ps(_mm512_unpacklo_pd(_mm512_castps_pd(T4),_mm512_castps_pd(T6)));
|
||||
kernel.packet[5] = _mm512_castpd_ps(_mm512_unpackhi_pd(_mm512_castps_pd(T4),_mm512_castps_pd(T6)));
|
||||
kernel.packet[6] = _mm512_castpd_ps(_mm512_unpacklo_pd(_mm512_castps_pd(T5),_mm512_castps_pd(T7)));
|
||||
kernel.packet[7] = _mm512_castpd_ps(_mm512_unpackhi_pd(_mm512_castps_pd(T5),_mm512_castps_pd(T7)));
|
||||
|
||||
T0 = _mm512_castpd_ps(_mm512_permutex_pd(_mm512_castps_pd(kernel.packet[4]), 0x4E));
|
||||
T0 = _mm512_mask_blend_ps(0xF0F0, kernel.packet[0], T0);
|
||||
T4 = _mm512_castpd_ps(_mm512_permutex_pd(_mm512_castps_pd(kernel.packet[0]), 0x4E));
|
||||
T4 = _mm512_mask_blend_ps(0xF0F0, T4, kernel.packet[4]);
|
||||
T1 = _mm512_castpd_ps(_mm512_permutex_pd(_mm512_castps_pd(kernel.packet[5]), 0x4E));
|
||||
T1 = _mm512_mask_blend_ps(0xF0F0, kernel.packet[1], T1);
|
||||
T5 = _mm512_castpd_ps(_mm512_permutex_pd(_mm512_castps_pd(kernel.packet[1]), 0x4E));
|
||||
T5 = _mm512_mask_blend_ps(0xF0F0, T5, kernel.packet[5]);
|
||||
T2 = _mm512_castpd_ps(_mm512_permutex_pd(_mm512_castps_pd(kernel.packet[6]), 0x4E));
|
||||
T2 = _mm512_mask_blend_ps(0xF0F0, kernel.packet[2], T2);
|
||||
T6 = _mm512_castpd_ps(_mm512_permutex_pd(_mm512_castps_pd(kernel.packet[2]), 0x4E));
|
||||
T6 = _mm512_mask_blend_ps(0xF0F0, T6, kernel.packet[6]);
|
||||
T3 = _mm512_castpd_ps(_mm512_permutex_pd(_mm512_castps_pd(kernel.packet[7]), 0x4E));
|
||||
T3 = _mm512_mask_blend_ps(0xF0F0, kernel.packet[3], T3);
|
||||
T7 = _mm512_castpd_ps(_mm512_permutex_pd(_mm512_castps_pd(kernel.packet[3]), 0x4E));
|
||||
T7 = _mm512_mask_blend_ps(0xF0F0, T7, kernel.packet[7]);
|
||||
|
||||
kernel.packet[0] = T0; kernel.packet[1] = T1;
|
||||
kernel.packet[2] = T2; kernel.packet[3] = T3;
|
||||
kernel.packet[4] = T4; kernel.packet[5] = T5;
|
||||
kernel.packet[6] = T6; kernel.packet[7] = T7;
|
||||
}
|
||||
|
||||
template <>
|
||||
EIGEN_ALWAYS_INLINE void trans8x8blocks(PacketBlock<Packet8d, 8>& kernel) {
|
||||
ptranspose(kernel);
|
||||
}
|
||||
|
||||
/***
|
||||
* Unrolls for tranposed C stores
|
||||
*/
|
||||
@ -198,7 +249,7 @@ public:
|
||||
r.packet[5] = zmm.packet[packetIndexOffset + zmmStride*5];
|
||||
r.packet[6] = zmm.packet[packetIndexOffset + zmmStride*6];
|
||||
r.packet[7] = zmm.packet[packetIndexOffset + zmmStride*7];
|
||||
ptranspose(r);
|
||||
trans8x8blocks(r);
|
||||
zmm.packet[packetIndexOffset + zmmStride*0] = r.packet[0];
|
||||
zmm.packet[packetIndexOffset + zmmStride*1] = r.packet[1];
|
||||
zmm.packet[packetIndexOffset + zmmStride*2] = r.packet[2];
|
||||
|
@ -44,6 +44,34 @@ template<> EIGEN_STRONG_INLINE Packet8f preinterpret<Packet8f, Packet16f>(const
|
||||
return _mm512_castps512_ps256(a);
|
||||
}
|
||||
|
||||
template<> EIGEN_STRONG_INLINE Packet4f preinterpret<Packet4f, Packet16f>(const Packet16f& a) {
|
||||
return _mm512_castps512_ps128(a);
|
||||
}
|
||||
|
||||
template<> EIGEN_STRONG_INLINE Packet4d preinterpret<Packet4d, Packet8d>(const Packet8d& a) {
|
||||
return _mm512_castpd512_pd256(a);
|
||||
}
|
||||
|
||||
template<> EIGEN_STRONG_INLINE Packet2d preinterpret<Packet2d, Packet8d>(const Packet8d& a) {
|
||||
return _mm512_castpd512_pd128(a);
|
||||
}
|
||||
|
||||
template<> EIGEN_STRONG_INLINE Packet16f preinterpret<Packet16f, Packet8f>(const Packet8f& a) {
|
||||
return _mm512_castps256_ps512(a);
|
||||
}
|
||||
|
||||
template<> EIGEN_STRONG_INLINE Packet16f preinterpret<Packet16f, Packet4f>(const Packet4f& a) {
|
||||
return _mm512_castps128_ps512(a);
|
||||
}
|
||||
|
||||
template<> EIGEN_STRONG_INLINE Packet8d preinterpret<Packet8d, Packet4d>(const Packet4d& a) {
|
||||
return _mm512_castpd256_pd512(a);
|
||||
}
|
||||
|
||||
template<> EIGEN_STRONG_INLINE Packet8d preinterpret<Packet8d, Packet2d>(const Packet2d& a) {
|
||||
return _mm512_castpd128_pd512(a);
|
||||
}
|
||||
|
||||
template<> EIGEN_STRONG_INLINE Packet16f preinterpret<Packet16f, Packet16f>(const Packet16f& a) {
|
||||
return a;
|
||||
}
|
||||
|
@ -285,6 +285,10 @@ template<> EIGEN_STRONG_INLINE Packet4i padd<Packet4i>(const Packet4i& a, const
|
||||
|
||||
template<> EIGEN_STRONG_INLINE Packet16b padd<Packet16b>(const Packet16b& a, const Packet16b& b) { return _mm_or_si128(a,b); }
|
||||
|
||||
template<typename Packet> EIGEN_STRONG_INLINE Packet padds(const Packet& a, const Packet& b);
|
||||
template<> EIGEN_STRONG_INLINE Packet4f padds<Packet4f>(const Packet4f& a, const Packet4f& b) { return _mm_add_ss(a,b); }
|
||||
template<> EIGEN_STRONG_INLINE Packet2d padds<Packet2d>(const Packet2d& a, const Packet2d& b) { return _mm_add_sd(a,b); }
|
||||
|
||||
template<> EIGEN_STRONG_INLINE Packet4f psub<Packet4f>(const Packet4f& a, const Packet4f& b) { return _mm_sub_ps(a,b); }
|
||||
template<> EIGEN_STRONG_INLINE Packet2d psub<Packet2d>(const Packet2d& a, const Packet2d& b) { return _mm_sub_pd(a,b); }
|
||||
template<> EIGEN_STRONG_INLINE Packet4i psub<Packet4i>(const Packet4i& a, const Packet4i& b) { return _mm_sub_epi32(a,b); }
|
||||
@ -370,6 +374,10 @@ template<> EIGEN_STRONG_INLINE Packet4f pnmadd(const Packet4f& a, const Packet4f
|
||||
template<> EIGEN_STRONG_INLINE Packet2d pnmadd(const Packet2d& a, const Packet2d& b, const Packet2d& c) { return _mm_fnmadd_pd(a,b,c); }
|
||||
template<> EIGEN_STRONG_INLINE Packet4f pnmsub(const Packet4f& a, const Packet4f& b, const Packet4f& c) { return _mm_fnmsub_ps(a,b,c); }
|
||||
template<> EIGEN_STRONG_INLINE Packet2d pnmsub(const Packet2d& a, const Packet2d& b, const Packet2d& c) { return _mm_fnmsub_pd(a,b,c); }
|
||||
|
||||
template<typename Packet> EIGEN_STRONG_INLINE Packet pmadds(const Packet& a, const Packet& b, const Packet& c);
|
||||
template<> EIGEN_STRONG_INLINE Packet4f pmadds<Packet4f>(const Packet4f& a, const Packet4f& b, const Packet4f& c) { return _mm_fmadd_ss(a,b,c); }
|
||||
template<> EIGEN_STRONG_INLINE Packet2d pmadds<Packet2d>(const Packet2d& a, const Packet2d& b, const Packet2d& c) { return _mm_fmadd_sd(a,b,c); }
|
||||
#endif
|
||||
|
||||
#ifdef EIGEN_VECTORIZE_SSE4_1
|
||||
@ -746,6 +754,15 @@ template<> EIGEN_STRONG_INLINE Packet16b ploadu<Packet16b>(const bool* from)
|
||||
return _mm_loadu_si128(reinterpret_cast<const __m128i*>(from));
|
||||
}
|
||||
|
||||
// Load lower part of packet zero extending.
|
||||
template<typename Packet> EIGEN_STRONG_INLINE Packet ploadl(const typename unpacket_traits<Packet>::type* from);
|
||||
template<> EIGEN_STRONG_INLINE Packet4f ploadl<Packet4f>(const float* from) { EIGEN_DEBUG_UNALIGNED_LOAD return _mm_castpd_ps(_mm_load_sd(reinterpret_cast<const double*>(from))); }
|
||||
template<> EIGEN_STRONG_INLINE Packet2d ploadl<Packet2d>(const double* from) { EIGEN_DEBUG_UNALIGNED_LOAD return _mm_load_sd(from); }
|
||||
|
||||
// Load scalar
|
||||
template<typename Packet> EIGEN_STRONG_INLINE Packet ploads(const typename unpacket_traits<Packet>::type* from);
|
||||
template<> EIGEN_STRONG_INLINE Packet4f ploads<Packet4f>(const float* from) { EIGEN_DEBUG_UNALIGNED_LOAD return _mm_load_ss(from); }
|
||||
template<> EIGEN_STRONG_INLINE Packet2d ploads<Packet2d>(const double* from) { EIGEN_DEBUG_UNALIGNED_LOAD return _mm_load_sd(from); }
|
||||
|
||||
template<> EIGEN_STRONG_INLINE Packet4f ploaddup<Packet4f>(const float* from)
|
||||
{
|
||||
@ -787,6 +804,14 @@ template<> EIGEN_STRONG_INLINE void pstoreu<float>(float* to, const Packet4f&
|
||||
template<> EIGEN_STRONG_INLINE void pstoreu<int>(int* to, const Packet4i& from) { EIGEN_DEBUG_UNALIGNED_STORE _mm_storeu_si128(reinterpret_cast<__m128i*>(to), from); }
|
||||
template<> EIGEN_STRONG_INLINE void pstoreu<bool>(bool* to, const Packet16b& from) { EIGEN_DEBUG_ALIGNED_STORE _mm_storeu_si128(reinterpret_cast<__m128i*>(to), from); }
|
||||
|
||||
template<typename Scalar, typename Packet> EIGEN_STRONG_INLINE void pstorel(Scalar* to, const Packet& from);
|
||||
template<> EIGEN_STRONG_INLINE void pstorel(float* to, const Packet4f& from) { EIGEN_DEBUG_UNALIGNED_STORE _mm_storel_pi(reinterpret_cast<__m64*>(to), from); }
|
||||
template<> EIGEN_STRONG_INLINE void pstorel(double* to, const Packet2d& from) { EIGEN_DEBUG_UNALIGNED_STORE _mm_storel_pd(to, from); }
|
||||
|
||||
template<typename Scalar, typename Packet> EIGEN_STRONG_INLINE void pstores(Scalar* to, const Packet& from);
|
||||
template<> EIGEN_STRONG_INLINE void pstores(float* to, const Packet4f& from) { EIGEN_DEBUG_UNALIGNED_STORE _mm_store_ss(to, from); }
|
||||
template<> EIGEN_STRONG_INLINE void pstores(double* to, const Packet2d& from) { EIGEN_DEBUG_UNALIGNED_STORE _mm_store_sd(to, from); }
|
||||
|
||||
template<> EIGEN_DEVICE_FUNC inline Packet4f pgather<float, Packet4f>(const float* from, Index stride)
|
||||
{
|
||||
return _mm_set_ps(from[3*stride], from[2*stride], from[1*stride], from[0*stride]);
|
||||
|
@ -71,6 +71,14 @@ template<> EIGEN_STRONG_INLINE Packet2d pcast<Packet4f, Packet2d>(const Packet4f
|
||||
return _mm_cvtps_pd(a);
|
||||
}
|
||||
|
||||
template<> EIGEN_STRONG_INLINE Packet2d preinterpret<Packet2d, Packet4f>(const Packet4f& a) {
|
||||
return _mm_castps_pd(a);
|
||||
}
|
||||
|
||||
template<> EIGEN_STRONG_INLINE Packet4f preinterpret<Packet4f, Packet2d>(const Packet2d& a) {
|
||||
return _mm_castpd_ps(a);
|
||||
}
|
||||
|
||||
template<> EIGEN_STRONG_INLINE Packet4i preinterpret<Packet4i,Packet4f>(const Packet4f& a) {
|
||||
return _mm_castps_si128(a);
|
||||
}
|
||||
|
@ -287,7 +287,6 @@ class gemm_blocking_space<StorageOrder,LhsScalar_,RhsScalar_,MaxRows, MaxCols, M
|
||||
};
|
||||
typedef std::conditional_t<Transpose,RhsScalar_,LhsScalar_> LhsScalar;
|
||||
typedef std::conditional_t<Transpose,LhsScalar_,RhsScalar_> RhsScalar;
|
||||
typedef gebp_traits<LhsScalar,RhsScalar> Traits;
|
||||
enum {
|
||||
SizeA = ActualRows * MaxDepth,
|
||||
SizeB = ActualCols * MaxDepth
|
||||
@ -336,7 +335,6 @@ class gemm_blocking_space<StorageOrder,LhsScalar_,RhsScalar_,MaxRows, MaxCols, M
|
||||
};
|
||||
typedef std::conditional_t<Transpose,RhsScalar_,LhsScalar_> LhsScalar;
|
||||
typedef std::conditional_t<Transpose,LhsScalar_,RhsScalar_> RhsScalar;
|
||||
typedef gebp_traits<LhsScalar,RhsScalar> Traits;
|
||||
|
||||
Index m_sizeA;
|
||||
Index m_sizeB;
|
||||
|
@ -229,6 +229,7 @@ public:
|
||||
}
|
||||
|
||||
EIGEN_DEVICE_FUNC const Index stride() const { return m_stride; }
|
||||
EIGEN_DEVICE_FUNC const Index incr() const { return 1; }
|
||||
EIGEN_DEVICE_FUNC const Scalar* data() const { return m_data; }
|
||||
|
||||
EIGEN_DEVICE_FUNC Index firstAligned(Index size) const {
|
||||
@ -402,6 +403,10 @@ public:
|
||||
storePacketBlock_helper<SubPacket, Scalar, n, n-1> spb;
|
||||
spb.store(this, i,j,block);
|
||||
}
|
||||
|
||||
EIGEN_DEVICE_FUNC const Index stride() const { return m_stride; }
|
||||
EIGEN_DEVICE_FUNC const Index incr() const { return m_incr.value(); }
|
||||
EIGEN_DEVICE_FUNC Scalar* data() const { return m_data; }
|
||||
protected:
|
||||
Scalar* EIGEN_RESTRICT m_data;
|
||||
const Index m_stride;
|
||||
|
Loading…
x
Reference in New Issue
Block a user