mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-08-11 19:29:02 +08:00
AVX512 Optimizations for Triangular Solve
This commit is contained in:
parent
01b5bc48cc
commit
518fc321cb
@ -190,6 +190,7 @@ using std::ptrdiff_t;
|
||||
#include "src/Core/arch/SSE/MathFunctions.h"
|
||||
#include "src/Core/arch/AVX/MathFunctions.h"
|
||||
#include "src/Core/arch/AVX512/MathFunctions.h"
|
||||
#include "src/Core/arch/AVX512/trsmKernel_impl.hpp"
|
||||
#elif defined EIGEN_VECTORIZE_AVX
|
||||
// Use AVX for floats and doubles, SSE for integers
|
||||
#include "src/Core/arch/SSE/PacketMath.h"
|
||||
|
@ -220,6 +220,15 @@ padd(const Packet& a, const Packet& b) { return a+b; }
|
||||
template<> EIGEN_DEVICE_FUNC inline bool
|
||||
padd(const bool& a, const bool& b) { return a || b; }
|
||||
|
||||
/** \internal \returns a packet version of \a *from, (un-aligned masked add)
|
||||
* There is no generic implementation. We only have implementations for specialized
|
||||
* cases. Generic case should not be called.
|
||||
*/
|
||||
template<typename Packet> EIGEN_DEVICE_FUNC inline
|
||||
std::enable_if_t<unpacket_traits<Packet>::masked_fpops_available, Packet>
|
||||
padd(const Packet& a, const Packet& b, typename unpacket_traits<Packet>::mask_t umask);
|
||||
|
||||
|
||||
/** \internal \returns a - b (coeff-wise) */
|
||||
template<typename Packet> EIGEN_DEVICE_FUNC inline Packet
|
||||
psub(const Packet& a, const Packet& b) { return a-b; }
|
||||
|
@ -236,7 +236,11 @@ template<> struct unpacket_traits<Packet8f> {
|
||||
typedef Packet4f half;
|
||||
typedef Packet8i integer_packet;
|
||||
typedef uint8_t mask_t;
|
||||
enum {size=8, alignment=Aligned32, vectorizable=true, masked_load_available=true, masked_store_available=true};
|
||||
enum {size=8, alignment=Aligned32, vectorizable=true, masked_load_available=true, masked_store_available=true
|
||||
#ifdef EIGEN_VECTORIZE_AVX512
|
||||
, masked_fpops_available=true
|
||||
#endif
|
||||
};
|
||||
};
|
||||
template<> struct unpacket_traits<Packet4d> {
|
||||
typedef double type;
|
||||
@ -464,6 +468,13 @@ template<> EIGEN_STRONG_INLINE Packet8f pload1<Packet8f>(const float* from) { r
|
||||
template<> EIGEN_STRONG_INLINE Packet4d pload1<Packet4d>(const double* from) { return _mm256_broadcast_sd(from); }
|
||||
|
||||
template<> EIGEN_STRONG_INLINE Packet8f padd<Packet8f>(const Packet8f& a, const Packet8f& b) { return _mm256_add_ps(a,b); }
|
||||
#ifdef EIGEN_VECTORIZE_AVX512
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet8f padd<Packet8f>(const Packet8f& a, const Packet8f& b, uint8_t umask) {
|
||||
__mmask8 mask = static_cast<__mmask8>(umask);
|
||||
return _mm256_maskz_add_ps(mask, a, b);
|
||||
}
|
||||
#endif
|
||||
template<> EIGEN_STRONG_INLINE Packet4d padd<Packet4d>(const Packet4d& a, const Packet4d& b) { return _mm256_add_pd(a,b); }
|
||||
template<> EIGEN_STRONG_INLINE Packet8i padd<Packet8i>(const Packet8i& a, const Packet8i& b) {
|
||||
#ifdef EIGEN_VECTORIZE_AVX2
|
||||
@ -848,11 +859,16 @@ template<> EIGEN_STRONG_INLINE Packet4d ploadu<Packet4d>(const double* from) { E
|
||||
template<> EIGEN_STRONG_INLINE Packet8i ploadu<Packet8i>(const int* from) { EIGEN_DEBUG_UNALIGNED_LOAD return _mm256_loadu_si256(reinterpret_cast<const __m256i*>(from)); }
|
||||
|
||||
template<> EIGEN_STRONG_INLINE Packet8f ploadu<Packet8f>(const float* from, uint8_t umask) {
|
||||
#ifdef EIGEN_VECTORIZE_AVX512
|
||||
__mmask8 mask = static_cast<__mmask8>(umask);
|
||||
EIGEN_DEBUG_UNALIGNED_LOAD return _mm256_maskz_loadu_ps(mask, from);
|
||||
#else
|
||||
Packet8i mask = _mm256_set1_epi8(static_cast<char>(umask));
|
||||
const Packet8i bit_mask = _mm256_set_epi32(0xffffff7f, 0xffffffbf, 0xffffffdf, 0xffffffef, 0xfffffff7, 0xfffffffb, 0xfffffffd, 0xfffffffe);
|
||||
mask = por<Packet8i>(mask, bit_mask);
|
||||
mask = pcmp_eq<Packet8i>(mask, _mm256_set1_epi32(0xffffffff));
|
||||
EIGEN_DEBUG_UNALIGNED_LOAD return _mm256_maskload_ps(from, mask);
|
||||
#endif
|
||||
}
|
||||
|
||||
// Loads 4 floats from memory a returns the packet {a0, a0 a1, a1, a2, a2, a3, a3}
|
||||
@ -911,11 +927,16 @@ template<> EIGEN_STRONG_INLINE void pstoreu<double>(double* to, const Packet4d&
|
||||
template<> EIGEN_STRONG_INLINE void pstoreu<int>(int* to, const Packet8i& from) { EIGEN_DEBUG_UNALIGNED_STORE _mm256_storeu_si256(reinterpret_cast<__m256i*>(to), from); }
|
||||
|
||||
template<> EIGEN_STRONG_INLINE void pstoreu<float>(float* to, const Packet8f& from, uint8_t umask) {
|
||||
#ifdef EIGEN_VECTORIZE_AVX512
|
||||
__mmask8 mask = static_cast<__mmask8>(umask);
|
||||
EIGEN_DEBUG_UNALIGNED_STORE return _mm256_mask_storeu_ps(to, mask, from);
|
||||
#else
|
||||
Packet8i mask = _mm256_set1_epi8(static_cast<char>(umask));
|
||||
const Packet8i bit_mask = _mm256_set_epi32(0xffffff7f, 0xffffffbf, 0xffffffdf, 0xffffffef, 0xfffffff7, 0xfffffffb, 0xfffffffd, 0xfffffffe);
|
||||
mask = por<Packet8i>(mask, bit_mask);
|
||||
mask = pcmp_eq<Packet8i>(mask, _mm256_set1_epi32(0xffffffff));
|
||||
EIGEN_DEBUG_UNALIGNED_STORE return _mm256_maskstore_ps(to, mask, from);
|
||||
#endif
|
||||
}
|
||||
|
||||
// NOTE: leverage _mm256_i32gather_ps and _mm256_i32gather_pd if AVX2 instructions are available
|
||||
|
@ -180,13 +180,14 @@ struct unpacket_traits<Packet16f> {
|
||||
typedef Packet8f half;
|
||||
typedef Packet16i integer_packet;
|
||||
typedef uint16_t mask_t;
|
||||
enum { size = 16, alignment=Aligned64, vectorizable=true, masked_load_available=true, masked_store_available=true };
|
||||
enum { size = 16, alignment=Aligned64, vectorizable=true, masked_load_available=true, masked_store_available=true, masked_fpops_available=true };
|
||||
};
|
||||
template <>
|
||||
struct unpacket_traits<Packet8d> {
|
||||
typedef double type;
|
||||
typedef Packet4d half;
|
||||
enum { size = 8, alignment=Aligned64, vectorizable=true, masked_load_available=false, masked_store_available=false };
|
||||
typedef uint8_t mask_t;
|
||||
enum { size = 8, alignment=Aligned64, vectorizable=true, masked_load_available=true, masked_store_available=true, masked_fpops_available=true };
|
||||
};
|
||||
template <>
|
||||
struct unpacket_traits<Packet16i> {
|
||||
@ -244,11 +245,25 @@ template<> EIGEN_STRONG_INLINE Packet8d peven_mask(const Packet8d& /*a*/) {
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet16f pload1<Packet16f>(const float* from) {
|
||||
#if (EIGEN_COMP_GNUC != 0) || (EIGEN_COMP_CLANG != 0)
|
||||
// Inline asm here helps reduce some register spilling in TRSM kernels.
|
||||
// See note in unrolls::gemm::microKernel in trsmKernel_impl.hpp
|
||||
Packet16f ret;
|
||||
__asm__ ("vbroadcastss %[mem], %[dst]" : [dst] "=v" (ret) : [mem] "m" (*from));
|
||||
return ret;
|
||||
#else
|
||||
return _mm512_broadcastss_ps(_mm_load_ps1(from));
|
||||
#endif
|
||||
}
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet8d pload1<Packet8d>(const double* from) {
|
||||
#if (EIGEN_COMP_GNUC != 0) || (EIGEN_COMP_CLANG != 0)
|
||||
Packet8d ret;
|
||||
__asm__ ("vbroadcastsd %[mem], %[dst]" : [dst] "=v" (ret) : [mem] "m" (*from));
|
||||
return ret;
|
||||
#else
|
||||
return _mm512_set1_pd(*from);
|
||||
#endif
|
||||
}
|
||||
|
||||
template <>
|
||||
@ -285,6 +300,20 @@ EIGEN_STRONG_INLINE Packet16i padd<Packet16i>(const Packet16i& a,
|
||||
const Packet16i& b) {
|
||||
return _mm512_add_epi32(a, b);
|
||||
}
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet16f padd<Packet16f>(const Packet16f& a,
|
||||
const Packet16f& b,
|
||||
uint16_t umask) {
|
||||
__mmask16 mask = static_cast<__mmask16>(umask);
|
||||
return _mm512_maskz_add_ps(mask, a, b);
|
||||
}
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet8d padd<Packet8d>(const Packet8d& a,
|
||||
const Packet8d& b,
|
||||
uint8_t umask) {
|
||||
__mmask8 mask = static_cast<__mmask8>(umask);
|
||||
return _mm512_maskz_add_pd(mask, a, b);
|
||||
}
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet16f psub<Packet16f>(const Packet16f& a,
|
||||
@ -771,12 +800,16 @@ EIGEN_STRONG_INLINE Packet16i ploadu<Packet16i>(const int* from) {
|
||||
EIGEN_DEBUG_UNALIGNED_LOAD return _mm512_loadu_si512(
|
||||
reinterpret_cast<const __m512i*>(from));
|
||||
}
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet16f ploadu<Packet16f>(const float* from, uint16_t umask) {
|
||||
__mmask16 mask = static_cast<__mmask16>(umask);
|
||||
EIGEN_DEBUG_UNALIGNED_LOAD return _mm512_maskz_loadu_ps(mask, from);
|
||||
}
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet8d ploadu<Packet8d>(const double* from, uint8_t umask) {
|
||||
__mmask8 mask = static_cast<__mmask8>(umask);
|
||||
EIGEN_DEBUG_UNALIGNED_LOAD return _mm512_maskz_loadu_pd(mask, from);
|
||||
}
|
||||
|
||||
// Loads 8 floats from memory a returns the packet
|
||||
// {a0, a0 a1, a1, a2, a2, a3, a3, a4, a4, a5, a5, a6, a6, a7, a7}
|
||||
@ -886,6 +919,11 @@ EIGEN_STRONG_INLINE void pstoreu<float>(float* to, const Packet16f& from, uint16
|
||||
__mmask16 mask = static_cast<__mmask16>(umask);
|
||||
EIGEN_DEBUG_UNALIGNED_STORE return _mm512_mask_storeu_ps(to, mask, from);
|
||||
}
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE void pstoreu<double>(double* to, const Packet8d& from, uint8_t umask) {
|
||||
__mmask8 mask = static_cast<__mmask8>(umask);
|
||||
EIGEN_DEBUG_UNALIGNED_STORE return _mm512_mask_storeu_pd(to, mask, from);
|
||||
}
|
||||
|
||||
template <>
|
||||
EIGEN_DEVICE_FUNC inline Packet16f pgather<float, Packet16f>(const float* from,
|
||||
@ -1392,6 +1430,56 @@ EIGEN_DEVICE_FUNC inline void ptranspose(PacketBlock<Packet16f, 16>& kernel) {
|
||||
EIGEN_INSERT_8f_INTO_16f(OUTPUT[INDEX], INPUT[2 * INDEX], \
|
||||
INPUT[2 * INDEX + STRIDE]);
|
||||
|
||||
EIGEN_DEVICE_FUNC inline void ptranspose(PacketBlock<Packet16f, 8>& kernel) {
|
||||
__m512 T0 = _mm512_unpacklo_ps(kernel.packet[0],kernel.packet[1]);
|
||||
__m512 T1 = _mm512_unpackhi_ps(kernel.packet[0],kernel.packet[1]);
|
||||
__m512 T2 = _mm512_unpacklo_ps(kernel.packet[2],kernel.packet[3]);
|
||||
__m512 T3 = _mm512_unpackhi_ps(kernel.packet[2],kernel.packet[3]);
|
||||
__m512 T4 = _mm512_unpacklo_ps(kernel.packet[4],kernel.packet[5]);
|
||||
__m512 T5 = _mm512_unpackhi_ps(kernel.packet[4],kernel.packet[5]);
|
||||
__m512 T6 = _mm512_unpacklo_ps(kernel.packet[6],kernel.packet[7]);
|
||||
__m512 T7 = _mm512_unpackhi_ps(kernel.packet[6],kernel.packet[7]);
|
||||
|
||||
kernel.packet[0] = reinterpret_cast<__m512>(
|
||||
_mm512_unpacklo_pd(reinterpret_cast<__m512d>(T0),reinterpret_cast<__m512d>(T2)));
|
||||
kernel.packet[1] = reinterpret_cast<__m512>(
|
||||
_mm512_unpackhi_pd(reinterpret_cast<__m512d>(T0),reinterpret_cast<__m512d>(T2)));
|
||||
kernel.packet[2] = reinterpret_cast<__m512>(
|
||||
_mm512_unpacklo_pd(reinterpret_cast<__m512d>(T1),reinterpret_cast<__m512d>(T3)));
|
||||
kernel.packet[3] = reinterpret_cast<__m512>(
|
||||
_mm512_unpackhi_pd(reinterpret_cast<__m512d>(T1),reinterpret_cast<__m512d>(T3)));
|
||||
kernel.packet[4] = reinterpret_cast<__m512>(
|
||||
_mm512_unpacklo_pd(reinterpret_cast<__m512d>(T4),reinterpret_cast<__m512d>(T6)));
|
||||
kernel.packet[5] = reinterpret_cast<__m512>(
|
||||
_mm512_unpackhi_pd(reinterpret_cast<__m512d>(T4),reinterpret_cast<__m512d>(T6)));
|
||||
kernel.packet[6] = reinterpret_cast<__m512>(
|
||||
_mm512_unpacklo_pd(reinterpret_cast<__m512d>(T5),reinterpret_cast<__m512d>(T7)));
|
||||
kernel.packet[7] = reinterpret_cast<__m512>(
|
||||
_mm512_unpackhi_pd(reinterpret_cast<__m512d>(T5),reinterpret_cast<__m512d>(T7)));
|
||||
|
||||
T0 = reinterpret_cast<__m512>(_mm512_permutex_pd(reinterpret_cast<__m512d>(kernel.packet[4]), 0x4E));
|
||||
T0 = _mm512_mask_blend_ps(0xF0F0, kernel.packet[0], T0);
|
||||
T4 = reinterpret_cast<__m512>(_mm512_permutex_pd(reinterpret_cast<__m512d>(kernel.packet[0]), 0x4E));
|
||||
T4 = _mm512_mask_blend_ps(0xF0F0, T4, kernel.packet[4]);
|
||||
T1 = reinterpret_cast<__m512>(_mm512_permutex_pd(reinterpret_cast<__m512d>(kernel.packet[5]), 0x4E));
|
||||
T1 = _mm512_mask_blend_ps(0xF0F0, kernel.packet[1], T1);
|
||||
T5 = reinterpret_cast<__m512>(_mm512_permutex_pd(reinterpret_cast<__m512d>(kernel.packet[1]), 0x4E));
|
||||
T5 = _mm512_mask_blend_ps(0xF0F0, T5, kernel.packet[5]);
|
||||
T2 = reinterpret_cast<__m512>(_mm512_permutex_pd(reinterpret_cast<__m512d>(kernel.packet[6]), 0x4E));
|
||||
T2 = _mm512_mask_blend_ps(0xF0F0, kernel.packet[2], T2);
|
||||
T6 = reinterpret_cast<__m512>(_mm512_permutex_pd(reinterpret_cast<__m512d>(kernel.packet[2]), 0x4E));
|
||||
T6 = _mm512_mask_blend_ps(0xF0F0, T6, kernel.packet[6]);
|
||||
T3 = reinterpret_cast<__m512>(_mm512_permutex_pd(reinterpret_cast<__m512d>(kernel.packet[7]), 0x4E));
|
||||
T3 = _mm512_mask_blend_ps(0xF0F0, kernel.packet[3], T3);
|
||||
T7 = reinterpret_cast<__m512>(_mm512_permutex_pd(reinterpret_cast<__m512d>(kernel.packet[3]), 0x4E));
|
||||
T7 = _mm512_mask_blend_ps(0xF0F0, T7, kernel.packet[7]);
|
||||
|
||||
kernel.packet[0] = T0; kernel.packet[1] = T1;
|
||||
kernel.packet[2] = T2; kernel.packet[3] = T3;
|
||||
kernel.packet[4] = T4; kernel.packet[5] = T5;
|
||||
kernel.packet[6] = T6; kernel.packet[7] = T7;
|
||||
}
|
||||
|
||||
EIGEN_DEVICE_FUNC inline void ptranspose(PacketBlock<Packet16f, 4>& kernel) {
|
||||
__m512 T0 = _mm512_unpacklo_ps(kernel.packet[0], kernel.packet[1]);
|
||||
__m512 T1 = _mm512_unpackhi_ps(kernel.packet[0], kernel.packet[1]);
|
||||
@ -1468,62 +1556,53 @@ EIGEN_DEVICE_FUNC inline void ptranspose(PacketBlock<Packet8d, 4>& kernel) {
|
||||
}
|
||||
|
||||
EIGEN_DEVICE_FUNC inline void ptranspose(PacketBlock<Packet8d, 8>& kernel) {
|
||||
__m512d T0 = _mm512_unpacklo_pd(kernel.packet[0], kernel.packet[1]);
|
||||
__m512d T1 = _mm512_unpackhi_pd(kernel.packet[0], kernel.packet[1]);
|
||||
__m512d T2 = _mm512_unpacklo_pd(kernel.packet[2], kernel.packet[3]);
|
||||
__m512d T3 = _mm512_unpackhi_pd(kernel.packet[2], kernel.packet[3]);
|
||||
__m512d T4 = _mm512_unpacklo_pd(kernel.packet[4], kernel.packet[5]);
|
||||
__m512d T5 = _mm512_unpackhi_pd(kernel.packet[4], kernel.packet[5]);
|
||||
__m512d T6 = _mm512_unpacklo_pd(kernel.packet[6], kernel.packet[7]);
|
||||
__m512d T7 = _mm512_unpackhi_pd(kernel.packet[6], kernel.packet[7]);
|
||||
__m512d T0 = _mm512_unpacklo_pd(kernel.packet[0],kernel.packet[1]);
|
||||
__m512d T1 = _mm512_unpackhi_pd(kernel.packet[0],kernel.packet[1]);
|
||||
__m512d T2 = _mm512_unpacklo_pd(kernel.packet[2],kernel.packet[3]);
|
||||
__m512d T3 = _mm512_unpackhi_pd(kernel.packet[2],kernel.packet[3]);
|
||||
__m512d T4 = _mm512_unpacklo_pd(kernel.packet[4],kernel.packet[5]);
|
||||
__m512d T5 = _mm512_unpackhi_pd(kernel.packet[4],kernel.packet[5]);
|
||||
__m512d T6 = _mm512_unpacklo_pd(kernel.packet[6],kernel.packet[7]);
|
||||
__m512d T7 = _mm512_unpackhi_pd(kernel.packet[6],kernel.packet[7]);
|
||||
|
||||
PacketBlock<Packet4d, 16> tmp;
|
||||
kernel.packet[0] = _mm512_permutex_pd(T2, 0x4E);
|
||||
kernel.packet[0] = _mm512_mask_blend_pd(0xCC, T0, kernel.packet[0]);
|
||||
kernel.packet[2] = _mm512_permutex_pd(T0, 0x4E);
|
||||
kernel.packet[2] = _mm512_mask_blend_pd(0xCC, kernel.packet[2], T2);
|
||||
kernel.packet[1] = _mm512_permutex_pd(T3, 0x4E);
|
||||
kernel.packet[1] = _mm512_mask_blend_pd(0xCC, T1, kernel.packet[1]);
|
||||
kernel.packet[3] = _mm512_permutex_pd(T1, 0x4E);
|
||||
kernel.packet[3] = _mm512_mask_blend_pd(0xCC, kernel.packet[3], T3);
|
||||
kernel.packet[4] = _mm512_permutex_pd(T6, 0x4E);
|
||||
kernel.packet[4] = _mm512_mask_blend_pd(0xCC, T4, kernel.packet[4]);
|
||||
kernel.packet[6] = _mm512_permutex_pd(T4, 0x4E);
|
||||
kernel.packet[6] = _mm512_mask_blend_pd(0xCC, kernel.packet[6], T6);
|
||||
kernel.packet[5] = _mm512_permutex_pd(T7, 0x4E);
|
||||
kernel.packet[5] = _mm512_mask_blend_pd(0xCC, T5, kernel.packet[5]);
|
||||
kernel.packet[7] = _mm512_permutex_pd(T5, 0x4E);
|
||||
kernel.packet[7] = _mm512_mask_blend_pd(0xCC, kernel.packet[7], T7);
|
||||
|
||||
tmp.packet[0] = _mm256_permute2f128_pd(_mm512_extractf64x4_pd(T0, 0),
|
||||
_mm512_extractf64x4_pd(T2, 0), 0x20);
|
||||
tmp.packet[1] = _mm256_permute2f128_pd(_mm512_extractf64x4_pd(T1, 0),
|
||||
_mm512_extractf64x4_pd(T3, 0), 0x20);
|
||||
tmp.packet[2] = _mm256_permute2f128_pd(_mm512_extractf64x4_pd(T0, 0),
|
||||
_mm512_extractf64x4_pd(T2, 0), 0x31);
|
||||
tmp.packet[3] = _mm256_permute2f128_pd(_mm512_extractf64x4_pd(T1, 0),
|
||||
_mm512_extractf64x4_pd(T3, 0), 0x31);
|
||||
T0 = _mm512_shuffle_f64x2(kernel.packet[4], kernel.packet[4], 0x4E);
|
||||
T0 = _mm512_mask_blend_pd(0xF0, kernel.packet[0], T0);
|
||||
T4 = _mm512_shuffle_f64x2(kernel.packet[0], kernel.packet[0], 0x4E);
|
||||
T4 = _mm512_mask_blend_pd(0xF0, T4, kernel.packet[4]);
|
||||
T1 = _mm512_shuffle_f64x2(kernel.packet[5], kernel.packet[5], 0x4E);
|
||||
T1 = _mm512_mask_blend_pd(0xF0, kernel.packet[1], T1);
|
||||
T5 = _mm512_shuffle_f64x2(kernel.packet[1], kernel.packet[1], 0x4E);
|
||||
T5 = _mm512_mask_blend_pd(0xF0, T5, kernel.packet[5]);
|
||||
T2 = _mm512_shuffle_f64x2(kernel.packet[6], kernel.packet[6], 0x4E);
|
||||
T2 = _mm512_mask_blend_pd(0xF0, kernel.packet[2], T2);
|
||||
T6 = _mm512_shuffle_f64x2(kernel.packet[2], kernel.packet[2], 0x4E);
|
||||
T6 = _mm512_mask_blend_pd(0xF0, T6, kernel.packet[6]);
|
||||
T3 = _mm512_shuffle_f64x2(kernel.packet[7], kernel.packet[7], 0x4E);
|
||||
T3 = _mm512_mask_blend_pd(0xF0, kernel.packet[3], T3);
|
||||
T7 = _mm512_shuffle_f64x2(kernel.packet[3], kernel.packet[3], 0x4E);
|
||||
T7 = _mm512_mask_blend_pd(0xF0, T7, kernel.packet[7]);
|
||||
|
||||
tmp.packet[4] = _mm256_permute2f128_pd(_mm512_extractf64x4_pd(T0, 1),
|
||||
_mm512_extractf64x4_pd(T2, 1), 0x20);
|
||||
tmp.packet[5] = _mm256_permute2f128_pd(_mm512_extractf64x4_pd(T1, 1),
|
||||
_mm512_extractf64x4_pd(T3, 1), 0x20);
|
||||
tmp.packet[6] = _mm256_permute2f128_pd(_mm512_extractf64x4_pd(T0, 1),
|
||||
_mm512_extractf64x4_pd(T2, 1), 0x31);
|
||||
tmp.packet[7] = _mm256_permute2f128_pd(_mm512_extractf64x4_pd(T1, 1),
|
||||
_mm512_extractf64x4_pd(T3, 1), 0x31);
|
||||
|
||||
tmp.packet[8] = _mm256_permute2f128_pd(_mm512_extractf64x4_pd(T4, 0),
|
||||
_mm512_extractf64x4_pd(T6, 0), 0x20);
|
||||
tmp.packet[9] = _mm256_permute2f128_pd(_mm512_extractf64x4_pd(T5, 0),
|
||||
_mm512_extractf64x4_pd(T7, 0), 0x20);
|
||||
tmp.packet[10] = _mm256_permute2f128_pd(_mm512_extractf64x4_pd(T4, 0),
|
||||
_mm512_extractf64x4_pd(T6, 0), 0x31);
|
||||
tmp.packet[11] = _mm256_permute2f128_pd(_mm512_extractf64x4_pd(T5, 0),
|
||||
_mm512_extractf64x4_pd(T7, 0), 0x31);
|
||||
|
||||
tmp.packet[12] = _mm256_permute2f128_pd(_mm512_extractf64x4_pd(T4, 1),
|
||||
_mm512_extractf64x4_pd(T6, 1), 0x20);
|
||||
tmp.packet[13] = _mm256_permute2f128_pd(_mm512_extractf64x4_pd(T5, 1),
|
||||
_mm512_extractf64x4_pd(T7, 1), 0x20);
|
||||
tmp.packet[14] = _mm256_permute2f128_pd(_mm512_extractf64x4_pd(T4, 1),
|
||||
_mm512_extractf64x4_pd(T6, 1), 0x31);
|
||||
tmp.packet[15] = _mm256_permute2f128_pd(_mm512_extractf64x4_pd(T5, 1),
|
||||
_mm512_extractf64x4_pd(T7, 1), 0x31);
|
||||
|
||||
PACK_OUTPUT_SQ_D(kernel.packet, tmp.packet, 0, 8);
|
||||
PACK_OUTPUT_SQ_D(kernel.packet, tmp.packet, 1, 8);
|
||||
PACK_OUTPUT_SQ_D(kernel.packet, tmp.packet, 2, 8);
|
||||
PACK_OUTPUT_SQ_D(kernel.packet, tmp.packet, 3, 8);
|
||||
|
||||
PACK_OUTPUT_SQ_D(kernel.packet, tmp.packet, 4, 8);
|
||||
PACK_OUTPUT_SQ_D(kernel.packet, tmp.packet, 5, 8);
|
||||
PACK_OUTPUT_SQ_D(kernel.packet, tmp.packet, 6, 8);
|
||||
PACK_OUTPUT_SQ_D(kernel.packet, tmp.packet, 7, 8);
|
||||
kernel.packet[0] = T0; kernel.packet[1] = T1;
|
||||
kernel.packet[2] = T2; kernel.packet[3] = T3;
|
||||
kernel.packet[4] = T4; kernel.packet[5] = T5;
|
||||
kernel.packet[6] = T6; kernel.packet[7] = T7;
|
||||
}
|
||||
|
||||
#define PACK_OUTPUT_I32(OUTPUT, INPUT, INDEX, STRIDE) \
|
||||
|
1106
Eigen/src/Core/arch/AVX512/trsmKernel_impl.hpp
Normal file
1106
Eigen/src/Core/arch/AVX512/trsmKernel_impl.hpp
Normal file
File diff suppressed because it is too large
Load Diff
1213
Eigen/src/Core/arch/AVX512/unrolls_impl.hpp
Normal file
1213
Eigen/src/Core/arch/AVX512/unrolls_impl.hpp
Normal file
File diff suppressed because it is too large
Load Diff
@ -2,6 +2,7 @@
|
||||
// for linear algebra.
|
||||
//
|
||||
// Copyright (C) 2009 Gael Guennebaud <gael.guennebaud@inria.fr>
|
||||
// Modifications Copyright (C) 2022 Intel Corporation
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the Mozilla
|
||||
// Public License v. 2.0. If a copy of the MPL was not distributed
|
||||
@ -16,6 +17,114 @@ namespace Eigen {
|
||||
|
||||
namespace internal {
|
||||
|
||||
template <typename Scalar, typename Index, int Mode, bool Conjugate, int TriStorageOrder,int OtherInnerStride>
|
||||
struct trsm_kernels {
|
||||
// Generic Implementation of triangular solve for triangular matrix on left and multiple rhs.
|
||||
// Handles non-packed matrices.
|
||||
static void trsmKernelL(
|
||||
Index size, Index otherSize,
|
||||
const Scalar* _tri, Index triStride,
|
||||
Scalar* _other, Index otherIncr, Index otherStride);
|
||||
|
||||
// Generic Implementation of triangular solve for triangular matrix on right and multiple lhs.
|
||||
// Handles non-packed matrices.
|
||||
static void trsmKernelR(
|
||||
Index size, Index otherSize,
|
||||
const Scalar* _tri, Index triStride,
|
||||
Scalar* _other, Index otherIncr, Index otherStride);
|
||||
};
|
||||
|
||||
template <typename Scalar, typename Index, int Mode, bool Conjugate, int TriStorageOrder,int OtherInnerStride>
|
||||
EIGEN_STRONG_INLINE void trsm_kernels<Scalar, Index, Mode, Conjugate, TriStorageOrder, OtherInnerStride>::trsmKernelL(
|
||||
Index size, Index otherSize,
|
||||
const Scalar* _tri, Index triStride,
|
||||
Scalar* _other, Index otherIncr, Index otherStride)
|
||||
{
|
||||
typedef const_blas_data_mapper<Scalar, Index, TriStorageOrder> TriMapper;
|
||||
typedef blas_data_mapper<Scalar, Index, ColMajor, Unaligned, OtherInnerStride> OtherMapper;
|
||||
TriMapper tri(_tri, triStride);
|
||||
OtherMapper other(_other, otherStride, otherIncr);
|
||||
|
||||
enum { IsLower = (Mode&Lower) == Lower };
|
||||
conj_if<Conjugate> conj;
|
||||
|
||||
// tr solve
|
||||
for (Index k=0; k<size; ++k)
|
||||
{
|
||||
// TODO write a small kernel handling this (can be shared with trsv)
|
||||
Index i = IsLower ? k : -k-1;
|
||||
Index rs = size - k - 1; // remaining size
|
||||
Index s = TriStorageOrder==RowMajor ? (IsLower ? 0 : i+1)
|
||||
: IsLower ? i+1 : i-rs;
|
||||
|
||||
Scalar a = (Mode & UnitDiag) ? Scalar(1) : Scalar(1)/conj(tri(i,i));
|
||||
for (Index j=0; j<otherSize; ++j)
|
||||
{
|
||||
if (TriStorageOrder==RowMajor)
|
||||
{
|
||||
Scalar b(0);
|
||||
const Scalar* l = &tri(i,s);
|
||||
typename OtherMapper::LinearMapper r = other.getLinearMapper(s,j);
|
||||
for (Index i3=0; i3<k; ++i3)
|
||||
b += conj(l[i3]) * r(i3);
|
||||
|
||||
other(i,j) = (other(i,j) - b)*a;
|
||||
}
|
||||
else
|
||||
{
|
||||
Scalar& otherij = other(i,j);
|
||||
otherij *= a;
|
||||
Scalar b = otherij;
|
||||
typename OtherMapper::LinearMapper r = other.getLinearMapper(s,j);
|
||||
typename TriMapper::LinearMapper l = tri.getLinearMapper(s,i);
|
||||
for (Index i3=0;i3<rs;++i3)
|
||||
r(i3) -= b * conj(l(i3));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
template <typename Scalar, typename Index, int Mode, bool Conjugate, int TriStorageOrder, int OtherInnerStride>
|
||||
EIGEN_STRONG_INLINE void trsm_kernels<Scalar, Index, Mode, Conjugate, TriStorageOrder, OtherInnerStride>::trsmKernelR(
|
||||
Index size, Index otherSize,
|
||||
const Scalar* _tri, Index triStride,
|
||||
Scalar* _other, Index otherIncr, Index otherStride)
|
||||
{
|
||||
typedef typename NumTraits<Scalar>::Real RealScalar;
|
||||
typedef blas_data_mapper<Scalar, Index, ColMajor, Unaligned, OtherInnerStride> LhsMapper;
|
||||
typedef const_blas_data_mapper<Scalar, Index, TriStorageOrder> RhsMapper;
|
||||
LhsMapper lhs(_other, otherStride, otherIncr);
|
||||
RhsMapper rhs(_tri, triStride);
|
||||
|
||||
enum {
|
||||
RhsStorageOrder = TriStorageOrder,
|
||||
IsLower = (Mode&Lower) == Lower
|
||||
};
|
||||
conj_if<Conjugate> conj;
|
||||
|
||||
for (Index k=0; k<size; ++k)
|
||||
{
|
||||
Index j = IsLower ? size-k-1 : k;
|
||||
|
||||
typename LhsMapper::LinearMapper r = lhs.getLinearMapper(0,j);
|
||||
for (Index k3=0; k3<k; ++k3)
|
||||
{
|
||||
Scalar b = conj(rhs(IsLower ? j+1+k3 : k3,j));
|
||||
typename LhsMapper::LinearMapper a = lhs.getLinearMapper(0,IsLower ? j+1+k3 : k3);
|
||||
for (Index i=0; i<otherSize; ++i)
|
||||
r(i) -= a(i) * b;
|
||||
}
|
||||
if((Mode & UnitDiag)==0)
|
||||
{
|
||||
Scalar inv_rjj = RealScalar(1)/conj(rhs(j,j));
|
||||
for (Index i=0; i<otherSize; ++i)
|
||||
r(i) *= inv_rjj;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// if the rhs is row major, let's transpose the product
|
||||
template <typename Scalar, typename Index, int Side, int Mode, bool Conjugate, int TriStorageOrder, int OtherInnerStride>
|
||||
struct triangular_solve_matrix<Scalar,Index,Side,Mode,Conjugate,TriStorageOrder,RowMajor,OtherInnerStride>
|
||||
@ -46,6 +155,7 @@ struct triangular_solve_matrix<Scalar,Index,OnTheLeft,Mode,Conjugate,TriStorageO
|
||||
Scalar* _other, Index otherIncr, Index otherStride,
|
||||
level3_blocking<Scalar,Scalar>& blocking);
|
||||
};
|
||||
|
||||
template <typename Scalar, typename Index, int Mode, bool Conjugate, int TriStorageOrder, int OtherInnerStride>
|
||||
EIGEN_DONT_INLINE void triangular_solve_matrix<Scalar,Index,OnTheLeft,Mode,Conjugate,TriStorageOrder,ColMajor,OtherInnerStride>::run(
|
||||
Index size, Index otherSize,
|
||||
@ -55,6 +165,25 @@ EIGEN_DONT_INLINE void triangular_solve_matrix<Scalar,Index,OnTheLeft,Mode,Conju
|
||||
{
|
||||
Index cols = otherSize;
|
||||
|
||||
std::ptrdiff_t l1, l2, l3;
|
||||
manage_caching_sizes(GetAction, &l1, &l2, &l3);
|
||||
|
||||
#if defined(EIGEN_ENABLE_AVX512_NOCOPY_TRSM_CUTOFFS)
|
||||
EIGEN_IF_CONSTEXPR( (OtherInnerStride == 1 &&
|
||||
(std::is_same<Scalar,float>::value ||
|
||||
std::is_same<Scalar,double>::value)) ) {
|
||||
// Very rough cutoffs to determine when to call trsm w/o packing
|
||||
// For small problem sizes trsmKernel compiled with clang is generally faster.
|
||||
// TODO: Investigate better heuristics for cutoffs.
|
||||
double L2Cap = 0.5; // 50% of L2 size
|
||||
if (size < avx512_trsm_cutoff<Scalar>(l2, cols, L2Cap)) {
|
||||
trsm_kernels<Scalar, Index, Mode, Conjugate, TriStorageOrder, 1>::trsmKernelL(
|
||||
size, cols, _tri, triStride, _other, 1, otherStride);
|
||||
return;
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
typedef const_blas_data_mapper<Scalar, Index, TriStorageOrder> TriMapper;
|
||||
typedef blas_data_mapper<Scalar, Index, ColMajor, Unaligned, OtherInnerStride> OtherMapper;
|
||||
TriMapper tri(_tri, triStride);
|
||||
@ -76,15 +205,12 @@ EIGEN_DONT_INLINE void triangular_solve_matrix<Scalar,Index,OnTheLeft,Mode,Conju
|
||||
ei_declare_aligned_stack_constructed_variable(Scalar, blockA, sizeA, blocking.blockA());
|
||||
ei_declare_aligned_stack_constructed_variable(Scalar, blockB, sizeB, blocking.blockB());
|
||||
|
||||
conj_if<Conjugate> conj;
|
||||
gebp_kernel<Scalar, Scalar, Index, OtherMapper, Traits::mr, Traits::nr, Conjugate, false> gebp_kernel;
|
||||
gemm_pack_lhs<Scalar, Index, TriMapper, Traits::mr, Traits::LhsProgress, typename Traits::LhsPacket4Packing, TriStorageOrder> pack_lhs;
|
||||
gemm_pack_rhs<Scalar, Index, OtherMapper, Traits::nr, ColMajor, false, true> pack_rhs;
|
||||
|
||||
// the goal here is to subdivise the Rhs panels such that we keep some cache
|
||||
// coherence when accessing the rhs elements
|
||||
std::ptrdiff_t l1, l2, l3;
|
||||
manage_caching_sizes(GetAction, &l1, &l2, &l3);
|
||||
Index subcols = cols>0 ? l2/(4 * sizeof(Scalar) * std::max<Index>(otherStride,size)) : 0;
|
||||
subcols = std::max<Index>((subcols/Traits::nr)*Traits::nr, Traits::nr);
|
||||
|
||||
@ -115,38 +241,19 @@ EIGEN_DONT_INLINE void triangular_solve_matrix<Scalar,Index,OnTheLeft,Mode,Conju
|
||||
{
|
||||
Index actualPanelWidth = std::min<Index>(actual_kc-k1, SmallPanelWidth);
|
||||
// tr solve
|
||||
for (Index k=0; k<actualPanelWidth; ++k)
|
||||
{
|
||||
// TODO write a small kernel handling this (can be shared with trsv)
|
||||
Index i = IsLower ? k2+k1+k : k2-k1-k-1;
|
||||
Index rs = actualPanelWidth - k - 1; // remaining size
|
||||
Index s = TriStorageOrder==RowMajor ? (IsLower ? k2+k1 : i+1)
|
||||
: IsLower ? i+1 : i-rs;
|
||||
|
||||
Scalar a = (Mode & UnitDiag) ? Scalar(1) : Scalar(1)/conj(tri(i,i));
|
||||
for (Index j=j2; j<j2+actual_cols; ++j)
|
||||
{
|
||||
if (TriStorageOrder==RowMajor)
|
||||
{
|
||||
Scalar b(0);
|
||||
const Scalar* l = &tri(i,s);
|
||||
typename OtherMapper::LinearMapper r = other.getLinearMapper(s,j);
|
||||
for (Index i3=0; i3<k; ++i3)
|
||||
b += conj(l[i3]) * r(i3);
|
||||
|
||||
other(i,j) = (other(i,j) - b)*a;
|
||||
}
|
||||
else
|
||||
{
|
||||
Scalar& otherij = other(i,j);
|
||||
otherij *= a;
|
||||
Scalar b = otherij;
|
||||
typename OtherMapper::LinearMapper r = other.getLinearMapper(s,j);
|
||||
typename TriMapper::LinearMapper l = tri.getLinearMapper(s,i);
|
||||
for (Index i3=0;i3<rs;++i3)
|
||||
r(i3) -= b * conj(l(i3));
|
||||
}
|
||||
Index i = IsLower ? k2+k1 : k2-k1;
|
||||
#if defined(EIGEN_USE_AVX512_TRSM_KERNELS)
|
||||
EIGEN_IF_CONSTEXPR( (OtherInnerStride == 1 &&
|
||||
(std::is_same<Scalar,float>::value ||
|
||||
std::is_same<Scalar,double>::value)) ) {
|
||||
i = IsLower ? k2 + k1: k2 - k1 - actualPanelWidth;
|
||||
}
|
||||
#endif
|
||||
trsm_kernels<Scalar, Index, Mode, Conjugate, TriStorageOrder, OtherInnerStride>::trsmKernelL(
|
||||
actualPanelWidth, actual_cols,
|
||||
_tri + i + (i)*triStride, triStride,
|
||||
_other + i*OtherInnerStride + j2*otherStride, otherIncr, otherStride);
|
||||
}
|
||||
|
||||
Index lengthTarget = actual_kc-k1-actualPanelWidth;
|
||||
@ -198,6 +305,7 @@ struct triangular_solve_matrix<Scalar,Index,OnTheRight,Mode,Conjugate,TriStorage
|
||||
Scalar* _other, Index otherIncr, Index otherStride,
|
||||
level3_blocking<Scalar,Scalar>& blocking);
|
||||
};
|
||||
|
||||
template <typename Scalar, typename Index, int Mode, bool Conjugate, int TriStorageOrder, int OtherInnerStride>
|
||||
EIGEN_DONT_INLINE void triangular_solve_matrix<Scalar,Index,OnTheRight,Mode,Conjugate,TriStorageOrder,ColMajor,OtherInnerStride>::run(
|
||||
Index size, Index otherSize,
|
||||
@ -206,7 +314,22 @@ EIGEN_DONT_INLINE void triangular_solve_matrix<Scalar,Index,OnTheRight,Mode,Conj
|
||||
level3_blocking<Scalar,Scalar>& blocking)
|
||||
{
|
||||
Index rows = otherSize;
|
||||
typedef typename NumTraits<Scalar>::Real RealScalar;
|
||||
|
||||
#if defined(EIGEN_ENABLE_AVX512_NOCOPY_TRSM_CUTOFFS)
|
||||
EIGEN_IF_CONSTEXPR( (OtherInnerStride == 1 &&
|
||||
(std::is_same<Scalar,float>::value ||
|
||||
std::is_same<Scalar,double>::value)) ) {
|
||||
// TODO: Investigate better heuristics for cutoffs.
|
||||
std::ptrdiff_t l1, l2, l3;
|
||||
manage_caching_sizes(GetAction, &l1, &l2, &l3);
|
||||
double L2Cap = 0.5; // 50% of L2 size
|
||||
if (size < avx512_trsm_cutoff<Scalar>(l2, rows, L2Cap)) {
|
||||
trsm_kernels<Scalar, Index, Mode, Conjugate, TriStorageOrder, OtherInnerStride>::
|
||||
trsmKernelR(size, rows, _tri, triStride, _other, 1, otherStride);
|
||||
return;
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
typedef blas_data_mapper<Scalar, Index, ColMajor, Unaligned, OtherInnerStride> LhsMapper;
|
||||
typedef const_blas_data_mapper<Scalar, Index, TriStorageOrder> RhsMapper;
|
||||
@ -229,7 +352,6 @@ EIGEN_DONT_INLINE void triangular_solve_matrix<Scalar,Index,OnTheRight,Mode,Conj
|
||||
ei_declare_aligned_stack_constructed_variable(Scalar, blockA, sizeA, blocking.blockA());
|
||||
ei_declare_aligned_stack_constructed_variable(Scalar, blockB, sizeB, blocking.blockB());
|
||||
|
||||
conj_if<Conjugate> conj;
|
||||
gebp_kernel<Scalar, Scalar, Index, LhsMapper, Traits::mr, Traits::nr, false, Conjugate> gebp_kernel;
|
||||
gemm_pack_rhs<Scalar, Index, RhsMapper, Traits::nr, RhsStorageOrder> pack_rhs;
|
||||
gemm_pack_rhs<Scalar, Index, RhsMapper, Traits::nr, RhsStorageOrder,false,true> pack_rhs_panel;
|
||||
@ -296,27 +418,13 @@ EIGEN_DONT_INLINE void triangular_solve_matrix<Scalar,Index,OnTheRight,Mode,Conj
|
||||
panelOffset, panelOffset); // offsets
|
||||
}
|
||||
|
||||
// unblocked triangular solve
|
||||
for (Index k=0; k<actualPanelWidth; ++k)
|
||||
{
|
||||
Index j = IsLower ? absolute_j2+actualPanelWidth-k-1 : absolute_j2+k;
|
||||
|
||||
typename LhsMapper::LinearMapper r = lhs.getLinearMapper(i2,j);
|
||||
for (Index k3=0; k3<k; ++k3)
|
||||
{
|
||||
Scalar b = conj(rhs(IsLower ? j+1+k3 : absolute_j2+k3,j));
|
||||
typename LhsMapper::LinearMapper a = lhs.getLinearMapper(i2,IsLower ? j+1+k3 : absolute_j2+k3);
|
||||
for (Index i=0; i<actual_mc; ++i)
|
||||
r(i) -= a(i) * b;
|
||||
}
|
||||
if((Mode & UnitDiag)==0)
|
||||
{
|
||||
Scalar inv_rjj = RealScalar(1)/conj(rhs(j,j));
|
||||
for (Index i=0; i<actual_mc; ++i)
|
||||
r(i) *= inv_rjj;
|
||||
}
|
||||
// unblocked triangular solve
|
||||
trsm_kernels<Scalar, Index, Mode, Conjugate, TriStorageOrder, OtherInnerStride>::
|
||||
trsmKernelR(actualPanelWidth, actual_mc,
|
||||
_tri + absolute_j2 + absolute_j2*triStride, triStride,
|
||||
_other + i2*OtherInnerStride + absolute_j2*otherStride, otherIncr, otherStride);
|
||||
}
|
||||
|
||||
// pack the just computed part of lhs to A
|
||||
pack_lhs_panel(blockA, lhs.getSubMapper(i2,absolute_j2),
|
||||
actualPanelWidth, actual_mc,
|
||||
@ -331,7 +439,6 @@ EIGEN_DONT_INLINE void triangular_solve_matrix<Scalar,Index,OnTheRight,Mode,Conj
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // end namespace internal
|
||||
|
||||
} // end namespace Eigen
|
||||
|
Loading…
x
Reference in New Issue
Block a user