AVX512 Optimizations for Triangular Solve

This commit is contained in:
b-shi 2022-03-16 18:04:50 +00:00 committed by Rasmus Munk Larsen
parent 01b5bc48cc
commit 518fc321cb
7 changed files with 2664 additions and 128 deletions

View File

@ -190,6 +190,7 @@ using std::ptrdiff_t;
#include "src/Core/arch/SSE/MathFunctions.h"
#include "src/Core/arch/AVX/MathFunctions.h"
#include "src/Core/arch/AVX512/MathFunctions.h"
#include "src/Core/arch/AVX512/trsmKernel_impl.hpp"
#elif defined EIGEN_VECTORIZE_AVX
// Use AVX for floats and doubles, SSE for integers
#include "src/Core/arch/SSE/PacketMath.h"

View File

@ -220,6 +220,15 @@ padd(const Packet& a, const Packet& b) { return a+b; }
template<> EIGEN_DEVICE_FUNC inline bool
padd(const bool& a, const bool& b) { return a || b; }
/** \internal \returns a packet version of \a *from, (un-aligned masked add)
* There is no generic implementation. We only have implementations for specialized
* cases. Generic case should not be called.
*/
template<typename Packet> EIGEN_DEVICE_FUNC inline
std::enable_if_t<unpacket_traits<Packet>::masked_fpops_available, Packet>
padd(const Packet& a, const Packet& b, typename unpacket_traits<Packet>::mask_t umask);
/** \internal \returns a - b (coeff-wise) */
template<typename Packet> EIGEN_DEVICE_FUNC inline Packet
psub(const Packet& a, const Packet& b) { return a-b; }

View File

@ -236,7 +236,11 @@ template<> struct unpacket_traits<Packet8f> {
typedef Packet4f half;
typedef Packet8i integer_packet;
typedef uint8_t mask_t;
enum {size=8, alignment=Aligned32, vectorizable=true, masked_load_available=true, masked_store_available=true};
enum {size=8, alignment=Aligned32, vectorizable=true, masked_load_available=true, masked_store_available=true
#ifdef EIGEN_VECTORIZE_AVX512
, masked_fpops_available=true
#endif
};
};
template<> struct unpacket_traits<Packet4d> {
typedef double type;
@ -464,6 +468,13 @@ template<> EIGEN_STRONG_INLINE Packet8f pload1<Packet8f>(const float* from) { r
template<> EIGEN_STRONG_INLINE Packet4d pload1<Packet4d>(const double* from) { return _mm256_broadcast_sd(from); }
template<> EIGEN_STRONG_INLINE Packet8f padd<Packet8f>(const Packet8f& a, const Packet8f& b) { return _mm256_add_ps(a,b); }
#ifdef EIGEN_VECTORIZE_AVX512
template <>
EIGEN_STRONG_INLINE Packet8f padd<Packet8f>(const Packet8f& a, const Packet8f& b, uint8_t umask) {
__mmask8 mask = static_cast<__mmask8>(umask);
return _mm256_maskz_add_ps(mask, a, b);
}
#endif
template<> EIGEN_STRONG_INLINE Packet4d padd<Packet4d>(const Packet4d& a, const Packet4d& b) { return _mm256_add_pd(a,b); }
template<> EIGEN_STRONG_INLINE Packet8i padd<Packet8i>(const Packet8i& a, const Packet8i& b) {
#ifdef EIGEN_VECTORIZE_AVX2
@ -848,11 +859,16 @@ template<> EIGEN_STRONG_INLINE Packet4d ploadu<Packet4d>(const double* from) { E
template<> EIGEN_STRONG_INLINE Packet8i ploadu<Packet8i>(const int* from) { EIGEN_DEBUG_UNALIGNED_LOAD return _mm256_loadu_si256(reinterpret_cast<const __m256i*>(from)); }
template<> EIGEN_STRONG_INLINE Packet8f ploadu<Packet8f>(const float* from, uint8_t umask) {
#ifdef EIGEN_VECTORIZE_AVX512
__mmask8 mask = static_cast<__mmask8>(umask);
EIGEN_DEBUG_UNALIGNED_LOAD return _mm256_maskz_loadu_ps(mask, from);
#else
Packet8i mask = _mm256_set1_epi8(static_cast<char>(umask));
const Packet8i bit_mask = _mm256_set_epi32(0xffffff7f, 0xffffffbf, 0xffffffdf, 0xffffffef, 0xfffffff7, 0xfffffffb, 0xfffffffd, 0xfffffffe);
mask = por<Packet8i>(mask, bit_mask);
mask = pcmp_eq<Packet8i>(mask, _mm256_set1_epi32(0xffffffff));
EIGEN_DEBUG_UNALIGNED_LOAD return _mm256_maskload_ps(from, mask);
#endif
}
// Loads 4 floats from memory a returns the packet {a0, a0 a1, a1, a2, a2, a3, a3}
@ -911,11 +927,16 @@ template<> EIGEN_STRONG_INLINE void pstoreu<double>(double* to, const Packet4d&
template<> EIGEN_STRONG_INLINE void pstoreu<int>(int* to, const Packet8i& from) { EIGEN_DEBUG_UNALIGNED_STORE _mm256_storeu_si256(reinterpret_cast<__m256i*>(to), from); }
template<> EIGEN_STRONG_INLINE void pstoreu<float>(float* to, const Packet8f& from, uint8_t umask) {
#ifdef EIGEN_VECTORIZE_AVX512
__mmask8 mask = static_cast<__mmask8>(umask);
EIGEN_DEBUG_UNALIGNED_STORE return _mm256_mask_storeu_ps(to, mask, from);
#else
Packet8i mask = _mm256_set1_epi8(static_cast<char>(umask));
const Packet8i bit_mask = _mm256_set_epi32(0xffffff7f, 0xffffffbf, 0xffffffdf, 0xffffffef, 0xfffffff7, 0xfffffffb, 0xfffffffd, 0xfffffffe);
mask = por<Packet8i>(mask, bit_mask);
mask = pcmp_eq<Packet8i>(mask, _mm256_set1_epi32(0xffffffff));
EIGEN_DEBUG_UNALIGNED_STORE return _mm256_maskstore_ps(to, mask, from);
#endif
}
// NOTE: leverage _mm256_i32gather_ps and _mm256_i32gather_pd if AVX2 instructions are available

View File

@ -180,13 +180,14 @@ struct unpacket_traits<Packet16f> {
typedef Packet8f half;
typedef Packet16i integer_packet;
typedef uint16_t mask_t;
enum { size = 16, alignment=Aligned64, vectorizable=true, masked_load_available=true, masked_store_available=true };
enum { size = 16, alignment=Aligned64, vectorizable=true, masked_load_available=true, masked_store_available=true, masked_fpops_available=true };
};
template <>
struct unpacket_traits<Packet8d> {
typedef double type;
typedef Packet4d half;
enum { size = 8, alignment=Aligned64, vectorizable=true, masked_load_available=false, masked_store_available=false };
typedef uint8_t mask_t;
enum { size = 8, alignment=Aligned64, vectorizable=true, masked_load_available=true, masked_store_available=true, masked_fpops_available=true };
};
template <>
struct unpacket_traits<Packet16i> {
@ -244,11 +245,25 @@ template<> EIGEN_STRONG_INLINE Packet8d peven_mask(const Packet8d& /*a*/) {
template <>
EIGEN_STRONG_INLINE Packet16f pload1<Packet16f>(const float* from) {
#if (EIGEN_COMP_GNUC != 0) || (EIGEN_COMP_CLANG != 0)
// Inline asm here helps reduce some register spilling in TRSM kernels.
// See note in unrolls::gemm::microKernel in trsmKernel_impl.hpp
Packet16f ret;
__asm__ ("vbroadcastss %[mem], %[dst]" : [dst] "=v" (ret) : [mem] "m" (*from));
return ret;
#else
return _mm512_broadcastss_ps(_mm_load_ps1(from));
#endif
}
template <>
EIGEN_STRONG_INLINE Packet8d pload1<Packet8d>(const double* from) {
#if (EIGEN_COMP_GNUC != 0) || (EIGEN_COMP_CLANG != 0)
Packet8d ret;
__asm__ ("vbroadcastsd %[mem], %[dst]" : [dst] "=v" (ret) : [mem] "m" (*from));
return ret;
#else
return _mm512_set1_pd(*from);
#endif
}
template <>
@ -285,6 +300,20 @@ EIGEN_STRONG_INLINE Packet16i padd<Packet16i>(const Packet16i& a,
const Packet16i& b) {
return _mm512_add_epi32(a, b);
}
template <>
EIGEN_STRONG_INLINE Packet16f padd<Packet16f>(const Packet16f& a,
const Packet16f& b,
uint16_t umask) {
__mmask16 mask = static_cast<__mmask16>(umask);
return _mm512_maskz_add_ps(mask, a, b);
}
template <>
EIGEN_STRONG_INLINE Packet8d padd<Packet8d>(const Packet8d& a,
const Packet8d& b,
uint8_t umask) {
__mmask8 mask = static_cast<__mmask8>(umask);
return _mm512_maskz_add_pd(mask, a, b);
}
template <>
EIGEN_STRONG_INLINE Packet16f psub<Packet16f>(const Packet16f& a,
@ -771,12 +800,16 @@ EIGEN_STRONG_INLINE Packet16i ploadu<Packet16i>(const int* from) {
EIGEN_DEBUG_UNALIGNED_LOAD return _mm512_loadu_si512(
reinterpret_cast<const __m512i*>(from));
}
template <>
EIGEN_STRONG_INLINE Packet16f ploadu<Packet16f>(const float* from, uint16_t umask) {
__mmask16 mask = static_cast<__mmask16>(umask);
EIGEN_DEBUG_UNALIGNED_LOAD return _mm512_maskz_loadu_ps(mask, from);
}
template <>
EIGEN_STRONG_INLINE Packet8d ploadu<Packet8d>(const double* from, uint8_t umask) {
__mmask8 mask = static_cast<__mmask8>(umask);
EIGEN_DEBUG_UNALIGNED_LOAD return _mm512_maskz_loadu_pd(mask, from);
}
// Loads 8 floats from memory a returns the packet
// {a0, a0 a1, a1, a2, a2, a3, a3, a4, a4, a5, a5, a6, a6, a7, a7}
@ -886,6 +919,11 @@ EIGEN_STRONG_INLINE void pstoreu<float>(float* to, const Packet16f& from, uint16
__mmask16 mask = static_cast<__mmask16>(umask);
EIGEN_DEBUG_UNALIGNED_STORE return _mm512_mask_storeu_ps(to, mask, from);
}
template <>
EIGEN_STRONG_INLINE void pstoreu<double>(double* to, const Packet8d& from, uint8_t umask) {
__mmask8 mask = static_cast<__mmask8>(umask);
EIGEN_DEBUG_UNALIGNED_STORE return _mm512_mask_storeu_pd(to, mask, from);
}
template <>
EIGEN_DEVICE_FUNC inline Packet16f pgather<float, Packet16f>(const float* from,
@ -1392,6 +1430,56 @@ EIGEN_DEVICE_FUNC inline void ptranspose(PacketBlock<Packet16f, 16>& kernel) {
EIGEN_INSERT_8f_INTO_16f(OUTPUT[INDEX], INPUT[2 * INDEX], \
INPUT[2 * INDEX + STRIDE]);
EIGEN_DEVICE_FUNC inline void ptranspose(PacketBlock<Packet16f, 8>& kernel) {
__m512 T0 = _mm512_unpacklo_ps(kernel.packet[0],kernel.packet[1]);
__m512 T1 = _mm512_unpackhi_ps(kernel.packet[0],kernel.packet[1]);
__m512 T2 = _mm512_unpacklo_ps(kernel.packet[2],kernel.packet[3]);
__m512 T3 = _mm512_unpackhi_ps(kernel.packet[2],kernel.packet[3]);
__m512 T4 = _mm512_unpacklo_ps(kernel.packet[4],kernel.packet[5]);
__m512 T5 = _mm512_unpackhi_ps(kernel.packet[4],kernel.packet[5]);
__m512 T6 = _mm512_unpacklo_ps(kernel.packet[6],kernel.packet[7]);
__m512 T7 = _mm512_unpackhi_ps(kernel.packet[6],kernel.packet[7]);
kernel.packet[0] = reinterpret_cast<__m512>(
_mm512_unpacklo_pd(reinterpret_cast<__m512d>(T0),reinterpret_cast<__m512d>(T2)));
kernel.packet[1] = reinterpret_cast<__m512>(
_mm512_unpackhi_pd(reinterpret_cast<__m512d>(T0),reinterpret_cast<__m512d>(T2)));
kernel.packet[2] = reinterpret_cast<__m512>(
_mm512_unpacklo_pd(reinterpret_cast<__m512d>(T1),reinterpret_cast<__m512d>(T3)));
kernel.packet[3] = reinterpret_cast<__m512>(
_mm512_unpackhi_pd(reinterpret_cast<__m512d>(T1),reinterpret_cast<__m512d>(T3)));
kernel.packet[4] = reinterpret_cast<__m512>(
_mm512_unpacklo_pd(reinterpret_cast<__m512d>(T4),reinterpret_cast<__m512d>(T6)));
kernel.packet[5] = reinterpret_cast<__m512>(
_mm512_unpackhi_pd(reinterpret_cast<__m512d>(T4),reinterpret_cast<__m512d>(T6)));
kernel.packet[6] = reinterpret_cast<__m512>(
_mm512_unpacklo_pd(reinterpret_cast<__m512d>(T5),reinterpret_cast<__m512d>(T7)));
kernel.packet[7] = reinterpret_cast<__m512>(
_mm512_unpackhi_pd(reinterpret_cast<__m512d>(T5),reinterpret_cast<__m512d>(T7)));
T0 = reinterpret_cast<__m512>(_mm512_permutex_pd(reinterpret_cast<__m512d>(kernel.packet[4]), 0x4E));
T0 = _mm512_mask_blend_ps(0xF0F0, kernel.packet[0], T0);
T4 = reinterpret_cast<__m512>(_mm512_permutex_pd(reinterpret_cast<__m512d>(kernel.packet[0]), 0x4E));
T4 = _mm512_mask_blend_ps(0xF0F0, T4, kernel.packet[4]);
T1 = reinterpret_cast<__m512>(_mm512_permutex_pd(reinterpret_cast<__m512d>(kernel.packet[5]), 0x4E));
T1 = _mm512_mask_blend_ps(0xF0F0, kernel.packet[1], T1);
T5 = reinterpret_cast<__m512>(_mm512_permutex_pd(reinterpret_cast<__m512d>(kernel.packet[1]), 0x4E));
T5 = _mm512_mask_blend_ps(0xF0F0, T5, kernel.packet[5]);
T2 = reinterpret_cast<__m512>(_mm512_permutex_pd(reinterpret_cast<__m512d>(kernel.packet[6]), 0x4E));
T2 = _mm512_mask_blend_ps(0xF0F0, kernel.packet[2], T2);
T6 = reinterpret_cast<__m512>(_mm512_permutex_pd(reinterpret_cast<__m512d>(kernel.packet[2]), 0x4E));
T6 = _mm512_mask_blend_ps(0xF0F0, T6, kernel.packet[6]);
T3 = reinterpret_cast<__m512>(_mm512_permutex_pd(reinterpret_cast<__m512d>(kernel.packet[7]), 0x4E));
T3 = _mm512_mask_blend_ps(0xF0F0, kernel.packet[3], T3);
T7 = reinterpret_cast<__m512>(_mm512_permutex_pd(reinterpret_cast<__m512d>(kernel.packet[3]), 0x4E));
T7 = _mm512_mask_blend_ps(0xF0F0, T7, kernel.packet[7]);
kernel.packet[0] = T0; kernel.packet[1] = T1;
kernel.packet[2] = T2; kernel.packet[3] = T3;
kernel.packet[4] = T4; kernel.packet[5] = T5;
kernel.packet[6] = T6; kernel.packet[7] = T7;
}
EIGEN_DEVICE_FUNC inline void ptranspose(PacketBlock<Packet16f, 4>& kernel) {
__m512 T0 = _mm512_unpacklo_ps(kernel.packet[0], kernel.packet[1]);
__m512 T1 = _mm512_unpackhi_ps(kernel.packet[0], kernel.packet[1]);
@ -1468,62 +1556,53 @@ EIGEN_DEVICE_FUNC inline void ptranspose(PacketBlock<Packet8d, 4>& kernel) {
}
EIGEN_DEVICE_FUNC inline void ptranspose(PacketBlock<Packet8d, 8>& kernel) {
__m512d T0 = _mm512_unpacklo_pd(kernel.packet[0], kernel.packet[1]);
__m512d T1 = _mm512_unpackhi_pd(kernel.packet[0], kernel.packet[1]);
__m512d T2 = _mm512_unpacklo_pd(kernel.packet[2], kernel.packet[3]);
__m512d T3 = _mm512_unpackhi_pd(kernel.packet[2], kernel.packet[3]);
__m512d T4 = _mm512_unpacklo_pd(kernel.packet[4], kernel.packet[5]);
__m512d T5 = _mm512_unpackhi_pd(kernel.packet[4], kernel.packet[5]);
__m512d T6 = _mm512_unpacklo_pd(kernel.packet[6], kernel.packet[7]);
__m512d T7 = _mm512_unpackhi_pd(kernel.packet[6], kernel.packet[7]);
__m512d T0 = _mm512_unpacklo_pd(kernel.packet[0],kernel.packet[1]);
__m512d T1 = _mm512_unpackhi_pd(kernel.packet[0],kernel.packet[1]);
__m512d T2 = _mm512_unpacklo_pd(kernel.packet[2],kernel.packet[3]);
__m512d T3 = _mm512_unpackhi_pd(kernel.packet[2],kernel.packet[3]);
__m512d T4 = _mm512_unpacklo_pd(kernel.packet[4],kernel.packet[5]);
__m512d T5 = _mm512_unpackhi_pd(kernel.packet[4],kernel.packet[5]);
__m512d T6 = _mm512_unpacklo_pd(kernel.packet[6],kernel.packet[7]);
__m512d T7 = _mm512_unpackhi_pd(kernel.packet[6],kernel.packet[7]);
PacketBlock<Packet4d, 16> tmp;
kernel.packet[0] = _mm512_permutex_pd(T2, 0x4E);
kernel.packet[0] = _mm512_mask_blend_pd(0xCC, T0, kernel.packet[0]);
kernel.packet[2] = _mm512_permutex_pd(T0, 0x4E);
kernel.packet[2] = _mm512_mask_blend_pd(0xCC, kernel.packet[2], T2);
kernel.packet[1] = _mm512_permutex_pd(T3, 0x4E);
kernel.packet[1] = _mm512_mask_blend_pd(0xCC, T1, kernel.packet[1]);
kernel.packet[3] = _mm512_permutex_pd(T1, 0x4E);
kernel.packet[3] = _mm512_mask_blend_pd(0xCC, kernel.packet[3], T3);
kernel.packet[4] = _mm512_permutex_pd(T6, 0x4E);
kernel.packet[4] = _mm512_mask_blend_pd(0xCC, T4, kernel.packet[4]);
kernel.packet[6] = _mm512_permutex_pd(T4, 0x4E);
kernel.packet[6] = _mm512_mask_blend_pd(0xCC, kernel.packet[6], T6);
kernel.packet[5] = _mm512_permutex_pd(T7, 0x4E);
kernel.packet[5] = _mm512_mask_blend_pd(0xCC, T5, kernel.packet[5]);
kernel.packet[7] = _mm512_permutex_pd(T5, 0x4E);
kernel.packet[7] = _mm512_mask_blend_pd(0xCC, kernel.packet[7], T7);
tmp.packet[0] = _mm256_permute2f128_pd(_mm512_extractf64x4_pd(T0, 0),
_mm512_extractf64x4_pd(T2, 0), 0x20);
tmp.packet[1] = _mm256_permute2f128_pd(_mm512_extractf64x4_pd(T1, 0),
_mm512_extractf64x4_pd(T3, 0), 0x20);
tmp.packet[2] = _mm256_permute2f128_pd(_mm512_extractf64x4_pd(T0, 0),
_mm512_extractf64x4_pd(T2, 0), 0x31);
tmp.packet[3] = _mm256_permute2f128_pd(_mm512_extractf64x4_pd(T1, 0),
_mm512_extractf64x4_pd(T3, 0), 0x31);
T0 = _mm512_shuffle_f64x2(kernel.packet[4], kernel.packet[4], 0x4E);
T0 = _mm512_mask_blend_pd(0xF0, kernel.packet[0], T0);
T4 = _mm512_shuffle_f64x2(kernel.packet[0], kernel.packet[0], 0x4E);
T4 = _mm512_mask_blend_pd(0xF0, T4, kernel.packet[4]);
T1 = _mm512_shuffle_f64x2(kernel.packet[5], kernel.packet[5], 0x4E);
T1 = _mm512_mask_blend_pd(0xF0, kernel.packet[1], T1);
T5 = _mm512_shuffle_f64x2(kernel.packet[1], kernel.packet[1], 0x4E);
T5 = _mm512_mask_blend_pd(0xF0, T5, kernel.packet[5]);
T2 = _mm512_shuffle_f64x2(kernel.packet[6], kernel.packet[6], 0x4E);
T2 = _mm512_mask_blend_pd(0xF0, kernel.packet[2], T2);
T6 = _mm512_shuffle_f64x2(kernel.packet[2], kernel.packet[2], 0x4E);
T6 = _mm512_mask_blend_pd(0xF0, T6, kernel.packet[6]);
T3 = _mm512_shuffle_f64x2(kernel.packet[7], kernel.packet[7], 0x4E);
T3 = _mm512_mask_blend_pd(0xF0, kernel.packet[3], T3);
T7 = _mm512_shuffle_f64x2(kernel.packet[3], kernel.packet[3], 0x4E);
T7 = _mm512_mask_blend_pd(0xF0, T7, kernel.packet[7]);
tmp.packet[4] = _mm256_permute2f128_pd(_mm512_extractf64x4_pd(T0, 1),
_mm512_extractf64x4_pd(T2, 1), 0x20);
tmp.packet[5] = _mm256_permute2f128_pd(_mm512_extractf64x4_pd(T1, 1),
_mm512_extractf64x4_pd(T3, 1), 0x20);
tmp.packet[6] = _mm256_permute2f128_pd(_mm512_extractf64x4_pd(T0, 1),
_mm512_extractf64x4_pd(T2, 1), 0x31);
tmp.packet[7] = _mm256_permute2f128_pd(_mm512_extractf64x4_pd(T1, 1),
_mm512_extractf64x4_pd(T3, 1), 0x31);
tmp.packet[8] = _mm256_permute2f128_pd(_mm512_extractf64x4_pd(T4, 0),
_mm512_extractf64x4_pd(T6, 0), 0x20);
tmp.packet[9] = _mm256_permute2f128_pd(_mm512_extractf64x4_pd(T5, 0),
_mm512_extractf64x4_pd(T7, 0), 0x20);
tmp.packet[10] = _mm256_permute2f128_pd(_mm512_extractf64x4_pd(T4, 0),
_mm512_extractf64x4_pd(T6, 0), 0x31);
tmp.packet[11] = _mm256_permute2f128_pd(_mm512_extractf64x4_pd(T5, 0),
_mm512_extractf64x4_pd(T7, 0), 0x31);
tmp.packet[12] = _mm256_permute2f128_pd(_mm512_extractf64x4_pd(T4, 1),
_mm512_extractf64x4_pd(T6, 1), 0x20);
tmp.packet[13] = _mm256_permute2f128_pd(_mm512_extractf64x4_pd(T5, 1),
_mm512_extractf64x4_pd(T7, 1), 0x20);
tmp.packet[14] = _mm256_permute2f128_pd(_mm512_extractf64x4_pd(T4, 1),
_mm512_extractf64x4_pd(T6, 1), 0x31);
tmp.packet[15] = _mm256_permute2f128_pd(_mm512_extractf64x4_pd(T5, 1),
_mm512_extractf64x4_pd(T7, 1), 0x31);
PACK_OUTPUT_SQ_D(kernel.packet, tmp.packet, 0, 8);
PACK_OUTPUT_SQ_D(kernel.packet, tmp.packet, 1, 8);
PACK_OUTPUT_SQ_D(kernel.packet, tmp.packet, 2, 8);
PACK_OUTPUT_SQ_D(kernel.packet, tmp.packet, 3, 8);
PACK_OUTPUT_SQ_D(kernel.packet, tmp.packet, 4, 8);
PACK_OUTPUT_SQ_D(kernel.packet, tmp.packet, 5, 8);
PACK_OUTPUT_SQ_D(kernel.packet, tmp.packet, 6, 8);
PACK_OUTPUT_SQ_D(kernel.packet, tmp.packet, 7, 8);
kernel.packet[0] = T0; kernel.packet[1] = T1;
kernel.packet[2] = T2; kernel.packet[3] = T3;
kernel.packet[4] = T4; kernel.packet[5] = T5;
kernel.packet[6] = T6; kernel.packet[7] = T7;
}
#define PACK_OUTPUT_I32(OUTPUT, INPUT, INDEX, STRIDE) \

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -2,6 +2,7 @@
// for linear algebra.
//
// Copyright (C) 2009 Gael Guennebaud <gael.guennebaud@inria.fr>
// Modifications Copyright (C) 2022 Intel Corporation
//
// This Source Code Form is subject to the terms of the Mozilla
// Public License v. 2.0. If a copy of the MPL was not distributed
@ -16,6 +17,114 @@ namespace Eigen {
namespace internal {
template <typename Scalar, typename Index, int Mode, bool Conjugate, int TriStorageOrder,int OtherInnerStride>
struct trsm_kernels {
// Generic Implementation of triangular solve for triangular matrix on left and multiple rhs.
// Handles non-packed matrices.
static void trsmKernelL(
Index size, Index otherSize,
const Scalar* _tri, Index triStride,
Scalar* _other, Index otherIncr, Index otherStride);
// Generic Implementation of triangular solve for triangular matrix on right and multiple lhs.
// Handles non-packed matrices.
static void trsmKernelR(
Index size, Index otherSize,
const Scalar* _tri, Index triStride,
Scalar* _other, Index otherIncr, Index otherStride);
};
template <typename Scalar, typename Index, int Mode, bool Conjugate, int TriStorageOrder,int OtherInnerStride>
EIGEN_STRONG_INLINE void trsm_kernels<Scalar, Index, Mode, Conjugate, TriStorageOrder, OtherInnerStride>::trsmKernelL(
Index size, Index otherSize,
const Scalar* _tri, Index triStride,
Scalar* _other, Index otherIncr, Index otherStride)
{
typedef const_blas_data_mapper<Scalar, Index, TriStorageOrder> TriMapper;
typedef blas_data_mapper<Scalar, Index, ColMajor, Unaligned, OtherInnerStride> OtherMapper;
TriMapper tri(_tri, triStride);
OtherMapper other(_other, otherStride, otherIncr);
enum { IsLower = (Mode&Lower) == Lower };
conj_if<Conjugate> conj;
// tr solve
for (Index k=0; k<size; ++k)
{
// TODO write a small kernel handling this (can be shared with trsv)
Index i = IsLower ? k : -k-1;
Index rs = size - k - 1; // remaining size
Index s = TriStorageOrder==RowMajor ? (IsLower ? 0 : i+1)
: IsLower ? i+1 : i-rs;
Scalar a = (Mode & UnitDiag) ? Scalar(1) : Scalar(1)/conj(tri(i,i));
for (Index j=0; j<otherSize; ++j)
{
if (TriStorageOrder==RowMajor)
{
Scalar b(0);
const Scalar* l = &tri(i,s);
typename OtherMapper::LinearMapper r = other.getLinearMapper(s,j);
for (Index i3=0; i3<k; ++i3)
b += conj(l[i3]) * r(i3);
other(i,j) = (other(i,j) - b)*a;
}
else
{
Scalar& otherij = other(i,j);
otherij *= a;
Scalar b = otherij;
typename OtherMapper::LinearMapper r = other.getLinearMapper(s,j);
typename TriMapper::LinearMapper l = tri.getLinearMapper(s,i);
for (Index i3=0;i3<rs;++i3)
r(i3) -= b * conj(l(i3));
}
}
}
}
template <typename Scalar, typename Index, int Mode, bool Conjugate, int TriStorageOrder, int OtherInnerStride>
EIGEN_STRONG_INLINE void trsm_kernels<Scalar, Index, Mode, Conjugate, TriStorageOrder, OtherInnerStride>::trsmKernelR(
Index size, Index otherSize,
const Scalar* _tri, Index triStride,
Scalar* _other, Index otherIncr, Index otherStride)
{
typedef typename NumTraits<Scalar>::Real RealScalar;
typedef blas_data_mapper<Scalar, Index, ColMajor, Unaligned, OtherInnerStride> LhsMapper;
typedef const_blas_data_mapper<Scalar, Index, TriStorageOrder> RhsMapper;
LhsMapper lhs(_other, otherStride, otherIncr);
RhsMapper rhs(_tri, triStride);
enum {
RhsStorageOrder = TriStorageOrder,
IsLower = (Mode&Lower) == Lower
};
conj_if<Conjugate> conj;
for (Index k=0; k<size; ++k)
{
Index j = IsLower ? size-k-1 : k;
typename LhsMapper::LinearMapper r = lhs.getLinearMapper(0,j);
for (Index k3=0; k3<k; ++k3)
{
Scalar b = conj(rhs(IsLower ? j+1+k3 : k3,j));
typename LhsMapper::LinearMapper a = lhs.getLinearMapper(0,IsLower ? j+1+k3 : k3);
for (Index i=0; i<otherSize; ++i)
r(i) -= a(i) * b;
}
if((Mode & UnitDiag)==0)
{
Scalar inv_rjj = RealScalar(1)/conj(rhs(j,j));
for (Index i=0; i<otherSize; ++i)
r(i) *= inv_rjj;
}
}
}
// if the rhs is row major, let's transpose the product
template <typename Scalar, typename Index, int Side, int Mode, bool Conjugate, int TriStorageOrder, int OtherInnerStride>
struct triangular_solve_matrix<Scalar,Index,Side,Mode,Conjugate,TriStorageOrder,RowMajor,OtherInnerStride>
@ -46,6 +155,7 @@ struct triangular_solve_matrix<Scalar,Index,OnTheLeft,Mode,Conjugate,TriStorageO
Scalar* _other, Index otherIncr, Index otherStride,
level3_blocking<Scalar,Scalar>& blocking);
};
template <typename Scalar, typename Index, int Mode, bool Conjugate, int TriStorageOrder, int OtherInnerStride>
EIGEN_DONT_INLINE void triangular_solve_matrix<Scalar,Index,OnTheLeft,Mode,Conjugate,TriStorageOrder,ColMajor,OtherInnerStride>::run(
Index size, Index otherSize,
@ -55,6 +165,25 @@ EIGEN_DONT_INLINE void triangular_solve_matrix<Scalar,Index,OnTheLeft,Mode,Conju
{
Index cols = otherSize;
std::ptrdiff_t l1, l2, l3;
manage_caching_sizes(GetAction, &l1, &l2, &l3);
#if defined(EIGEN_ENABLE_AVX512_NOCOPY_TRSM_CUTOFFS)
EIGEN_IF_CONSTEXPR( (OtherInnerStride == 1 &&
(std::is_same<Scalar,float>::value ||
std::is_same<Scalar,double>::value)) ) {
// Very rough cutoffs to determine when to call trsm w/o packing
// For small problem sizes trsmKernel compiled with clang is generally faster.
// TODO: Investigate better heuristics for cutoffs.
double L2Cap = 0.5; // 50% of L2 size
if (size < avx512_trsm_cutoff<Scalar>(l2, cols, L2Cap)) {
trsm_kernels<Scalar, Index, Mode, Conjugate, TriStorageOrder, 1>::trsmKernelL(
size, cols, _tri, triStride, _other, 1, otherStride);
return;
}
}
#endif
typedef const_blas_data_mapper<Scalar, Index, TriStorageOrder> TriMapper;
typedef blas_data_mapper<Scalar, Index, ColMajor, Unaligned, OtherInnerStride> OtherMapper;
TriMapper tri(_tri, triStride);
@ -76,15 +205,12 @@ EIGEN_DONT_INLINE void triangular_solve_matrix<Scalar,Index,OnTheLeft,Mode,Conju
ei_declare_aligned_stack_constructed_variable(Scalar, blockA, sizeA, blocking.blockA());
ei_declare_aligned_stack_constructed_variable(Scalar, blockB, sizeB, blocking.blockB());
conj_if<Conjugate> conj;
gebp_kernel<Scalar, Scalar, Index, OtherMapper, Traits::mr, Traits::nr, Conjugate, false> gebp_kernel;
gemm_pack_lhs<Scalar, Index, TriMapper, Traits::mr, Traits::LhsProgress, typename Traits::LhsPacket4Packing, TriStorageOrder> pack_lhs;
gemm_pack_rhs<Scalar, Index, OtherMapper, Traits::nr, ColMajor, false, true> pack_rhs;
// the goal here is to subdivise the Rhs panels such that we keep some cache
// coherence when accessing the rhs elements
std::ptrdiff_t l1, l2, l3;
manage_caching_sizes(GetAction, &l1, &l2, &l3);
Index subcols = cols>0 ? l2/(4 * sizeof(Scalar) * std::max<Index>(otherStride,size)) : 0;
subcols = std::max<Index>((subcols/Traits::nr)*Traits::nr, Traits::nr);
@ -115,38 +241,19 @@ EIGEN_DONT_INLINE void triangular_solve_matrix<Scalar,Index,OnTheLeft,Mode,Conju
{
Index actualPanelWidth = std::min<Index>(actual_kc-k1, SmallPanelWidth);
// tr solve
for (Index k=0; k<actualPanelWidth; ++k)
{
// TODO write a small kernel handling this (can be shared with trsv)
Index i = IsLower ? k2+k1+k : k2-k1-k-1;
Index rs = actualPanelWidth - k - 1; // remaining size
Index s = TriStorageOrder==RowMajor ? (IsLower ? k2+k1 : i+1)
: IsLower ? i+1 : i-rs;
Scalar a = (Mode & UnitDiag) ? Scalar(1) : Scalar(1)/conj(tri(i,i));
for (Index j=j2; j<j2+actual_cols; ++j)
{
if (TriStorageOrder==RowMajor)
{
Scalar b(0);
const Scalar* l = &tri(i,s);
typename OtherMapper::LinearMapper r = other.getLinearMapper(s,j);
for (Index i3=0; i3<k; ++i3)
b += conj(l[i3]) * r(i3);
other(i,j) = (other(i,j) - b)*a;
}
else
{
Scalar& otherij = other(i,j);
otherij *= a;
Scalar b = otherij;
typename OtherMapper::LinearMapper r = other.getLinearMapper(s,j);
typename TriMapper::LinearMapper l = tri.getLinearMapper(s,i);
for (Index i3=0;i3<rs;++i3)
r(i3) -= b * conj(l(i3));
}
Index i = IsLower ? k2+k1 : k2-k1;
#if defined(EIGEN_USE_AVX512_TRSM_KERNELS)
EIGEN_IF_CONSTEXPR( (OtherInnerStride == 1 &&
(std::is_same<Scalar,float>::value ||
std::is_same<Scalar,double>::value)) ) {
i = IsLower ? k2 + k1: k2 - k1 - actualPanelWidth;
}
#endif
trsm_kernels<Scalar, Index, Mode, Conjugate, TriStorageOrder, OtherInnerStride>::trsmKernelL(
actualPanelWidth, actual_cols,
_tri + i + (i)*triStride, triStride,
_other + i*OtherInnerStride + j2*otherStride, otherIncr, otherStride);
}
Index lengthTarget = actual_kc-k1-actualPanelWidth;
@ -198,6 +305,7 @@ struct triangular_solve_matrix<Scalar,Index,OnTheRight,Mode,Conjugate,TriStorage
Scalar* _other, Index otherIncr, Index otherStride,
level3_blocking<Scalar,Scalar>& blocking);
};
template <typename Scalar, typename Index, int Mode, bool Conjugate, int TriStorageOrder, int OtherInnerStride>
EIGEN_DONT_INLINE void triangular_solve_matrix<Scalar,Index,OnTheRight,Mode,Conjugate,TriStorageOrder,ColMajor,OtherInnerStride>::run(
Index size, Index otherSize,
@ -206,7 +314,22 @@ EIGEN_DONT_INLINE void triangular_solve_matrix<Scalar,Index,OnTheRight,Mode,Conj
level3_blocking<Scalar,Scalar>& blocking)
{
Index rows = otherSize;
typedef typename NumTraits<Scalar>::Real RealScalar;
#if defined(EIGEN_ENABLE_AVX512_NOCOPY_TRSM_CUTOFFS)
EIGEN_IF_CONSTEXPR( (OtherInnerStride == 1 &&
(std::is_same<Scalar,float>::value ||
std::is_same<Scalar,double>::value)) ) {
// TODO: Investigate better heuristics for cutoffs.
std::ptrdiff_t l1, l2, l3;
manage_caching_sizes(GetAction, &l1, &l2, &l3);
double L2Cap = 0.5; // 50% of L2 size
if (size < avx512_trsm_cutoff<Scalar>(l2, rows, L2Cap)) {
trsm_kernels<Scalar, Index, Mode, Conjugate, TriStorageOrder, OtherInnerStride>::
trsmKernelR(size, rows, _tri, triStride, _other, 1, otherStride);
return;
}
}
#endif
typedef blas_data_mapper<Scalar, Index, ColMajor, Unaligned, OtherInnerStride> LhsMapper;
typedef const_blas_data_mapper<Scalar, Index, TriStorageOrder> RhsMapper;
@ -229,7 +352,6 @@ EIGEN_DONT_INLINE void triangular_solve_matrix<Scalar,Index,OnTheRight,Mode,Conj
ei_declare_aligned_stack_constructed_variable(Scalar, blockA, sizeA, blocking.blockA());
ei_declare_aligned_stack_constructed_variable(Scalar, blockB, sizeB, blocking.blockB());
conj_if<Conjugate> conj;
gebp_kernel<Scalar, Scalar, Index, LhsMapper, Traits::mr, Traits::nr, false, Conjugate> gebp_kernel;
gemm_pack_rhs<Scalar, Index, RhsMapper, Traits::nr, RhsStorageOrder> pack_rhs;
gemm_pack_rhs<Scalar, Index, RhsMapper, Traits::nr, RhsStorageOrder,false,true> pack_rhs_panel;
@ -296,27 +418,13 @@ EIGEN_DONT_INLINE void triangular_solve_matrix<Scalar,Index,OnTheRight,Mode,Conj
panelOffset, panelOffset); // offsets
}
// unblocked triangular solve
for (Index k=0; k<actualPanelWidth; ++k)
{
Index j = IsLower ? absolute_j2+actualPanelWidth-k-1 : absolute_j2+k;
typename LhsMapper::LinearMapper r = lhs.getLinearMapper(i2,j);
for (Index k3=0; k3<k; ++k3)
{
Scalar b = conj(rhs(IsLower ? j+1+k3 : absolute_j2+k3,j));
typename LhsMapper::LinearMapper a = lhs.getLinearMapper(i2,IsLower ? j+1+k3 : absolute_j2+k3);
for (Index i=0; i<actual_mc; ++i)
r(i) -= a(i) * b;
}
if((Mode & UnitDiag)==0)
{
Scalar inv_rjj = RealScalar(1)/conj(rhs(j,j));
for (Index i=0; i<actual_mc; ++i)
r(i) *= inv_rjj;
}
// unblocked triangular solve
trsm_kernels<Scalar, Index, Mode, Conjugate, TriStorageOrder, OtherInnerStride>::
trsmKernelR(actualPanelWidth, actual_mc,
_tri + absolute_j2 + absolute_j2*triStride, triStride,
_other + i2*OtherInnerStride + absolute_j2*otherStride, otherIncr, otherStride);
}
// pack the just computed part of lhs to A
pack_lhs_panel(blockA, lhs.getSubMapper(i2,absolute_j2),
actualPanelWidth, actual_mc,
@ -331,7 +439,6 @@ EIGEN_DONT_INLINE void triangular_solve_matrix<Scalar,Index,OnTheRight,Mode,Conj
}
}
}
} // end namespace internal
} // end namespace Eigen