mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-05-02 16:54:10 +08:00
AVX512 TRSM kernels use alloca if EIGEN_NO_MALLOC requested
This commit is contained in:
parent
4d1c16eab8
commit
37673ca1bc
File diff suppressed because it is too large
Load Diff
@ -12,13 +12,20 @@
|
||||
|
||||
#include "../../InternalHeaderCheck.h"
|
||||
|
||||
#define EIGEN_USE_AVX512_TRSM_KERNELS // Comment out to prevent using optimized trsm kernels.
|
||||
|
||||
#if defined(EIGEN_USE_AVX512_TRSM_KERNELS)
|
||||
#define EIGEN_USE_AVX512_TRSM_R_KERNELS
|
||||
#if !defined(EIGEN_NO_MALLOC) // Separate MACRO since these kernels require malloc
|
||||
#define EIGEN_USE_AVX512_TRSM_L_KERNELS
|
||||
#if !defined(EIGEN_USE_AVX512_TRSM_KERNELS)
|
||||
#define EIGEN_USE_AVX512_TRSM_KERNELS 1
|
||||
#endif
|
||||
|
||||
#if EIGEN_USE_AVX512_TRSM_KERNELS
|
||||
#if !defined(EIGEN_USE_AVX512_TRSM_R_KERNELS)
|
||||
#define EIGEN_USE_AVX512_TRSM_R_KERNELS 1
|
||||
#endif
|
||||
#if !defined(EIGEN_USE_AVX512_TRSM_L_KERNELS)
|
||||
#define EIGEN_USE_AVX512_TRSM_L_KERNELS 1
|
||||
#endif
|
||||
#else // EIGEN_USE_AVX512_TRSM_KERNELS == 0
|
||||
#define EIGEN_USE_AVX512_TRSM_R_KERNELS 0
|
||||
#define EIGEN_USE_AVX512_TRSM_L_KERNELS 0
|
||||
#endif
|
||||
|
||||
#if defined(EIGEN_HAS_CXX17_IFCONSTEXPR)
|
||||
@ -49,8 +56,7 @@ typedef Packet4d vecHalfDouble;
|
||||
// Note: this depends on macros and typedefs above.
|
||||
#include "TrsmUnrolls.inc"
|
||||
|
||||
|
||||
#if defined(EIGEN_USE_AVX512_TRSM_KERNELS) && (EIGEN_COMP_CLANG != 0)
|
||||
#if (EIGEN_USE_AVX512_TRSM_KERNELS) && (EIGEN_COMP_CLANG != 0)
|
||||
/**
|
||||
* For smaller problem sizes, and certain compilers, using the optimized kernels trsmKernelL/R directly
|
||||
* is faster than the packed versions in TriangularSolverMatrix.h.
|
||||
@ -67,34 +73,46 @@ typedef Packet4d vecHalfDouble;
|
||||
* M = Dimension of triangular matrix
|
||||
*
|
||||
*/
|
||||
#define EIGEN_ENABLE_AVX512_NOCOPY_TRSM_CUTOFFS // Comment out to disable no-copy dispatch
|
||||
#if !defined(EIGEN_ENABLE_AVX512_NOCOPY_TRSM_CUTOFFS)
|
||||
#define EIGEN_ENABLE_AVX512_NOCOPY_TRSM_CUTOFFS 1
|
||||
#endif
|
||||
|
||||
#if defined(EIGEN_ENABLE_AVX512_NOCOPY_TRSM_CUTOFFS)
|
||||
#define EIGEN_ENABLE_AVX512_NOCOPY_TRSM_R_CUTOFFS
|
||||
#if !defined(EIGEN_NO_MALLOC)
|
||||
#define EIGEN_ENABLE_AVX512_NOCOPY_TRSM_L_CUTOFFS
|
||||
#if EIGEN_ENABLE_AVX512_NOCOPY_TRSM_CUTOFFS
|
||||
|
||||
#if EIGEN_USE_AVX512_TRSM_R_KERNELS
|
||||
#if !defined(EIGEN_ENABLE_AVX512_NOCOPY_TRSM_R_CUTOFFS)
|
||||
#define EIGEN_ENABLE_AVX512_NOCOPY_TRSM_R_CUTOFFS 1
|
||||
#endif // !defined(EIGEN_ENABLE_AVX512_NOCOPY_TRSM_R_CUTOFFS)
|
||||
#endif
|
||||
|
||||
#if EIGEN_USE_AVX512_TRSM_L_KERNELS
|
||||
#if !defined(EIGEN_ENABLE_AVX512_NOCOPY_TRSM_L_CUTOFFS)
|
||||
#define EIGEN_ENABLE_AVX512_NOCOPY_TRSM_L_CUTOFFS 1
|
||||
#endif
|
||||
#endif // EIGEN_USE_AVX512_TRSM_L_KERNELS
|
||||
|
||||
#else // EIGEN_ENABLE_AVX512_NOCOPY_TRSM_CUTOFFS == 0
|
||||
#define EIGEN_ENABLE_AVX512_NOCOPY_TRSM_R_CUTOFFS 0
|
||||
#define EIGEN_ENABLE_AVX512_NOCOPY_TRSM_L_CUTOFFS 0
|
||||
#endif // EIGEN_ENABLE_AVX512_NOCOPY_TRSM_CUTOFFS
|
||||
|
||||
template <typename Scalar>
|
||||
int64_t avx512_trsm_cutoff(int64_t L2Size, int64_t N, double L2Cap) {
|
||||
const int64_t U3 = 3 * packet_traits<Scalar>::size;
|
||||
const int64_t MaxNb = 5 * U3;
|
||||
int64_t Nb = std::min(MaxNb, N);
|
||||
double cutoff_d = (((L2Size*L2Cap)/(sizeof(Scalar)))-(EIGEN_AVX_MAX_NUM_ROW)*Nb)/
|
||||
((EIGEN_AVX_MAX_NUM_ROW)+Nb);
|
||||
double cutoff_d =
|
||||
(((L2Size * L2Cap) / (sizeof(Scalar))) - (EIGEN_AVX_MAX_NUM_ROW)*Nb) / ((EIGEN_AVX_MAX_NUM_ROW) + Nb);
|
||||
int64_t cutoff_l = static_cast<int64_t>(cutoff_d);
|
||||
return (cutoff_l / EIGEN_AVX_MAX_NUM_ROW) * EIGEN_AVX_MAX_NUM_ROW;
|
||||
}
|
||||
#endif
|
||||
|
||||
|
||||
/**
|
||||
* Used by gemmKernel for the case A/B row-major and C col-major.
|
||||
*/
|
||||
template <typename Scalar, typename vec, int64_t unrollM, int64_t unrollN, bool remM, bool remN>
|
||||
static EIGEN_ALWAYS_INLINE
|
||||
void transStoreC(PacketBlock<vec,EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm,
|
||||
static EIGEN_ALWAYS_INLINE void transStoreC(PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm,
|
||||
Scalar *C_arr, int64_t LDC, int64_t remM_ = 0, int64_t remN_ = 0) {
|
||||
EIGEN_UNUSED_VARIABLE(remN_);
|
||||
EIGEN_UNUSED_VARIABLE(remM_);
|
||||
@ -194,16 +212,13 @@ void transStoreC(PacketBlock<vec,EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm,
|
||||
* handleKRem: Handle arbitrary K? This is not needed for trsm.
|
||||
*/
|
||||
template <typename Scalar, bool isARowMajor, bool isCRowMajor, bool isAdd, bool handleKRem>
|
||||
void gemmKernel(Scalar *A_arr, Scalar *B_arr, Scalar *C_arr,
|
||||
int64_t M, int64_t N, int64_t K,
|
||||
int64_t LDA, int64_t LDB, int64_t LDC) {
|
||||
void gemmKernel(Scalar *A_arr, Scalar *B_arr, Scalar *C_arr, int64_t M, int64_t N, int64_t K, int64_t LDA, int64_t LDB,
|
||||
int64_t LDC) {
|
||||
using urolls = unrolls::gemm<Scalar, isAdd>;
|
||||
constexpr int64_t U3 = urolls::PacketSize * 3;
|
||||
constexpr int64_t U2 = urolls::PacketSize * 2;
|
||||
constexpr int64_t U1 = urolls::PacketSize * 1;
|
||||
using vec = typename std::conditional<std::is_same<Scalar, float>::value,
|
||||
vecFullFloat,
|
||||
vecFullDouble>::type;
|
||||
using vec = typename std::conditional<std::is_same<Scalar, float>::value, vecFullFloat, vecFullDouble>::type;
|
||||
int64_t N_ = (N / U3) * U3;
|
||||
int64_t M_ = (M / EIGEN_AVX_MAX_NUM_ROW) * EIGEN_AVX_MAX_NUM_ROW;
|
||||
int64_t K_ = (K / EIGEN_AVX_MAX_K_UNROL) * EIGEN_AVX_MAX_K_UNROL;
|
||||
@ -216,19 +231,19 @@ void gemmKernel(Scalar *A_arr, Scalar *B_arr, Scalar *C_arr,
|
||||
PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> zmm;
|
||||
urolls::template setzero<3, EIGEN_AVX_MAX_NUM_ROW>(zmm);
|
||||
for (int64_t k = 0; k < K_; k += EIGEN_AVX_MAX_K_UNROL) {
|
||||
urolls:: template microKernel<isARowMajor,3,EIGEN_AVX_MAX_NUM_ROW,EIGEN_AVX_MAX_K_UNROL,
|
||||
EIGEN_AVX_MAX_B_LOAD,EIGEN_AVX_MAX_A_BCAST>(
|
||||
B_t, A_t, LDB, LDA, zmm);
|
||||
urolls::template microKernel<isARowMajor, 3, EIGEN_AVX_MAX_NUM_ROW, EIGEN_AVX_MAX_K_UNROL, EIGEN_AVX_MAX_B_LOAD,
|
||||
EIGEN_AVX_MAX_A_BCAST>(B_t, A_t, LDB, LDA, zmm);
|
||||
B_t += EIGEN_AVX_MAX_K_UNROL * LDB;
|
||||
EIGEN_IF_CONSTEXPR(isARowMajor) A_t += EIGEN_AVX_MAX_K_UNROL; else A_t += EIGEN_AVX_MAX_K_UNROL*LDA;
|
||||
EIGEN_IF_CONSTEXPR(isARowMajor) A_t += EIGEN_AVX_MAX_K_UNROL;
|
||||
else A_t += EIGEN_AVX_MAX_K_UNROL * LDA;
|
||||
}
|
||||
EIGEN_IF_CONSTEXPR(handleKRem) {
|
||||
for (int64_t k = K_; k < K; k++) {
|
||||
urolls:: template microKernel<isARowMajor,3,EIGEN_AVX_MAX_NUM_ROW,1,
|
||||
EIGEN_AVX_B_LOAD_SETS*3,EIGEN_AVX_MAX_A_BCAST>(
|
||||
B_t, A_t, LDB, LDA, zmm);
|
||||
urolls::template microKernel<isARowMajor, 3, EIGEN_AVX_MAX_NUM_ROW, 1, EIGEN_AVX_B_LOAD_SETS * 3,
|
||||
EIGEN_AVX_MAX_A_BCAST>(B_t, A_t, LDB, LDA, zmm);
|
||||
B_t += LDB;
|
||||
EIGEN_IF_CONSTEXPR(isARowMajor) A_t++; else A_t += LDA;
|
||||
EIGEN_IF_CONSTEXPR(isARowMajor) A_t++;
|
||||
else A_t += LDA;
|
||||
}
|
||||
}
|
||||
EIGEN_IF_CONSTEXPR(isCRowMajor) {
|
||||
@ -245,19 +260,19 @@ void gemmKernel(Scalar *A_arr, Scalar *B_arr, Scalar *C_arr,
|
||||
PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> zmm;
|
||||
urolls::template setzero<3, 4>(zmm);
|
||||
for (int64_t k = 0; k < K_; k += EIGEN_AVX_MAX_K_UNROL) {
|
||||
urolls:: template microKernel<isARowMajor,3,4,EIGEN_AVX_MAX_K_UNROL,
|
||||
EIGEN_AVX_B_LOAD_SETS*3,EIGEN_AVX_MAX_A_BCAST>(
|
||||
B_t, A_t, LDB, LDA, zmm);
|
||||
urolls::template microKernel<isARowMajor, 3, 4, EIGEN_AVX_MAX_K_UNROL, EIGEN_AVX_B_LOAD_SETS * 3,
|
||||
EIGEN_AVX_MAX_A_BCAST>(B_t, A_t, LDB, LDA, zmm);
|
||||
B_t += EIGEN_AVX_MAX_K_UNROL * LDB;
|
||||
EIGEN_IF_CONSTEXPR(isARowMajor) A_t += EIGEN_AVX_MAX_K_UNROL; else A_t += EIGEN_AVX_MAX_K_UNROL*LDA;
|
||||
EIGEN_IF_CONSTEXPR(isARowMajor) A_t += EIGEN_AVX_MAX_K_UNROL;
|
||||
else A_t += EIGEN_AVX_MAX_K_UNROL * LDA;
|
||||
}
|
||||
EIGEN_IF_CONSTEXPR(handleKRem) {
|
||||
for (int64_t k = K_; k < K; k++) {
|
||||
urolls:: template microKernel<isARowMajor,3,4,1,
|
||||
EIGEN_AVX_B_LOAD_SETS*3,EIGEN_AVX_MAX_A_BCAST>(
|
||||
urolls::template microKernel<isARowMajor, 3, 4, 1, EIGEN_AVX_B_LOAD_SETS * 3, EIGEN_AVX_MAX_A_BCAST>(
|
||||
B_t, A_t, LDB, LDA, zmm);
|
||||
B_t += LDB;
|
||||
EIGEN_IF_CONSTEXPR(isARowMajor) A_t++; else A_t += LDA;
|
||||
EIGEN_IF_CONSTEXPR(isARowMajor) A_t++;
|
||||
else A_t += LDA;
|
||||
}
|
||||
}
|
||||
EIGEN_IF_CONSTEXPR(isCRowMajor) {
|
||||
@ -275,19 +290,19 @@ void gemmKernel(Scalar *A_arr, Scalar *B_arr, Scalar *C_arr,
|
||||
PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> zmm;
|
||||
urolls::template setzero<3, 2>(zmm);
|
||||
for (int64_t k = 0; k < K_; k += EIGEN_AVX_MAX_K_UNROL) {
|
||||
urolls:: template microKernel<isARowMajor,3,2,EIGEN_AVX_MAX_K_UNROL,
|
||||
EIGEN_AVX_B_LOAD_SETS*3,EIGEN_AVX_MAX_A_BCAST>(
|
||||
B_t, A_t, LDB, LDA, zmm);
|
||||
urolls::template microKernel<isARowMajor, 3, 2, EIGEN_AVX_MAX_K_UNROL, EIGEN_AVX_B_LOAD_SETS * 3,
|
||||
EIGEN_AVX_MAX_A_BCAST>(B_t, A_t, LDB, LDA, zmm);
|
||||
B_t += EIGEN_AVX_MAX_K_UNROL * LDB;
|
||||
EIGEN_IF_CONSTEXPR(isARowMajor) A_t += EIGEN_AVX_MAX_K_UNROL; else A_t += EIGEN_AVX_MAX_K_UNROL*LDA;
|
||||
EIGEN_IF_CONSTEXPR(isARowMajor) A_t += EIGEN_AVX_MAX_K_UNROL;
|
||||
else A_t += EIGEN_AVX_MAX_K_UNROL * LDA;
|
||||
}
|
||||
EIGEN_IF_CONSTEXPR(handleKRem) {
|
||||
for (int64_t k = K_; k < K; k++) {
|
||||
urolls:: template microKernel<isARowMajor,3,2,1,
|
||||
EIGEN_AVX_B_LOAD_SETS*3,EIGEN_AVX_MAX_A_BCAST>(
|
||||
urolls::template microKernel<isARowMajor, 3, 2, 1, EIGEN_AVX_B_LOAD_SETS * 3, EIGEN_AVX_MAX_A_BCAST>(
|
||||
B_t, A_t, LDB, LDA, zmm);
|
||||
B_t += LDB;
|
||||
EIGEN_IF_CONSTEXPR(isARowMajor) A_t++; else A_t += LDA;
|
||||
EIGEN_IF_CONSTEXPR(isARowMajor) A_t++;
|
||||
else A_t += LDA;
|
||||
}
|
||||
}
|
||||
EIGEN_IF_CONSTEXPR(isCRowMajor) {
|
||||
@ -306,18 +321,18 @@ void gemmKernel(Scalar *A_arr, Scalar *B_arr, Scalar *C_arr,
|
||||
urolls::template setzero<3, 1>(zmm);
|
||||
{
|
||||
for (int64_t k = 0; k < K_; k += EIGEN_AVX_MAX_K_UNROL) {
|
||||
urolls:: template microKernel<isARowMajor,3,1,EIGEN_AVX_MAX_K_UNROL,
|
||||
EIGEN_AVX_B_LOAD_SETS*3,1>(
|
||||
urolls::template microKernel<isARowMajor, 3, 1, EIGEN_AVX_MAX_K_UNROL, EIGEN_AVX_B_LOAD_SETS * 3, 1>(
|
||||
B_t, A_t, LDB, LDA, zmm);
|
||||
B_t += EIGEN_AVX_MAX_K_UNROL * LDB;
|
||||
EIGEN_IF_CONSTEXPR(isARowMajor) A_t += EIGEN_AVX_MAX_K_UNROL; else A_t += EIGEN_AVX_MAX_K_UNROL*LDA;
|
||||
EIGEN_IF_CONSTEXPR(isARowMajor) A_t += EIGEN_AVX_MAX_K_UNROL;
|
||||
else A_t += EIGEN_AVX_MAX_K_UNROL * LDA;
|
||||
}
|
||||
EIGEN_IF_CONSTEXPR(handleKRem) {
|
||||
for (int64_t k = K_; k < K; k++) {
|
||||
urolls:: template microKernel<isARowMajor,3,1,1,
|
||||
EIGEN_AVX_B_LOAD_SETS*3,1>(B_t, A_t, LDB, LDA, zmm);
|
||||
urolls::template microKernel<isARowMajor, 3, 1, 1, EIGEN_AVX_B_LOAD_SETS * 3, 1>(B_t, A_t, LDB, LDA, zmm);
|
||||
B_t += LDB;
|
||||
EIGEN_IF_CONSTEXPR(isARowMajor) A_t++; else A_t += LDA;
|
||||
EIGEN_IF_CONSTEXPR(isARowMajor) A_t++;
|
||||
else A_t += LDA;
|
||||
}
|
||||
}
|
||||
EIGEN_IF_CONSTEXPR(isCRowMajor) {
|
||||
@ -339,19 +354,19 @@ void gemmKernel(Scalar *A_arr, Scalar *B_arr, Scalar *C_arr,
|
||||
PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> zmm;
|
||||
urolls::template setzero<2, EIGEN_AVX_MAX_NUM_ROW>(zmm);
|
||||
for (int64_t k = 0; k < K_; k += EIGEN_AVX_MAX_K_UNROL) {
|
||||
urolls:: template microKernel<isARowMajor,2,EIGEN_AVX_MAX_NUM_ROW,
|
||||
EIGEN_AVX_MAX_K_UNROL,EIGEN_AVX_MAX_B_LOAD,EIGEN_AVX_MAX_A_BCAST>(
|
||||
B_t, A_t, LDB, LDA, zmm);
|
||||
urolls::template microKernel<isARowMajor, 2, EIGEN_AVX_MAX_NUM_ROW, EIGEN_AVX_MAX_K_UNROL, EIGEN_AVX_MAX_B_LOAD,
|
||||
EIGEN_AVX_MAX_A_BCAST>(B_t, A_t, LDB, LDA, zmm);
|
||||
B_t += EIGEN_AVX_MAX_K_UNROL * LDB;
|
||||
EIGEN_IF_CONSTEXPR(isARowMajor) A_t += EIGEN_AVX_MAX_K_UNROL; else A_t += EIGEN_AVX_MAX_K_UNROL*LDA;
|
||||
EIGEN_IF_CONSTEXPR(isARowMajor) A_t += EIGEN_AVX_MAX_K_UNROL;
|
||||
else A_t += EIGEN_AVX_MAX_K_UNROL * LDA;
|
||||
}
|
||||
EIGEN_IF_CONSTEXPR(handleKRem) {
|
||||
for (int64_t k = K_; k < K; k++) {
|
||||
urolls:: template microKernel<isARowMajor,2,EIGEN_AVX_MAX_NUM_ROW,1,
|
||||
EIGEN_AVX_MAX_B_LOAD,EIGEN_AVX_MAX_A_BCAST>(
|
||||
B_t, A_t, LDB, LDA, zmm);
|
||||
urolls::template microKernel<isARowMajor, 2, EIGEN_AVX_MAX_NUM_ROW, 1, EIGEN_AVX_MAX_B_LOAD,
|
||||
EIGEN_AVX_MAX_A_BCAST>(B_t, A_t, LDB, LDA, zmm);
|
||||
B_t += LDB;
|
||||
EIGEN_IF_CONSTEXPR(isARowMajor) A_t++; else A_t += LDA;
|
||||
EIGEN_IF_CONSTEXPR(isARowMajor) A_t++;
|
||||
else A_t += LDA;
|
||||
}
|
||||
}
|
||||
EIGEN_IF_CONSTEXPR(isCRowMajor) {
|
||||
@ -368,19 +383,19 @@ void gemmKernel(Scalar *A_arr, Scalar *B_arr, Scalar *C_arr,
|
||||
PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> zmm;
|
||||
urolls::template setzero<2, 4>(zmm);
|
||||
for (int64_t k = 0; k < K_; k += EIGEN_AVX_MAX_K_UNROL) {
|
||||
urolls:: template microKernel<isARowMajor,2,4,EIGEN_AVX_MAX_K_UNROL,
|
||||
EIGEN_AVX_MAX_B_LOAD,EIGEN_AVX_MAX_A_BCAST>(
|
||||
B_t, A_t, LDB, LDA, zmm);
|
||||
urolls::template microKernel<isARowMajor, 2, 4, EIGEN_AVX_MAX_K_UNROL, EIGEN_AVX_MAX_B_LOAD,
|
||||
EIGEN_AVX_MAX_A_BCAST>(B_t, A_t, LDB, LDA, zmm);
|
||||
B_t += EIGEN_AVX_MAX_K_UNROL * LDB;
|
||||
EIGEN_IF_CONSTEXPR(isARowMajor) A_t += EIGEN_AVX_MAX_K_UNROL; else A_t += EIGEN_AVX_MAX_K_UNROL*LDA;
|
||||
EIGEN_IF_CONSTEXPR(isARowMajor) A_t += EIGEN_AVX_MAX_K_UNROL;
|
||||
else A_t += EIGEN_AVX_MAX_K_UNROL * LDA;
|
||||
}
|
||||
EIGEN_IF_CONSTEXPR(handleKRem) {
|
||||
for (int64_t k = K_; k < K; k++) {
|
||||
urolls:: template microKernel<isARowMajor,2,4,1,
|
||||
EIGEN_AVX_MAX_B_LOAD,EIGEN_AVX_MAX_A_BCAST>(
|
||||
B_t, A_t, LDB, LDA, zmm);
|
||||
urolls::template microKernel<isARowMajor, 2, 4, 1, EIGEN_AVX_MAX_B_LOAD, EIGEN_AVX_MAX_A_BCAST>(B_t, A_t, LDB,
|
||||
LDA, zmm);
|
||||
B_t += LDB;
|
||||
EIGEN_IF_CONSTEXPR(isARowMajor) A_t++; else A_t += LDA;
|
||||
EIGEN_IF_CONSTEXPR(isARowMajor) A_t++;
|
||||
else A_t += LDA;
|
||||
}
|
||||
}
|
||||
EIGEN_IF_CONSTEXPR(isCRowMajor) {
|
||||
@ -398,19 +413,19 @@ void gemmKernel(Scalar *A_arr, Scalar *B_arr, Scalar *C_arr,
|
||||
PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> zmm;
|
||||
urolls::template setzero<2, 2>(zmm);
|
||||
for (int64_t k = 0; k < K_; k += EIGEN_AVX_MAX_K_UNROL) {
|
||||
urolls:: template microKernel<isARowMajor,2,2,EIGEN_AVX_MAX_K_UNROL,
|
||||
EIGEN_AVX_MAX_B_LOAD,EIGEN_AVX_MAX_A_BCAST>(
|
||||
B_t, A_t, LDB, LDA, zmm);
|
||||
urolls::template microKernel<isARowMajor, 2, 2, EIGEN_AVX_MAX_K_UNROL, EIGEN_AVX_MAX_B_LOAD,
|
||||
EIGEN_AVX_MAX_A_BCAST>(B_t, A_t, LDB, LDA, zmm);
|
||||
B_t += EIGEN_AVX_MAX_K_UNROL * LDB;
|
||||
EIGEN_IF_CONSTEXPR(isARowMajor) A_t += EIGEN_AVX_MAX_K_UNROL; else A_t += EIGEN_AVX_MAX_K_UNROL*LDA;
|
||||
EIGEN_IF_CONSTEXPR(isARowMajor) A_t += EIGEN_AVX_MAX_K_UNROL;
|
||||
else A_t += EIGEN_AVX_MAX_K_UNROL * LDA;
|
||||
}
|
||||
EIGEN_IF_CONSTEXPR(handleKRem) {
|
||||
for (int64_t k = K_; k < K; k++) {
|
||||
urolls:: template microKernel<isARowMajor,2,2,1,
|
||||
EIGEN_AVX_MAX_B_LOAD,EIGEN_AVX_MAX_A_BCAST>(
|
||||
B_t, A_t, LDB, LDA, zmm);
|
||||
urolls::template microKernel<isARowMajor, 2, 2, 1, EIGEN_AVX_MAX_B_LOAD, EIGEN_AVX_MAX_A_BCAST>(B_t, A_t, LDB,
|
||||
LDA, zmm);
|
||||
B_t += LDB;
|
||||
EIGEN_IF_CONSTEXPR(isARowMajor) A_t++; else A_t += LDA;
|
||||
EIGEN_IF_CONSTEXPR(isARowMajor) A_t++;
|
||||
else A_t += LDA;
|
||||
}
|
||||
}
|
||||
EIGEN_IF_CONSTEXPR(isCRowMajor) {
|
||||
@ -428,18 +443,18 @@ void gemmKernel(Scalar *A_arr, Scalar *B_arr, Scalar *C_arr,
|
||||
PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> zmm;
|
||||
urolls::template setzero<2, 1>(zmm);
|
||||
for (int64_t k = 0; k < K_; k += EIGEN_AVX_MAX_K_UNROL) {
|
||||
urolls:: template microKernel<isARowMajor,2,1,EIGEN_AVX_MAX_K_UNROL,
|
||||
EIGEN_AVX_MAX_B_LOAD,1>(
|
||||
B_t, A_t, LDB, LDA, zmm);
|
||||
urolls::template microKernel<isARowMajor, 2, 1, EIGEN_AVX_MAX_K_UNROL, EIGEN_AVX_MAX_B_LOAD, 1>(B_t, A_t, LDB,
|
||||
LDA, zmm);
|
||||
B_t += EIGEN_AVX_MAX_K_UNROL * LDB;
|
||||
EIGEN_IF_CONSTEXPR(isARowMajor) A_t += EIGEN_AVX_MAX_K_UNROL; else A_t += EIGEN_AVX_MAX_K_UNROL*LDA;
|
||||
EIGEN_IF_CONSTEXPR(isARowMajor) A_t += EIGEN_AVX_MAX_K_UNROL;
|
||||
else A_t += EIGEN_AVX_MAX_K_UNROL * LDA;
|
||||
}
|
||||
EIGEN_IF_CONSTEXPR(handleKRem) {
|
||||
for (int64_t k = K_; k < K; k++) {
|
||||
urolls:: template microKernel<isARowMajor,2,1,1,
|
||||
EIGEN_AVX_MAX_B_LOAD,1>(B_t, A_t, LDB, LDA, zmm);
|
||||
urolls::template microKernel<isARowMajor, 2, 1, 1, EIGEN_AVX_MAX_B_LOAD, 1>(B_t, A_t, LDB, LDA, zmm);
|
||||
B_t += LDB;
|
||||
EIGEN_IF_CONSTEXPR(isARowMajor) A_t++; else A_t += LDA;
|
||||
EIGEN_IF_CONSTEXPR(isARowMajor) A_t++;
|
||||
else A_t += LDA;
|
||||
}
|
||||
}
|
||||
EIGEN_IF_CONSTEXPR(isCRowMajor) {
|
||||
@ -460,19 +475,19 @@ void gemmKernel(Scalar *A_arr, Scalar *B_arr, Scalar *C_arr,
|
||||
PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> zmm;
|
||||
urolls::template setzero<1, EIGEN_AVX_MAX_NUM_ROW>(zmm);
|
||||
for (int64_t k = 0; k < K_; k += EIGEN_AVX_MAX_K_UNROL) {
|
||||
urolls:: template microKernel<isARowMajor,1,EIGEN_AVX_MAX_NUM_ROW,EIGEN_AVX_MAX_K_UNROL,
|
||||
EIGEN_AVX_MAX_B_LOAD,EIGEN_AVX_MAX_A_BCAST>(
|
||||
B_t, A_t, LDB, LDA, zmm);
|
||||
urolls::template microKernel<isARowMajor, 1, EIGEN_AVX_MAX_NUM_ROW, EIGEN_AVX_MAX_K_UNROL, EIGEN_AVX_MAX_B_LOAD,
|
||||
EIGEN_AVX_MAX_A_BCAST>(B_t, A_t, LDB, LDA, zmm);
|
||||
B_t += EIGEN_AVX_MAX_K_UNROL * LDB;
|
||||
EIGEN_IF_CONSTEXPR(isARowMajor) A_t += EIGEN_AVX_MAX_K_UNROL; else A_t += EIGEN_AVX_MAX_K_UNROL*LDA;
|
||||
EIGEN_IF_CONSTEXPR(isARowMajor) A_t += EIGEN_AVX_MAX_K_UNROL;
|
||||
else A_t += EIGEN_AVX_MAX_K_UNROL * LDA;
|
||||
}
|
||||
EIGEN_IF_CONSTEXPR(handleKRem) {
|
||||
for (int64_t k = K_; k < K; k++) {
|
||||
urolls:: template microKernel<isARowMajor,1,EIGEN_AVX_MAX_NUM_ROW,1,
|
||||
EIGEN_AVX_B_LOAD_SETS*1,EIGEN_AVX_MAX_A_BCAST>(
|
||||
B_t, A_t, LDB, LDA, zmm);
|
||||
urolls::template microKernel<isARowMajor, 1, EIGEN_AVX_MAX_NUM_ROW, 1, EIGEN_AVX_B_LOAD_SETS * 1,
|
||||
EIGEN_AVX_MAX_A_BCAST>(B_t, A_t, LDB, LDA, zmm);
|
||||
B_t += LDB;
|
||||
EIGEN_IF_CONSTEXPR(isARowMajor) A_t++; else A_t += LDA;
|
||||
EIGEN_IF_CONSTEXPR(isARowMajor) A_t++;
|
||||
else A_t += LDA;
|
||||
}
|
||||
}
|
||||
EIGEN_IF_CONSTEXPR(isCRowMajor) {
|
||||
@ -489,19 +504,19 @@ void gemmKernel(Scalar *A_arr, Scalar *B_arr, Scalar *C_arr,
|
||||
PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> zmm;
|
||||
urolls::template setzero<1, 4>(zmm);
|
||||
for (int64_t k = 0; k < K_; k += EIGEN_AVX_MAX_K_UNROL) {
|
||||
urolls:: template microKernel<isARowMajor,1,4,EIGEN_AVX_MAX_K_UNROL,
|
||||
EIGEN_AVX_MAX_B_LOAD,EIGEN_AVX_MAX_A_BCAST>(
|
||||
B_t, A_t, LDB, LDA, zmm);
|
||||
urolls::template microKernel<isARowMajor, 1, 4, EIGEN_AVX_MAX_K_UNROL, EIGEN_AVX_MAX_B_LOAD,
|
||||
EIGEN_AVX_MAX_A_BCAST>(B_t, A_t, LDB, LDA, zmm);
|
||||
B_t += EIGEN_AVX_MAX_K_UNROL * LDB;
|
||||
EIGEN_IF_CONSTEXPR(isARowMajor) A_t += EIGEN_AVX_MAX_K_UNROL; else A_t += EIGEN_AVX_MAX_K_UNROL*LDA;
|
||||
EIGEN_IF_CONSTEXPR(isARowMajor) A_t += EIGEN_AVX_MAX_K_UNROL;
|
||||
else A_t += EIGEN_AVX_MAX_K_UNROL * LDA;
|
||||
}
|
||||
EIGEN_IF_CONSTEXPR(handleKRem) {
|
||||
for (int64_t k = K_; k < K; k++) {
|
||||
urolls:: template microKernel<isARowMajor,1,4,1,
|
||||
EIGEN_AVX_MAX_B_LOAD,EIGEN_AVX_MAX_A_BCAST>(
|
||||
B_t, A_t, LDB, LDA, zmm);
|
||||
urolls::template microKernel<isARowMajor, 1, 4, 1, EIGEN_AVX_MAX_B_LOAD, EIGEN_AVX_MAX_A_BCAST>(B_t, A_t, LDB,
|
||||
LDA, zmm);
|
||||
B_t += LDB;
|
||||
EIGEN_IF_CONSTEXPR(isARowMajor) A_t++; else A_t += LDA;
|
||||
EIGEN_IF_CONSTEXPR(isARowMajor) A_t++;
|
||||
else A_t += LDA;
|
||||
}
|
||||
}
|
||||
EIGEN_IF_CONSTEXPR(isCRowMajor) {
|
||||
@ -519,19 +534,19 @@ void gemmKernel(Scalar *A_arr, Scalar *B_arr, Scalar *C_arr,
|
||||
PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> zmm;
|
||||
urolls::template setzero<1, 2>(zmm);
|
||||
for (int64_t k = 0; k < K_; k += EIGEN_AVX_MAX_K_UNROL) {
|
||||
urolls:: template microKernel<isARowMajor,1,2,EIGEN_AVX_MAX_K_UNROL,
|
||||
EIGEN_AVX_MAX_B_LOAD,EIGEN_AVX_MAX_A_BCAST>(
|
||||
B_t, A_t, LDB, LDA, zmm);
|
||||
urolls::template microKernel<isARowMajor, 1, 2, EIGEN_AVX_MAX_K_UNROL, EIGEN_AVX_MAX_B_LOAD,
|
||||
EIGEN_AVX_MAX_A_BCAST>(B_t, A_t, LDB, LDA, zmm);
|
||||
B_t += EIGEN_AVX_MAX_K_UNROL * LDB;
|
||||
EIGEN_IF_CONSTEXPR(isARowMajor) A_t += EIGEN_AVX_MAX_K_UNROL; else A_t += EIGEN_AVX_MAX_K_UNROL*LDA;
|
||||
EIGEN_IF_CONSTEXPR(isARowMajor) A_t += EIGEN_AVX_MAX_K_UNROL;
|
||||
else A_t += EIGEN_AVX_MAX_K_UNROL * LDA;
|
||||
}
|
||||
EIGEN_IF_CONSTEXPR(handleKRem) {
|
||||
for (int64_t k = K_; k < K; k++) {
|
||||
urolls:: template microKernel<isARowMajor,1,2,1,
|
||||
EIGEN_AVX_MAX_B_LOAD,EIGEN_AVX_MAX_A_BCAST>(
|
||||
B_t, A_t, LDB, LDA, zmm);
|
||||
urolls::template microKernel<isARowMajor, 1, 2, 1, EIGEN_AVX_MAX_B_LOAD, EIGEN_AVX_MAX_A_BCAST>(B_t, A_t, LDB,
|
||||
LDA, zmm);
|
||||
B_t += LDB;
|
||||
EIGEN_IF_CONSTEXPR(isARowMajor) A_t++; else A_t += LDA;
|
||||
EIGEN_IF_CONSTEXPR(isARowMajor) A_t++;
|
||||
else A_t += LDA;
|
||||
}
|
||||
}
|
||||
EIGEN_IF_CONSTEXPR(isCRowMajor) {
|
||||
@ -550,17 +565,18 @@ void gemmKernel(Scalar *A_arr, Scalar *B_arr, Scalar *C_arr,
|
||||
urolls::template setzero<1, 1>(zmm);
|
||||
{
|
||||
for (int64_t k = 0; k < K_; k += EIGEN_AVX_MAX_K_UNROL) {
|
||||
urolls:: template microKernel<isARowMajor,1,1,EIGEN_AVX_MAX_K_UNROL,
|
||||
EIGEN_AVX_MAX_B_LOAD,1>(
|
||||
B_t, A_t, LDB, LDA, zmm);
|
||||
urolls::template microKernel<isARowMajor, 1, 1, EIGEN_AVX_MAX_K_UNROL, EIGEN_AVX_MAX_B_LOAD, 1>(B_t, A_t, LDB,
|
||||
LDA, zmm);
|
||||
B_t += EIGEN_AVX_MAX_K_UNROL * LDB;
|
||||
EIGEN_IF_CONSTEXPR(isARowMajor) A_t += EIGEN_AVX_MAX_K_UNROL; else A_t += EIGEN_AVX_MAX_K_UNROL*LDA;
|
||||
EIGEN_IF_CONSTEXPR(isARowMajor) A_t += EIGEN_AVX_MAX_K_UNROL;
|
||||
else A_t += EIGEN_AVX_MAX_K_UNROL * LDA;
|
||||
}
|
||||
EIGEN_IF_CONSTEXPR(handleKRem) {
|
||||
for (int64_t k = K_; k < K; k++) {
|
||||
urolls::template microKernel<isARowMajor, 1, 1, 1, EIGEN_AVX_B_LOAD_SETS * 1, 1>(B_t, A_t, LDB, LDA, zmm);
|
||||
B_t += LDB;
|
||||
EIGEN_IF_CONSTEXPR(isARowMajor) A_t++; else A_t += LDA;
|
||||
EIGEN_IF_CONSTEXPR(isARowMajor) A_t++;
|
||||
else A_t += LDA;
|
||||
}
|
||||
}
|
||||
EIGEN_IF_CONSTEXPR(isCRowMajor) {
|
||||
@ -583,19 +599,19 @@ void gemmKernel(Scalar *A_arr, Scalar *B_arr, Scalar *C_arr,
|
||||
PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> zmm;
|
||||
urolls::template setzero<1, EIGEN_AVX_MAX_NUM_ROW>(zmm);
|
||||
for (int64_t k = 0; k < K_; k += EIGEN_AVX_MAX_K_UNROL) {
|
||||
urolls:: template microKernel<isARowMajor,1,EIGEN_AVX_MAX_NUM_ROW,EIGEN_AVX_MAX_K_UNROL,
|
||||
EIGEN_AVX_MAX_B_LOAD,EIGEN_AVX_MAX_A_BCAST,true>(
|
||||
B_t, A_t, LDB, LDA, zmm, N - j);
|
||||
urolls::template microKernel<isARowMajor, 1, EIGEN_AVX_MAX_NUM_ROW, EIGEN_AVX_MAX_K_UNROL, EIGEN_AVX_MAX_B_LOAD,
|
||||
EIGEN_AVX_MAX_A_BCAST, true>(B_t, A_t, LDB, LDA, zmm, N - j);
|
||||
B_t += EIGEN_AVX_MAX_K_UNROL * LDB;
|
||||
EIGEN_IF_CONSTEXPR(isARowMajor) A_t += EIGEN_AVX_MAX_K_UNROL; else A_t += EIGEN_AVX_MAX_K_UNROL*LDA;
|
||||
EIGEN_IF_CONSTEXPR(isARowMajor) A_t += EIGEN_AVX_MAX_K_UNROL;
|
||||
else A_t += EIGEN_AVX_MAX_K_UNROL * LDA;
|
||||
}
|
||||
EIGEN_IF_CONSTEXPR(handleKRem) {
|
||||
for (int64_t k = K_; k < K; k++) {
|
||||
urolls:: template microKernel<isARowMajor,1,EIGEN_AVX_MAX_NUM_ROW,1,
|
||||
EIGEN_AVX_MAX_B_LOAD,EIGEN_AVX_MAX_A_BCAST,true>(
|
||||
B_t, A_t, LDB, LDA, zmm, N - j);
|
||||
urolls::template microKernel<isARowMajor, 1, EIGEN_AVX_MAX_NUM_ROW, 1, EIGEN_AVX_MAX_B_LOAD,
|
||||
EIGEN_AVX_MAX_A_BCAST, true>(B_t, A_t, LDB, LDA, zmm, N - j);
|
||||
B_t += LDB;
|
||||
EIGEN_IF_CONSTEXPR(isARowMajor) A_t++; else A_t += LDA;
|
||||
EIGEN_IF_CONSTEXPR(isARowMajor) A_t++;
|
||||
else A_t += LDA;
|
||||
}
|
||||
}
|
||||
EIGEN_IF_CONSTEXPR(isCRowMajor) {
|
||||
@ -612,19 +628,19 @@ void gemmKernel(Scalar *A_arr, Scalar *B_arr, Scalar *C_arr,
|
||||
PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> zmm;
|
||||
urolls::template setzero<1, 4>(zmm);
|
||||
for (int64_t k = 0; k < K_; k += EIGEN_AVX_MAX_K_UNROL) {
|
||||
urolls:: template microKernel<isARowMajor,1,4,EIGEN_AVX_MAX_K_UNROL,
|
||||
EIGEN_AVX_MAX_B_LOAD,EIGEN_AVX_MAX_A_BCAST,true>(
|
||||
B_t, A_t, LDB, LDA, zmm, N - j);
|
||||
urolls::template microKernel<isARowMajor, 1, 4, EIGEN_AVX_MAX_K_UNROL, EIGEN_AVX_MAX_B_LOAD,
|
||||
EIGEN_AVX_MAX_A_BCAST, true>(B_t, A_t, LDB, LDA, zmm, N - j);
|
||||
B_t += EIGEN_AVX_MAX_K_UNROL * LDB;
|
||||
EIGEN_IF_CONSTEXPR(isARowMajor) A_t += EIGEN_AVX_MAX_K_UNROL; else A_t += EIGEN_AVX_MAX_K_UNROL*LDA;
|
||||
EIGEN_IF_CONSTEXPR(isARowMajor) A_t += EIGEN_AVX_MAX_K_UNROL;
|
||||
else A_t += EIGEN_AVX_MAX_K_UNROL * LDA;
|
||||
}
|
||||
EIGEN_IF_CONSTEXPR(handleKRem) {
|
||||
for (int64_t k = K_; k < K; k++) {
|
||||
urolls:: template microKernel<isARowMajor,1,4,1,
|
||||
EIGEN_AVX_MAX_B_LOAD,EIGEN_AVX_MAX_A_BCAST,true>(
|
||||
urolls::template microKernel<isARowMajor, 1, 4, 1, EIGEN_AVX_MAX_B_LOAD, EIGEN_AVX_MAX_A_BCAST, true>(
|
||||
B_t, A_t, LDB, LDA, zmm, N - j);
|
||||
B_t += LDB;
|
||||
EIGEN_IF_CONSTEXPR(isARowMajor) A_t++; else A_t += LDA;
|
||||
EIGEN_IF_CONSTEXPR(isARowMajor) A_t++;
|
||||
else A_t += LDA;
|
||||
}
|
||||
}
|
||||
EIGEN_IF_CONSTEXPR(isCRowMajor) {
|
||||
@ -642,19 +658,19 @@ void gemmKernel(Scalar *A_arr, Scalar *B_arr, Scalar *C_arr,
|
||||
PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> zmm;
|
||||
urolls::template setzero<1, 2>(zmm);
|
||||
for (int64_t k = 0; k < K_; k += EIGEN_AVX_MAX_K_UNROL) {
|
||||
urolls:: template microKernel<isARowMajor,1,2,EIGEN_AVX_MAX_K_UNROL,
|
||||
EIGEN_AVX_MAX_B_LOAD,EIGEN_AVX_MAX_A_BCAST,true>(
|
||||
B_t, A_t, LDB, LDA, zmm, N - j);
|
||||
urolls::template microKernel<isARowMajor, 1, 2, EIGEN_AVX_MAX_K_UNROL, EIGEN_AVX_MAX_B_LOAD,
|
||||
EIGEN_AVX_MAX_A_BCAST, true>(B_t, A_t, LDB, LDA, zmm, N - j);
|
||||
B_t += EIGEN_AVX_MAX_K_UNROL * LDB;
|
||||
EIGEN_IF_CONSTEXPR(isARowMajor) A_t += EIGEN_AVX_MAX_K_UNROL; else A_t += EIGEN_AVX_MAX_K_UNROL*LDA;
|
||||
EIGEN_IF_CONSTEXPR(isARowMajor) A_t += EIGEN_AVX_MAX_K_UNROL;
|
||||
else A_t += EIGEN_AVX_MAX_K_UNROL * LDA;
|
||||
}
|
||||
EIGEN_IF_CONSTEXPR(handleKRem) {
|
||||
for (int64_t k = K_; k < K; k++) {
|
||||
urolls:: template microKernel<isARowMajor,1,2,1,
|
||||
EIGEN_AVX_MAX_B_LOAD,EIGEN_AVX_MAX_A_BCAST,true>(
|
||||
urolls::template microKernel<isARowMajor, 1, 2, 1, EIGEN_AVX_MAX_B_LOAD, EIGEN_AVX_MAX_A_BCAST, true>(
|
||||
B_t, A_t, LDB, LDA, zmm, N - j);
|
||||
B_t += LDB;
|
||||
EIGEN_IF_CONSTEXPR(isARowMajor) A_t++; else A_t += LDA;
|
||||
EIGEN_IF_CONSTEXPR(isARowMajor) A_t++;
|
||||
else A_t += LDA;
|
||||
}
|
||||
}
|
||||
EIGEN_IF_CONSTEXPR(isCRowMajor) {
|
||||
@ -672,19 +688,19 @@ void gemmKernel(Scalar *A_arr, Scalar *B_arr, Scalar *C_arr,
|
||||
PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> zmm;
|
||||
urolls::template setzero<1, 1>(zmm);
|
||||
for (int64_t k = 0; k < K_; k += EIGEN_AVX_MAX_K_UNROL) {
|
||||
urolls:: template microKernel<isARowMajor,1,1,EIGEN_AVX_MAX_K_UNROL,
|
||||
EIGEN_AVX_MAX_B_LOAD,1,true>(
|
||||
urolls::template microKernel<isARowMajor, 1, 1, EIGEN_AVX_MAX_K_UNROL, EIGEN_AVX_MAX_B_LOAD, 1, true>(
|
||||
B_t, A_t, LDB, LDA, zmm, N - j);
|
||||
B_t += EIGEN_AVX_MAX_K_UNROL * LDB;
|
||||
EIGEN_IF_CONSTEXPR(isARowMajor) A_t += EIGEN_AVX_MAX_K_UNROL; else A_t += EIGEN_AVX_MAX_K_UNROL*LDA;
|
||||
EIGEN_IF_CONSTEXPR(isARowMajor) A_t += EIGEN_AVX_MAX_K_UNROL;
|
||||
else A_t += EIGEN_AVX_MAX_K_UNROL * LDA;
|
||||
}
|
||||
EIGEN_IF_CONSTEXPR(handleKRem) {
|
||||
for (int64_t k = K_; k < K; k++) {
|
||||
urolls:: template microKernel<isARowMajor,1,1,1,
|
||||
EIGEN_AVX_MAX_B_LOAD,1,true>(
|
||||
B_t, A_t, LDB, LDA, zmm, N - j);
|
||||
urolls::template microKernel<isARowMajor, 1, 1, 1, EIGEN_AVX_MAX_B_LOAD, 1, true>(B_t, A_t, LDB, LDA, zmm,
|
||||
N - j);
|
||||
B_t += LDB;
|
||||
EIGEN_IF_CONSTEXPR(isARowMajor) A_t++; else A_t += LDA;
|
||||
EIGEN_IF_CONSTEXPR(isARowMajor) A_t++;
|
||||
else A_t += LDA;
|
||||
}
|
||||
}
|
||||
EIGEN_IF_CONSTEXPR(isCRowMajor) {
|
||||
@ -707,9 +723,7 @@ void gemmKernel(Scalar *A_arr, Scalar *B_arr, Scalar *C_arr,
|
||||
* The B matrix (RHS) is assumed to be row-major
|
||||
*/
|
||||
template <typename Scalar, typename vec, int64_t unrollM, bool isARowMajor, bool isFWDSolve, bool isUnitDiag>
|
||||
static EIGEN_ALWAYS_INLINE
|
||||
void triSolveKernel(Scalar *A_arr, Scalar *B_arr, int64_t K, int64_t LDA, int64_t LDB) {
|
||||
|
||||
static EIGEN_ALWAYS_INLINE void triSolveKernel(Scalar *A_arr, Scalar *B_arr, int64_t K, int64_t LDA, int64_t LDB) {
|
||||
static_assert(unrollM <= EIGEN_AVX_MAX_NUM_ROW, "unrollM should be equal to EIGEN_AVX_MAX_NUM_ROW");
|
||||
using urolls = unrolls::trsm<Scalar>;
|
||||
constexpr int64_t U3 = urolls::PacketSize * 3;
|
||||
@ -722,30 +736,30 @@ void triSolveKernel(Scalar *A_arr, Scalar *B_arr, int64_t K, int64_t LDA, int64_
|
||||
int64_t k = 0;
|
||||
while (K - k >= U3) {
|
||||
urolls::template loadRHS<isFWDSolve, unrollM, 3>(B_arr + k, LDB, RHSInPacket);
|
||||
urolls:: template triSolveMicroKernel<isARowMajor, isFWDSolve, isUnitDiag, unrollM, 3>(
|
||||
A_arr, LDA, RHSInPacket, AInPacket);
|
||||
urolls::template triSolveMicroKernel<isARowMajor, isFWDSolve, isUnitDiag, unrollM, 3>(A_arr, LDA, RHSInPacket,
|
||||
AInPacket);
|
||||
urolls::template storeRHS<isFWDSolve, unrollM, 3>(B_arr + k, LDB, RHSInPacket);
|
||||
k += U3;
|
||||
}
|
||||
if (K - k >= U2) {
|
||||
urolls::template loadRHS<isFWDSolve, unrollM, 2>(B_arr + k, LDB, RHSInPacket);
|
||||
urolls:: template triSolveMicroKernel<isARowMajor, isFWDSolve, isUnitDiag, unrollM, 2>(
|
||||
A_arr, LDA, RHSInPacket, AInPacket);
|
||||
urolls::template triSolveMicroKernel<isARowMajor, isFWDSolve, isUnitDiag, unrollM, 2>(A_arr, LDA, RHSInPacket,
|
||||
AInPacket);
|
||||
urolls::template storeRHS<isFWDSolve, unrollM, 2>(B_arr + k, LDB, RHSInPacket);
|
||||
k += U2;
|
||||
}
|
||||
if (K - k >= U1) {
|
||||
urolls::template loadRHS<isFWDSolve, unrollM, 1>(B_arr + k, LDB, RHSInPacket);
|
||||
urolls:: template triSolveMicroKernel<isARowMajor, isFWDSolve, isUnitDiag, unrollM, 1>(
|
||||
A_arr, LDA, RHSInPacket, AInPacket);
|
||||
urolls::template triSolveMicroKernel<isARowMajor, isFWDSolve, isUnitDiag, unrollM, 1>(A_arr, LDA, RHSInPacket,
|
||||
AInPacket);
|
||||
urolls::template storeRHS<isFWDSolve, unrollM, 1>(B_arr + k, LDB, RHSInPacket);
|
||||
k += U1;
|
||||
}
|
||||
if (K - k > 0) {
|
||||
// Handle remaining number of RHS
|
||||
urolls::template loadRHS<isFWDSolve, unrollM, 1, true>(B_arr + k, LDB, RHSInPacket, K - k);
|
||||
urolls::template triSolveMicroKernel<isARowMajor, isFWDSolve, isUnitDiag, unrollM, 1>(
|
||||
A_arr, LDA, RHSInPacket, AInPacket);
|
||||
urolls::template triSolveMicroKernel<isARowMajor, isFWDSolve, isUnitDiag, unrollM, 1>(A_arr, LDA, RHSInPacket,
|
||||
AInPacket);
|
||||
urolls::template storeRHS<isFWDSolve, unrollM, 1, true>(B_arr + k, LDB, RHSInPacket, K - k);
|
||||
}
|
||||
}
|
||||
@ -790,9 +804,8 @@ void triSolveKernelLxK(Scalar *A_arr, Scalar *B_arr, int64_t M, int64_t K, int64
|
||||
*
|
||||
*/
|
||||
template <typename Scalar, bool toTemp = true, bool remM = false>
|
||||
static EIGEN_ALWAYS_INLINE
|
||||
void copyBToRowMajor(Scalar *B_arr, int64_t LDB, int64_t K,
|
||||
Scalar *B_temp, int64_t LDB_, int64_t remM_ = 0) {
|
||||
static EIGEN_ALWAYS_INLINE void copyBToRowMajor(Scalar *B_arr, int64_t LDB, int64_t K, Scalar *B_temp, int64_t LDB_,
|
||||
int64_t remM_ = 0) {
|
||||
EIGEN_UNUSED_VARIABLE(remM_);
|
||||
using urolls = unrolls::transB<Scalar>;
|
||||
using vecHalf = typename std::conditional<std::is_same<Scalar, float>::value, vecHalfFloat, vecFullDouble>::type;
|
||||
@ -809,11 +822,13 @@ void copyBToRowMajor(Scalar *B_arr, int64_t LDB, int64_t K,
|
||||
}
|
||||
if (K - k >= U2) {
|
||||
urolls::template transB_kernel<U2, toTemp, remM>(B_arr + k * LDB, LDB, B_temp, LDB_, ymm, remM_);
|
||||
B_temp += U2; k += U2;
|
||||
B_temp += U2;
|
||||
k += U2;
|
||||
}
|
||||
if (K - k >= U1) {
|
||||
urolls::template transB_kernel<U1, toTemp, remM>(B_arr + k * LDB, LDB, B_temp, LDB_, ymm, remM_);
|
||||
B_temp += U1; k += U1;
|
||||
B_temp += U1;
|
||||
k += U1;
|
||||
}
|
||||
EIGEN_IF_CONSTEXPR(U1 > 8) {
|
||||
// Note: without "if constexpr" this section of code will also be
|
||||
@ -821,7 +836,8 @@ void copyBToRowMajor(Scalar *B_arr, int64_t LDB, int64_t K,
|
||||
// to make sure the counter is not non-negative.
|
||||
if (K - k >= 8) {
|
||||
urolls::template transB_kernel<8, toTemp, remM>(B_arr + k * LDB, LDB, B_temp, LDB_, ymm, remM_);
|
||||
B_temp += 8; k += 8;
|
||||
B_temp += 8;
|
||||
k += 8;
|
||||
}
|
||||
}
|
||||
EIGEN_IF_CONSTEXPR(U1 > 4) {
|
||||
@ -830,19 +846,48 @@ void copyBToRowMajor(Scalar *B_arr, int64_t LDB, int64_t K,
|
||||
// to make sure the counter is not non-negative.
|
||||
if (K - k >= 4) {
|
||||
urolls::template transB_kernel<4, toTemp, remM>(B_arr + k * LDB, LDB, B_temp, LDB_, ymm, remM_);
|
||||
B_temp += 4; k += 4;
|
||||
B_temp += 4;
|
||||
k += 4;
|
||||
}
|
||||
}
|
||||
if (K - k >= 2) {
|
||||
urolls::template transB_kernel<2, toTemp, remM>(B_arr + k * LDB, LDB, B_temp, LDB_, ymm, remM_);
|
||||
B_temp += 2; k += 2;
|
||||
B_temp += 2;
|
||||
k += 2;
|
||||
}
|
||||
if (K - k >= 1) {
|
||||
urolls::template transB_kernel<1, toTemp, remM>(B_arr + k * LDB, LDB, B_temp, LDB_, ymm, remM_);
|
||||
B_temp += 1; k += 1;
|
||||
B_temp += 1;
|
||||
k += 1;
|
||||
}
|
||||
}
|
||||
|
||||
#if (EIGEN_USE_AVX512_TRSM_L_KERNELS) && defined(EIGEN_NO_MALLOC)
|
||||
/**
|
||||
* Reduce blocking sizes so that the size of the temporary workspace needed is less than "limit" bytes,
|
||||
* - kB must be at least psize
|
||||
* - numM must be at least EIGEN_AVX_MAX_NUM_ROW
|
||||
*/
|
||||
template <typename Scalar, bool isBRowMajor>
|
||||
constexpr std::pair<int64_t, int64_t> trsmBlocking(const int64_t limit) {
|
||||
constexpr int64_t psize = packet_traits<Scalar>::size;
|
||||
int64_t kB = 15 * psize;
|
||||
int64_t numM = 8 * EIGEN_AVX_MAX_NUM_ROW;
|
||||
// If B is rowmajor, no temp workspace needed, so use default blocking sizes.
|
||||
if (isBRowMajor) return {kB, numM};
|
||||
|
||||
// Very simple heuristic, prefer keeping kB as large as possible to fully use vector registers.
|
||||
for (int64_t k = kB; k > psize; k -= psize) {
|
||||
for (int64_t m = numM; m > EIGEN_AVX_MAX_NUM_ROW; m -= EIGEN_AVX_MAX_NUM_ROW) {
|
||||
if ((((k + psize - 1) / psize + 4) * psize) * m * sizeof(Scalar) < limit) {
|
||||
return {k, m};
|
||||
}
|
||||
}
|
||||
}
|
||||
return {psize, EIGEN_AVX_MAX_NUM_ROW}; // Minimum blocking size required
|
||||
}
|
||||
#endif // (EIGEN_USE_AVX512_TRSM_L_KERNELS) && defined(EIGEN_NO_MALLOC)
|
||||
|
||||
/**
|
||||
* Main triangular solve driver
|
||||
*
|
||||
@ -870,8 +915,10 @@ void copyBToRowMajor(Scalar *B_arr, int64_t LDB, int64_t K,
|
||||
* Note: For RXX cases M,numRHS should be swapped.
|
||||
*
|
||||
*/
|
||||
template <typename Scalar, bool isARowMajor = true, bool isBRowMajor = true, bool isFWDSolve = true, bool isUnitDiag = false>
|
||||
template <typename Scalar, bool isARowMajor = true, bool isBRowMajor = true, bool isFWDSolve = true,
|
||||
bool isUnitDiag = false>
|
||||
void triSolve(Scalar *A_arr, Scalar *B_arr, int64_t M, int64_t numRHS, int64_t LDA, int64_t LDB) {
|
||||
constexpr int64_t psize = packet_traits<Scalar>::size;
|
||||
/**
|
||||
* The values for kB, numM were determined experimentally.
|
||||
* kB: Number of RHS we process at a time.
|
||||
@ -885,8 +932,30 @@ void triSolve(Scalar *A_arr, Scalar *B_arr, int64_t M, int64_t numRHS, int64_t L
|
||||
* large enough to allow GEMM updates to have larger "K"s (see below.) No benchmarking has been done so far to
|
||||
* determine optimal values for numM.
|
||||
*/
|
||||
const int64_t kB = (3*packet_traits<Scalar>::size)*5; // 5*U3
|
||||
const int64_t numM = 64;
|
||||
#if (EIGEN_USE_AVX512_TRSM_L_KERNELS) && defined(EIGEN_NO_MALLOC)
|
||||
/**
|
||||
* If EIGEN_NO_MALLOC is requested, we try to reduce kB and numM so the maximum temp workspace required is less
|
||||
* than EIGEN_STACK_ALLOCATION_LIMIT. Actual workspace size may be less, depending on the number of vectors to
|
||||
* solve.
|
||||
* - kB must be at least psize
|
||||
* - numM must be at least EIGEN_AVX_MAX_NUM_ROW
|
||||
*
|
||||
* If B is row-major, the blocking sizes are not reduced (no temp workspace needed).
|
||||
*/
|
||||
constexpr std::pair<int64_t, int64_t> blocking_ = trsmBlocking<Scalar, isBRowMajor>(EIGEN_STACK_ALLOCATION_LIMIT);
|
||||
constexpr int64_t kB = blocking_.first;
|
||||
constexpr int64_t numM = blocking_.second;
|
||||
/**
|
||||
* If the temp workspace size exceeds EIGEN_STACK_ALLOCATION_LIMIT even with the minimum blocking sizes,
|
||||
* we throw an assertion. Use -DEIGEN_USE_AVX512_TRSM_L_KERNELS=0 if necessary
|
||||
*/
|
||||
static_assert(!(((((kB + psize - 1) / psize + 4) * psize) * numM * sizeof(Scalar) >= EIGEN_STACK_ALLOCATION_LIMIT) &&
|
||||
!isBRowMajor),
|
||||
"Temp workspace required is too large.");
|
||||
#else
|
||||
constexpr int64_t kB = (3 * psize) * 5; // 5*U3
|
||||
constexpr int64_t numM = 8 * EIGEN_AVX_MAX_NUM_ROW;
|
||||
#endif
|
||||
|
||||
int64_t sizeBTemp = 0;
|
||||
Scalar *B_temp = NULL;
|
||||
@ -896,9 +965,17 @@ void triSolve(Scalar *A_arr, Scalar *B_arr, int64_t M, int64_t numRHS, int64_t L
|
||||
* transpose it to row-major. Call the solve routine, and copy+transpose it back to the original array.
|
||||
* The updated row-major copy of B is reused in the GEMM updates.
|
||||
*/
|
||||
sizeBTemp = (((std::min(kB, numRHS) + 15)/16+ 4)*16)*numM;
|
||||
B_temp = (Scalar*) handmade_aligned_malloc(sizeof(Scalar)*sizeBTemp,4096);
|
||||
sizeBTemp = (((std::min(kB, numRHS) + psize - 1) / psize + 4) * psize) * numM;
|
||||
}
|
||||
|
||||
#if !defined(EIGEN_NO_MALLOC)
|
||||
EIGEN_IF_CONSTEXPR(!isBRowMajor) B_temp = (Scalar *)handmade_aligned_malloc(sizeof(Scalar) * sizeBTemp, 64);
|
||||
#elif (EIGEN_USE_AVX512_TRSM_L_KERNELS) && defined(EIGEN_NO_MALLOC)
|
||||
// Use alloca if malloc not allowed, requested temp workspace size should be less than EIGEN_STACK_ALLOCATION_LIMIT
|
||||
ei_declare_aligned_stack_constructed_variable(Scalar, B_temp_alloca, sizeBTemp, 0);
|
||||
B_temp = B_temp_alloca;
|
||||
#endif
|
||||
|
||||
for (int64_t k = 0; k < numRHS; k += kB) {
|
||||
int64_t bK = numRHS - k > kB ? kB : numRHS - k;
|
||||
int64_t M_ = (M / EIGEN_AVX_MAX_NUM_ROW) * EIGEN_AVX_MAX_NUM_ROW, gemmOff = 0;
|
||||
@ -950,11 +1027,8 @@ void triSolve(Scalar *A_arr, Scalar *B_arr, int64_t M, int64_t numRHS, int64_t L
|
||||
int64_t indB_i = isFWDSolve ? 0 : M - (i + EIGEN_AVX_MAX_NUM_ROW);
|
||||
int64_t indB_i2 = isFWDSolve ? i + EIGEN_AVX_MAX_NUM_ROW : M - (i + 2 * EIGEN_AVX_MAX_NUM_ROW);
|
||||
gemmKernel<Scalar, isARowMajor, isBRowMajor, false, false>(
|
||||
&A_arr[idA<isARowMajor>(indA_i,indA_j,LDA)],
|
||||
B_arr + k + indB_i*LDB,
|
||||
B_arr + k + indB_i2*LDB,
|
||||
EIGEN_AVX_MAX_NUM_ROW, bK, i + EIGEN_AVX_MAX_NUM_ROW,
|
||||
LDA, LDB, LDB);
|
||||
&A_arr[idA<isARowMajor>(indA_i, indA_j, LDA)], B_arr + k + indB_i * LDB, B_arr + k + indB_i2 * LDB,
|
||||
EIGEN_AVX_MAX_NUM_ROW, bK, i + EIGEN_AVX_MAX_NUM_ROW, LDA, LDB, LDB);
|
||||
}
|
||||
else {
|
||||
if (offsetBTemp + EIGEN_AVX_MAX_NUM_ROW * LDT > sizeBTemp) {
|
||||
@ -971,14 +1045,11 @@ void triSolve(Scalar *A_arr, Scalar *B_arr, int64_t M, int64_t numRHS, int64_t L
|
||||
int64_t indB_i = isFWDSolve ? i + EIGEN_AVX_MAX_NUM_ROW : 0;
|
||||
int64_t offB_1 = isFWDSolve ? 0 : sizeBTemp - offsetBTemp;
|
||||
gemmKernel<Scalar, isARowMajor, isBRowMajor, false, false>(
|
||||
&A_arr[idA<isARowMajor>(indA_i, indA_j,LDA)],
|
||||
B_temp + offB_1,
|
||||
B_arr + indB_i + (k)*LDB,
|
||||
M - (i + EIGEN_AVX_MAX_NUM_ROW), bK, i + EIGEN_AVX_MAX_NUM_ROW - gemmOff,
|
||||
LDA, LDT, LDB);
|
||||
offsetBTemp = 0; gemmOff = i + EIGEN_AVX_MAX_NUM_ROW;
|
||||
}
|
||||
else {
|
||||
&A_arr[idA<isARowMajor>(indA_i, indA_j, LDA)], B_temp + offB_1, B_arr + indB_i + (k)*LDB,
|
||||
M - (i + EIGEN_AVX_MAX_NUM_ROW), bK, i + EIGEN_AVX_MAX_NUM_ROW - gemmOff, LDA, LDT, LDB);
|
||||
offsetBTemp = 0;
|
||||
gemmOff = i + EIGEN_AVX_MAX_NUM_ROW;
|
||||
} else {
|
||||
/**
|
||||
* If there is enough space in B_temp, we only update the next 8xbK values of B.
|
||||
*/
|
||||
@ -987,11 +1058,8 @@ void triSolve(Scalar *A_arr, Scalar *B_arr, int64_t M, int64_t numRHS, int64_t L
|
||||
int64_t indB_i = isFWDSolve ? i + EIGEN_AVX_MAX_NUM_ROW : M - (i + 2 * EIGEN_AVX_MAX_NUM_ROW);
|
||||
int64_t offB_1 = isFWDSolve ? 0 : sizeBTemp - offsetBTemp;
|
||||
gemmKernel<Scalar, isARowMajor, isBRowMajor, false, false>(
|
||||
&A_arr[idA<isARowMajor>(indA_i,indA_j,LDA)],
|
||||
B_temp + offB_1,
|
||||
B_arr + indB_i + (k)*LDB,
|
||||
EIGEN_AVX_MAX_NUM_ROW, bK, i + EIGEN_AVX_MAX_NUM_ROW - gemmOff,
|
||||
LDA, LDT, LDB);
|
||||
&A_arr[idA<isARowMajor>(indA_i, indA_j, LDA)], B_temp + offB_1, B_arr + indB_i + (k)*LDB,
|
||||
EIGEN_AVX_MAX_NUM_ROW, bK, i + EIGEN_AVX_MAX_NUM_ROW - gemmOff, LDA, LDT, LDB);
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -1006,23 +1074,17 @@ void triSolve(Scalar *A_arr, Scalar *B_arr, int64_t M, int64_t numRHS, int64_t L
|
||||
int64_t indB_i = isFWDSolve ? 0 : bM;
|
||||
int64_t indB_i2 = isFWDSolve ? M_ : 0;
|
||||
gemmKernel<Scalar, isARowMajor, isBRowMajor, false, false>(
|
||||
&A_arr[idA<isARowMajor>(indA_i,indA_j,LDA)],
|
||||
B_arr + k +indB_i*LDB,
|
||||
B_arr + k + indB_i2*LDB,
|
||||
bM , bK, M_,
|
||||
LDA, LDB, LDB);
|
||||
&A_arr[idA<isARowMajor>(indA_i, indA_j, LDA)], B_arr + k + indB_i * LDB, B_arr + k + indB_i2 * LDB, bM,
|
||||
bK, M_, LDA, LDB, LDB);
|
||||
}
|
||||
else {
|
||||
int64_t indA_i = isFWDSolve ? M_ : 0;
|
||||
int64_t indA_j = isFWDSolve ? gemmOff : bM;
|
||||
int64_t indB_i = isFWDSolve ? M_ : 0;
|
||||
int64_t offB_1 = isFWDSolve ? 0 : sizeBTemp - offsetBTemp;
|
||||
gemmKernel<Scalar,isARowMajor, isBRowMajor,false,false>(
|
||||
&A_arr[idA<isARowMajor>(indA_i,indA_j,LDA)],
|
||||
B_temp + offB_1,
|
||||
B_arr + indB_i + (k)*LDB,
|
||||
bM , bK, M_ - gemmOff,
|
||||
LDA, LDT, LDB);
|
||||
gemmKernel<Scalar, isARowMajor, isBRowMajor, false, false>(&A_arr[idA<isARowMajor>(indA_i, indA_j, LDA)],
|
||||
B_temp + offB_1, B_arr + indB_i + (k)*LDB, bM, bK,
|
||||
M_ - gemmOff, LDA, LDT, LDB);
|
||||
}
|
||||
}
|
||||
EIGEN_IF_CONSTEXPR(!isBRowMajor) {
|
||||
@ -1030,44 +1092,45 @@ void triSolve(Scalar *A_arr, Scalar *B_arr, int64_t M, int64_t numRHS, int64_t L
|
||||
int64_t indB_i = isFWDSolve ? M_ : 0;
|
||||
int64_t offB_1 = isFWDSolve ? 0 : (bM - 1) * bkL;
|
||||
copyBToRowMajor<Scalar, true, true>(B_arr + indB_i + k * LDB, LDB, bK, B_temp, bkL, bM);
|
||||
triSolveKernelLxK<Scalar, isARowMajor, isFWDSolve, isUnitDiag>(
|
||||
&A_arr[idA<isARowMajor>(indA_i, indA_i, LDA)], B_temp + offB_1, bM, bkL, LDA, bkL);
|
||||
triSolveKernelLxK<Scalar, isARowMajor, isFWDSolve, isUnitDiag>(&A_arr[idA<isARowMajor>(indA_i, indA_i, LDA)],
|
||||
B_temp + offB_1, bM, bkL, LDA, bkL);
|
||||
copyBToRowMajor<Scalar, false, true>(B_arr + indB_i + k * LDB, LDB, bK, B_temp, bkL, bM);
|
||||
}
|
||||
else {
|
||||
int64_t ind = isFWDSolve ? M_ : M - 1 - M_;
|
||||
triSolveKernelLxK<Scalar, isARowMajor, isFWDSolve, isUnitDiag>(
|
||||
&A_arr[idA<isARowMajor>(ind, ind, LDA)], B_arr + k + ind*LDB, bM, bK, LDA, LDB);
|
||||
triSolveKernelLxK<Scalar, isARowMajor, isFWDSolve, isUnitDiag>(&A_arr[idA<isARowMajor>(ind, ind, LDA)],
|
||||
B_arr + k + ind * LDB, bM, bK, LDA, LDB);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#if !defined(EIGEN_NO_MALLOC)
|
||||
EIGEN_IF_CONSTEXPR(!isBRowMajor) handmade_aligned_free(B_temp);
|
||||
#endif
|
||||
}
|
||||
|
||||
// Template specializations of trsmKernelL/R for float/double and inner strides of 1.
|
||||
#if defined(EIGEN_USE_AVX512_TRSM_KERNELS)
|
||||
#if (EIGEN_USE_AVX512_TRSM_KERNELS)
|
||||
#if (EIGEN_USE_AVX512_TRSM_R_KERNELS)
|
||||
template <typename Scalar, typename Index, int Mode, bool Conjugate, int TriStorageOrder, int OtherInnerStride>
|
||||
struct trsmKernelR;
|
||||
|
||||
template <typename Index, int Mode, int TriStorageOrder>
|
||||
struct trsmKernelR<float, Index, Mode, false, TriStorageOrder, 1> {
|
||||
static void kernel(Index size, Index otherSize, const float* _tri, Index triStride,
|
||||
float* _other, Index otherIncr, Index otherStride);
|
||||
static void kernel(Index size, Index otherSize, const float *_tri, Index triStride, float *_other, Index otherIncr,
|
||||
Index otherStride);
|
||||
};
|
||||
|
||||
template <typename Index, int Mode, int TriStorageOrder>
|
||||
struct trsmKernelR<double, Index, Mode, false, TriStorageOrder, 1> {
|
||||
static void kernel(Index size, Index otherSize, const double* _tri, Index triStride,
|
||||
double* _other, Index otherIncr, Index otherStride);
|
||||
static void kernel(Index size, Index otherSize, const double *_tri, Index triStride, double *_other, Index otherIncr,
|
||||
Index otherStride);
|
||||
};
|
||||
|
||||
template <typename Index, int Mode, int TriStorageOrder>
|
||||
EIGEN_DONT_INLINE void trsmKernelR<float, Index, Mode, false, TriStorageOrder, 1>::kernel(
|
||||
Index size, Index otherSize,
|
||||
const float* _tri, Index triStride,
|
||||
float* _other, Index otherIncr, Index otherStride)
|
||||
{
|
||||
Index size, Index otherSize, const float *_tri, Index triStride, float *_other, Index otherIncr,
|
||||
Index otherStride) {
|
||||
EIGEN_UNUSED_VARIABLE(otherIncr);
|
||||
triSolve<float, TriStorageOrder != RowMajor, true, (Mode & Lower) != Lower, (Mode & UnitDiag) != 0>(
|
||||
const_cast<float *>(_tri), _other, size, otherSize, triStride, otherStride);
|
||||
@ -1075,38 +1138,35 @@ EIGEN_DONT_INLINE void trsmKernelR<float, Index, Mode, false, TriStorageOrder, 1
|
||||
|
||||
template <typename Index, int Mode, int TriStorageOrder>
|
||||
EIGEN_DONT_INLINE void trsmKernelR<double, Index, Mode, false, TriStorageOrder, 1>::kernel(
|
||||
Index size, Index otherSize,
|
||||
const double* _tri, Index triStride,
|
||||
double* _other, Index otherIncr, Index otherStride)
|
||||
{
|
||||
Index size, Index otherSize, const double *_tri, Index triStride, double *_other, Index otherIncr,
|
||||
Index otherStride) {
|
||||
EIGEN_UNUSED_VARIABLE(otherIncr);
|
||||
triSolve<double, TriStorageOrder != RowMajor, true, (Mode & Lower) != Lower, (Mode & UnitDiag) != 0>(
|
||||
const_cast<double *>(_tri), _other, size, otherSize, triStride, otherStride);
|
||||
}
|
||||
#endif // (EIGEN_USE_AVX512_TRSM_R_KERNELS)
|
||||
|
||||
// These trsm kernels require temporary memory allocation, so disable them if malloc is not allowed.
|
||||
#if defined(EIGEN_USE_AVX512_TRSM_L_KERNELS)
|
||||
// These trsm kernels require temporary memory allocation
|
||||
#if (EIGEN_USE_AVX512_TRSM_L_KERNELS)
|
||||
template <typename Scalar, typename Index, int Mode, bool Conjugate, int TriStorageOrder, int OtherInnerStride>
|
||||
struct trsmKernelL;
|
||||
|
||||
template <typename Index, int Mode, int TriStorageOrder>
|
||||
struct trsmKernelL<float, Index, Mode, false, TriStorageOrder, 1> {
|
||||
static void kernel(Index size, Index otherSize, const float* _tri, Index triStride,
|
||||
float* _other, Index otherIncr, Index otherStride);
|
||||
static void kernel(Index size, Index otherSize, const float *_tri, Index triStride, float *_other, Index otherIncr,
|
||||
Index otherStride);
|
||||
};
|
||||
|
||||
template <typename Index, int Mode, int TriStorageOrder>
|
||||
struct trsmKernelL<double, Index, Mode, false, TriStorageOrder, 1> {
|
||||
static void kernel(Index size, Index otherSize, const double* _tri, Index triStride,
|
||||
double* _other, Index otherIncr, Index otherStride);
|
||||
static void kernel(Index size, Index otherSize, const double *_tri, Index triStride, double *_other, Index otherIncr,
|
||||
Index otherStride);
|
||||
};
|
||||
|
||||
template <typename Index, int Mode, int TriStorageOrder>
|
||||
EIGEN_DONT_INLINE void trsmKernelL<float, Index, Mode, false, TriStorageOrder, 1>::kernel(
|
||||
Index size, Index otherSize,
|
||||
const float* _tri, Index triStride,
|
||||
float* _other, Index otherIncr, Index otherStride)
|
||||
{
|
||||
Index size, Index otherSize, const float *_tri, Index triStride, float *_other, Index otherIncr,
|
||||
Index otherStride) {
|
||||
EIGEN_UNUSED_VARIABLE(otherIncr);
|
||||
triSolve<float, TriStorageOrder == RowMajor, false, (Mode & Lower) == Lower, (Mode & UnitDiag) != 0>(
|
||||
const_cast<float *>(_tri), _other, size, otherSize, triStride, otherStride);
|
||||
@ -1114,16 +1174,14 @@ EIGEN_DONT_INLINE void trsmKernelL<float, Index, Mode, false, TriStorageOrder, 1
|
||||
|
||||
template <typename Index, int Mode, int TriStorageOrder>
|
||||
EIGEN_DONT_INLINE void trsmKernelL<double, Index, Mode, false, TriStorageOrder, 1>::kernel(
|
||||
Index size, Index otherSize,
|
||||
const double* _tri, Index triStride,
|
||||
double* _other, Index otherIncr, Index otherStride)
|
||||
{
|
||||
Index size, Index otherSize, const double *_tri, Index triStride, double *_other, Index otherIncr,
|
||||
Index otherStride) {
|
||||
EIGEN_UNUSED_VARIABLE(otherIncr);
|
||||
triSolve<double, TriStorageOrder == RowMajor, false, (Mode & Lower) == Lower, (Mode & UnitDiag) != 0>(
|
||||
const_cast<double *>(_tri), _other, size, otherSize, triStride, otherStride);
|
||||
}
|
||||
#endif // EIGEN_USE_AVX512_TRSM_L_KERNELS
|
||||
#endif // EIGEN_USE_AVX512_TRSM_KERNELS
|
||||
}
|
||||
}
|
||||
} // namespace internal
|
||||
} // namespace Eigen
|
||||
#endif // EIGEN_TRSM_KERNEL_IMPL_H
|
||||
|
@ -11,8 +11,7 @@
|
||||
#define EIGEN_UNROLLS_IMPL_H
|
||||
|
||||
template <bool isARowMajor = true>
|
||||
static EIGEN_ALWAYS_INLINE
|
||||
int64_t idA(int64_t i, int64_t j, int64_t LDA) {
|
||||
static EIGEN_ALWAYS_INLINE int64_t idA(int64_t i, int64_t j, int64_t LDA) {
|
||||
EIGEN_IF_CONSTEXPR(isARowMajor) return i * LDA + j;
|
||||
else return i + j * LDA;
|
||||
}
|
||||
@ -60,8 +59,12 @@ namespace unrolls {
|
||||
template <int64_t N>
|
||||
EIGEN_ALWAYS_INLINE auto remMask(int64_t m) {
|
||||
EIGEN_IF_CONSTEXPR(N == 16) { return 0xFFFF >> (16 - m); }
|
||||
else EIGEN_IF_CONSTEXPR( N == 8) { return 0xFF >> (8 - m); }
|
||||
else EIGEN_IF_CONSTEXPR( N == 4) { return 0x0F >> (4 - m); }
|
||||
else EIGEN_IF_CONSTEXPR(N == 8) {
|
||||
return 0xFF >> (8 - m);
|
||||
}
|
||||
else EIGEN_IF_CONSTEXPR(N == 4) {
|
||||
return 0x0F >> (4 - m);
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
@ -105,10 +108,14 @@ EIGEN_ALWAYS_INLINE void trans8x8blocks(PacketBlock<Packet16f, 8>& kernel) {
|
||||
T7 = _mm512_castpd_ps(_mm512_permutex_pd(_mm512_castps_pd(kernel.packet[3]), 0x4E));
|
||||
T7 = _mm512_mask_blend_ps(0xF0F0, T7, kernel.packet[7]);
|
||||
|
||||
kernel.packet[0] = T0; kernel.packet[1] = T1;
|
||||
kernel.packet[2] = T2; kernel.packet[3] = T3;
|
||||
kernel.packet[4] = T4; kernel.packet[5] = T5;
|
||||
kernel.packet[6] = T6; kernel.packet[7] = T7;
|
||||
kernel.packet[0] = T0;
|
||||
kernel.packet[1] = T1;
|
||||
kernel.packet[2] = T2;
|
||||
kernel.packet[3] = T3;
|
||||
kernel.packet[4] = T4;
|
||||
kernel.packet[5] = T5;
|
||||
kernel.packet[6] = T6;
|
||||
kernel.packet[7] = T7;
|
||||
}
|
||||
|
||||
template <>
|
||||
@ -126,7 +133,6 @@ public:
|
||||
using vecHalf = typename std::conditional<std::is_same<Scalar, float>::value, vecHalfFloat, vecFullDouble>::type;
|
||||
static constexpr int64_t PacketSize = packet_traits<Scalar>::size;
|
||||
|
||||
|
||||
/***********************************
|
||||
* Auxillary Functions for:
|
||||
* - storeC
|
||||
@ -143,9 +149,8 @@ public:
|
||||
*
|
||||
**/
|
||||
template <int64_t endN, int64_t counter, int64_t unrollN, int64_t packetIndexOffset, bool remM>
|
||||
static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0 && endN <= PacketSize)>
|
||||
aux_storeC(Scalar *C_arr, int64_t LDC,
|
||||
PacketBlock<vec,EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm, int64_t remM_ = 0) {
|
||||
static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0 && endN <= PacketSize)> aux_storeC(
|
||||
Scalar *C_arr, int64_t LDC, PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm, int64_t remM_ = 0) {
|
||||
constexpr int64_t counterReverse = endN - counter;
|
||||
constexpr int64_t startN = counterReverse;
|
||||
|
||||
@ -159,8 +164,7 @@ public:
|
||||
remMask<EIGEN_AVX_MAX_NUM_ROW>(remM_));
|
||||
}
|
||||
else {
|
||||
pstoreu<Scalar>(
|
||||
C_arr + LDC*startN,
|
||||
pstoreu<Scalar>(C_arr + LDC * startN,
|
||||
padd(ploadu<vecHalf>((const Scalar *)C_arr + LDC * startN),
|
||||
preinterpret<vecHalf>(zmm.packet[packetIndexOffset + (unrollN / PacketSize) * startN])));
|
||||
}
|
||||
@ -176,26 +180,25 @@ public:
|
||||
EIGEN_IF_CONSTEXPR(remM) {
|
||||
pstoreu<Scalar>(
|
||||
C_arr + LDC * startN,
|
||||
padd(ploadu<vecHalf>((const Scalar*)C_arr + LDC*startN,
|
||||
remMask<EIGEN_AVX_MAX_NUM_ROW>(remM_)),
|
||||
preinterpret<vecHalf>(zmm.packet[packetIndexOffset + (unrollN/PacketSize)*(startN-EIGEN_AVX_MAX_NUM_ROW)])),
|
||||
padd(ploadu<vecHalf>((const Scalar *)C_arr + LDC * startN, remMask<EIGEN_AVX_MAX_NUM_ROW>(remM_)),
|
||||
preinterpret<vecHalf>(
|
||||
zmm.packet[packetIndexOffset + (unrollN / PacketSize) * (startN - EIGEN_AVX_MAX_NUM_ROW)])),
|
||||
remMask<EIGEN_AVX_MAX_NUM_ROW>(remM_));
|
||||
}
|
||||
else {
|
||||
pstoreu<Scalar>(
|
||||
C_arr + LDC * startN,
|
||||
padd(ploadu<vecHalf>((const Scalar *)C_arr + LDC * startN),
|
||||
preinterpret<vecHalf>(zmm.packet[packetIndexOffset + (unrollN/PacketSize)*(startN-EIGEN_AVX_MAX_NUM_ROW)])));
|
||||
preinterpret<vecHalf>(
|
||||
zmm.packet[packetIndexOffset + (unrollN / PacketSize) * (startN - EIGEN_AVX_MAX_NUM_ROW)])));
|
||||
}
|
||||
}
|
||||
aux_storeC<endN, counter - 1, unrollN, packetIndexOffset, remM>(C_arr, LDC, zmm, remM_);
|
||||
}
|
||||
|
||||
template <int64_t endN, int64_t counter, int64_t unrollN, int64_t packetIndexOffset, bool remM>
|
||||
static EIGEN_ALWAYS_INLINE std::enable_if_t<!(counter > 0 && endN <= PacketSize)>
|
||||
aux_storeC(Scalar *C_arr, int64_t LDC,
|
||||
PacketBlock<vec,EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm, int64_t remM_ = 0)
|
||||
{
|
||||
static EIGEN_ALWAYS_INLINE std::enable_if_t<!(counter > 0 && endN <= PacketSize)> aux_storeC(
|
||||
Scalar *C_arr, int64_t LDC, PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm, int64_t remM_ = 0) {
|
||||
EIGEN_UNUSED_VARIABLE(C_arr);
|
||||
EIGEN_UNUSED_VARIABLE(LDC);
|
||||
EIGEN_UNUSED_VARIABLE(zmm);
|
||||
@ -203,9 +206,9 @@ public:
|
||||
}
|
||||
|
||||
template <int64_t endN, int64_t unrollN, int64_t packetIndexOffset, bool remM>
|
||||
static EIGEN_ALWAYS_INLINE
|
||||
void storeC(Scalar *C_arr, int64_t LDC,
|
||||
PacketBlock<vec,EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm, int64_t remM_ = 0){
|
||||
static EIGEN_ALWAYS_INLINE void storeC(Scalar *C_arr, int64_t LDC,
|
||||
PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm,
|
||||
int64_t remM_ = 0) {
|
||||
aux_storeC<endN, endN, unrollN, packetIndexOffset, remM>(C_arr, LDC, zmm, remM_);
|
||||
}
|
||||
|
||||
@ -235,8 +238,7 @@ public:
|
||||
* avx registers are being transposed.
|
||||
*/
|
||||
template <int64_t unrollN, int64_t packetIndexOffset>
|
||||
static EIGEN_ALWAYS_INLINE
|
||||
void transpose(PacketBlock<vec,EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm) {
|
||||
static EIGEN_ALWAYS_INLINE void transpose(PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm) {
|
||||
// Note: this assumes EIGEN_AVX_MAX_NUM_ROW = 8. Unrolls should be adjusted
|
||||
// accordingly if EIGEN_AVX_MAX_NUM_ROW is smaller.
|
||||
constexpr int64_t zmmStride = unrollN / PacketSize;
|
||||
@ -298,27 +300,25 @@ public:
|
||||
* for(startN = 0; startN < endN; startN++)
|
||||
**/
|
||||
template <int64_t endN, int64_t counter, int64_t packetIndexOffset, bool remM>
|
||||
static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)>
|
||||
aux_loadB(Scalar *B_arr, int64_t LDB,
|
||||
PacketBlock<vecHalf,EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &ymm, int64_t remM_ = 0) {
|
||||
static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_loadB(
|
||||
Scalar *B_arr, int64_t LDB, PacketBlock<vecHalf, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &ymm,
|
||||
int64_t remM_ = 0) {
|
||||
constexpr int64_t counterReverse = endN - counter;
|
||||
constexpr int64_t startN = counterReverse;
|
||||
|
||||
EIGEN_IF_CONSTEXPR(remM) {
|
||||
ymm.packet[packetIndexOffset + startN] = ploadu<vecHalf>(
|
||||
(const Scalar*)&B_arr[startN*LDB], remMask<EIGEN_AVX_MAX_NUM_ROW>(remM_));
|
||||
ymm.packet[packetIndexOffset + startN] =
|
||||
ploadu<vecHalf>((const Scalar *)&B_arr[startN * LDB], remMask<EIGEN_AVX_MAX_NUM_ROW>(remM_));
|
||||
}
|
||||
else
|
||||
ymm.packet[packetIndexOffset + startN] = ploadu<vecHalf>((const Scalar*)&B_arr[startN*LDB]);
|
||||
else ymm.packet[packetIndexOffset + startN] = ploadu<vecHalf>((const Scalar *)&B_arr[startN * LDB]);
|
||||
|
||||
aux_loadB<endN, counter - 1, packetIndexOffset, remM>(B_arr, LDB, ymm, remM_);
|
||||
}
|
||||
|
||||
template <int64_t endN, int64_t counter, int64_t packetIndexOffset, bool remM>
|
||||
static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)>
|
||||
aux_loadB(Scalar *B_arr, int64_t LDB,
|
||||
PacketBlock<vecHalf,EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &ymm, int64_t remM_ = 0)
|
||||
{
|
||||
static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_loadB(
|
||||
Scalar *B_arr, int64_t LDB, PacketBlock<vecHalf, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &ymm,
|
||||
int64_t remM_ = 0) {
|
||||
EIGEN_UNUSED_VARIABLE(B_arr);
|
||||
EIGEN_UNUSED_VARIABLE(LDB);
|
||||
EIGEN_UNUSED_VARIABLE(ymm);
|
||||
@ -332,16 +332,13 @@ public:
|
||||
* for(startN = 0; startN < endN; startN++)
|
||||
**/
|
||||
template <int64_t endN, int64_t counter, int64_t packetIndexOffset, bool remK, bool remM>
|
||||
static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)>
|
||||
aux_storeB(Scalar *B_arr, int64_t LDB,
|
||||
PacketBlock<vecHalf,EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &ymm, int64_t rem_ = 0) {
|
||||
static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_storeB(
|
||||
Scalar *B_arr, int64_t LDB, PacketBlock<vecHalf, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &ymm, int64_t rem_ = 0) {
|
||||
constexpr int64_t counterReverse = endN - counter;
|
||||
constexpr int64_t startN = counterReverse;
|
||||
|
||||
EIGEN_IF_CONSTEXPR(remK || remM) {
|
||||
pstoreu<Scalar>(
|
||||
&B_arr[startN*LDB],
|
||||
ymm.packet[packetIndexOffset + startN],
|
||||
pstoreu<Scalar>(&B_arr[startN * LDB], ymm.packet[packetIndexOffset + startN],
|
||||
remMask<EIGEN_AVX_MAX_NUM_ROW>(rem_));
|
||||
}
|
||||
else {
|
||||
@ -352,10 +349,8 @@ public:
|
||||
}
|
||||
|
||||
template <int64_t endN, int64_t counter, int64_t packetIndexOffset, bool remK, bool remM>
|
||||
static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)>
|
||||
aux_storeB(Scalar *B_arr, int64_t LDB,
|
||||
PacketBlock<vecHalf,EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &ymm, int64_t rem_ = 0)
|
||||
{
|
||||
static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_storeB(
|
||||
Scalar *B_arr, int64_t LDB, PacketBlock<vecHalf, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &ymm, int64_t rem_ = 0) {
|
||||
EIGEN_UNUSED_VARIABLE(B_arr);
|
||||
EIGEN_UNUSED_VARIABLE(LDB);
|
||||
EIGEN_UNUSED_VARIABLE(ymm);
|
||||
@ -369,23 +364,19 @@ public:
|
||||
* for(startN = 0; startN < endN; startN += EIGEN_AVX_MAX_NUM_ROW)
|
||||
**/
|
||||
template <int64_t endN, int64_t counter, bool toTemp, bool remM>
|
||||
static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)>
|
||||
aux_loadBBlock(Scalar *B_arr, int64_t LDB, Scalar *B_temp, int64_t LDB_,
|
||||
PacketBlock<vecHalf,EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &ymm,
|
||||
int64_t remM_ = 0) {
|
||||
static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_loadBBlock(
|
||||
Scalar *B_arr, int64_t LDB, Scalar *B_temp, int64_t LDB_,
|
||||
PacketBlock<vecHalf, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &ymm, int64_t remM_ = 0) {
|
||||
constexpr int64_t counterReverse = endN - counter;
|
||||
constexpr int64_t startN = counterReverse;
|
||||
transB::template loadB<EIGEN_AVX_MAX_NUM_ROW, startN, false>(&B_temp[startN], LDB_, ymm);
|
||||
aux_loadBBlock<endN, counter-EIGEN_AVX_MAX_NUM_ROW, toTemp, remM>(
|
||||
B_arr, LDB, B_temp, LDB_, ymm, remM_);
|
||||
aux_loadBBlock<endN, counter - EIGEN_AVX_MAX_NUM_ROW, toTemp, remM>(B_arr, LDB, B_temp, LDB_, ymm, remM_);
|
||||
}
|
||||
|
||||
template <int64_t endN, int64_t counter, bool toTemp, bool remM>
|
||||
static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)>
|
||||
aux_loadBBlock(Scalar *B_arr, int64_t LDB, Scalar *B_temp, int64_t LDB_,
|
||||
PacketBlock<vecHalf,EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &ymm,
|
||||
int64_t remM_ = 0)
|
||||
{
|
||||
static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_loadBBlock(
|
||||
Scalar *B_arr, int64_t LDB, Scalar *B_temp, int64_t LDB_,
|
||||
PacketBlock<vecHalf, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &ymm, int64_t remM_ = 0) {
|
||||
EIGEN_UNUSED_VARIABLE(B_arr);
|
||||
EIGEN_UNUSED_VARIABLE(LDB);
|
||||
EIGEN_UNUSED_VARIABLE(B_temp);
|
||||
@ -394,7 +385,6 @@ public:
|
||||
EIGEN_UNUSED_VARIABLE(remM_);
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* aux_storeBBlock
|
||||
*
|
||||
@ -402,31 +392,26 @@ public:
|
||||
* for(startN = 0; startN < endN; startN += EIGEN_AVX_MAX_NUM_ROW)
|
||||
**/
|
||||
template <int64_t endN, int64_t counter, bool toTemp, bool remM, int64_t remK_>
|
||||
static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)>
|
||||
aux_storeBBlock(Scalar *B_arr, int64_t LDB, Scalar *B_temp, int64_t LDB_,
|
||||
PacketBlock<vecHalf,EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &ymm,
|
||||
int64_t remM_ = 0) {
|
||||
static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_storeBBlock(
|
||||
Scalar *B_arr, int64_t LDB, Scalar *B_temp, int64_t LDB_,
|
||||
PacketBlock<vecHalf, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &ymm, int64_t remM_ = 0) {
|
||||
constexpr int64_t counterReverse = endN - counter;
|
||||
constexpr int64_t startN = counterReverse;
|
||||
|
||||
EIGEN_IF_CONSTEXPR(toTemp) {
|
||||
transB::template storeB<EIGEN_AVX_MAX_NUM_ROW,startN, remK_ != 0, false>(
|
||||
&B_temp[startN], LDB_, ymm, remK_);
|
||||
transB::template storeB<EIGEN_AVX_MAX_NUM_ROW, startN, remK_ != 0, false>(&B_temp[startN], LDB_, ymm, remK_);
|
||||
}
|
||||
else {
|
||||
transB::template storeB<std::min(EIGEN_AVX_MAX_NUM_ROW,endN),startN, false, remM>(
|
||||
&B_arr[0 + startN*LDB], LDB, ymm, remM_);
|
||||
transB::template storeB<std::min(EIGEN_AVX_MAX_NUM_ROW, endN), startN, false, remM>(&B_arr[0 + startN * LDB], LDB,
|
||||
ymm, remM_);
|
||||
}
|
||||
aux_storeBBlock<endN, counter-EIGEN_AVX_MAX_NUM_ROW, toTemp, remM, remK_>(
|
||||
B_arr, LDB, B_temp, LDB_, ymm, remM_);
|
||||
aux_storeBBlock<endN, counter - EIGEN_AVX_MAX_NUM_ROW, toTemp, remM, remK_>(B_arr, LDB, B_temp, LDB_, ymm, remM_);
|
||||
}
|
||||
|
||||
template <int64_t endN, int64_t counter, bool toTemp, bool remM, int64_t remK_>
|
||||
static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)>
|
||||
aux_storeBBlock(Scalar *B_arr, int64_t LDB, Scalar *B_temp, int64_t LDB_,
|
||||
PacketBlock<vecHalf,EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &ymm,
|
||||
int64_t remM_ = 0)
|
||||
{
|
||||
static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_storeBBlock(
|
||||
Scalar *B_arr, int64_t LDB, Scalar *B_temp, int64_t LDB_,
|
||||
PacketBlock<vecHalf, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &ymm, int64_t remM_ = 0) {
|
||||
EIGEN_UNUSED_VARIABLE(B_arr);
|
||||
EIGEN_UNUSED_VARIABLE(LDB);
|
||||
EIGEN_UNUSED_VARIABLE(B_temp);
|
||||
@ -435,51 +420,43 @@ public:
|
||||
EIGEN_UNUSED_VARIABLE(remM_);
|
||||
}
|
||||
|
||||
|
||||
/********************************************************
|
||||
* Wrappers for aux_XXXX to hide counter parameter
|
||||
********************************************************/
|
||||
|
||||
template <int64_t endN, int64_t packetIndexOffset, bool remM>
|
||||
static EIGEN_ALWAYS_INLINE
|
||||
void loadB(Scalar *B_arr, int64_t LDB,
|
||||
PacketBlock<vecHalf,EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &ymm, int64_t remM_ = 0) {
|
||||
static EIGEN_ALWAYS_INLINE void loadB(Scalar *B_arr, int64_t LDB,
|
||||
PacketBlock<vecHalf, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &ymm,
|
||||
int64_t remM_ = 0) {
|
||||
aux_loadB<endN, endN, packetIndexOffset, remM>(B_arr, LDB, ymm, remM_);
|
||||
}
|
||||
|
||||
template <int64_t endN, int64_t packetIndexOffset, bool remK, bool remM>
|
||||
static EIGEN_ALWAYS_INLINE
|
||||
void storeB(Scalar *B_arr, int64_t LDB,
|
||||
PacketBlock<vecHalf,EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &ymm, int64_t rem_ = 0) {
|
||||
static EIGEN_ALWAYS_INLINE void storeB(Scalar *B_arr, int64_t LDB,
|
||||
PacketBlock<vecHalf, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &ymm,
|
||||
int64_t rem_ = 0) {
|
||||
aux_storeB<endN, endN, packetIndexOffset, remK, remM>(B_arr, LDB, ymm, rem_);
|
||||
}
|
||||
|
||||
template <int64_t unrollN, bool toTemp, bool remM>
|
||||
static EIGEN_ALWAYS_INLINE
|
||||
void loadBBlock(Scalar *B_arr, int64_t LDB, Scalar *B_temp, int64_t LDB_,
|
||||
static EIGEN_ALWAYS_INLINE void loadBBlock(Scalar *B_arr, int64_t LDB, Scalar *B_temp, int64_t LDB_,
|
||||
PacketBlock<vecHalf, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &ymm,
|
||||
int64_t remM_ = 0) {
|
||||
EIGEN_IF_CONSTEXPR(toTemp) {
|
||||
transB::template loadB<unrollN,0,remM>(&B_arr[0],LDB, ymm, remM_);
|
||||
}
|
||||
EIGEN_IF_CONSTEXPR(toTemp) { transB::template loadB<unrollN, 0, remM>(&B_arr[0], LDB, ymm, remM_); }
|
||||
else {
|
||||
aux_loadBBlock<unrollN, unrollN, toTemp, remM>(
|
||||
B_arr, LDB, B_temp, LDB_, ymm, remM_);
|
||||
aux_loadBBlock<unrollN, unrollN, toTemp, remM>(B_arr, LDB, B_temp, LDB_, ymm, remM_);
|
||||
}
|
||||
}
|
||||
|
||||
template <int64_t unrollN, bool toTemp, bool remM, int64_t remK_>
|
||||
static EIGEN_ALWAYS_INLINE
|
||||
void storeBBlock(Scalar *B_arr, int64_t LDB, Scalar *B_temp, int64_t LDB_,
|
||||
static EIGEN_ALWAYS_INLINE void storeBBlock(Scalar *B_arr, int64_t LDB, Scalar *B_temp, int64_t LDB_,
|
||||
PacketBlock<vecHalf, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &ymm,
|
||||
int64_t remM_ = 0) {
|
||||
aux_storeBBlock<unrollN, unrollN, toTemp, remM, remK_>(
|
||||
B_arr, LDB, B_temp, LDB_, ymm, remM_);
|
||||
aux_storeBBlock<unrollN, unrollN, toTemp, remM, remK_>(B_arr, LDB, B_temp, LDB_, ymm, remM_);
|
||||
}
|
||||
|
||||
template <int64_t packetIndexOffset>
|
||||
static EIGEN_ALWAYS_INLINE
|
||||
void transposeLxL(PacketBlock<vecHalf,EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &ymm){
|
||||
static EIGEN_ALWAYS_INLINE void transposeLxL(PacketBlock<vecHalf, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &ymm) {
|
||||
// Note: this assumes EIGEN_AVX_MAX_NUM_ROW = 8. Unrolls should be adjusted
|
||||
// accordingly if EIGEN_AVX_MAX_NUM_ROW is smaller.
|
||||
PacketBlock<vecHalf, EIGEN_AVX_MAX_NUM_ROW> r;
|
||||
@ -503,9 +480,9 @@ public:
|
||||
}
|
||||
|
||||
template <int64_t unrollN, bool toTemp, bool remM>
|
||||
static EIGEN_ALWAYS_INLINE
|
||||
void transB_kernel(Scalar *B_arr, int64_t LDB, Scalar *B_temp, int64_t LDB_,
|
||||
PacketBlock<vecHalf,EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &ymm, int64_t remM_ = 0) {
|
||||
static EIGEN_ALWAYS_INLINE void transB_kernel(Scalar *B_arr, int64_t LDB, Scalar *B_temp, int64_t LDB_,
|
||||
PacketBlock<vecHalf, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &ymm,
|
||||
int64_t remM_ = 0) {
|
||||
constexpr int64_t U3 = PacketSize * 3;
|
||||
constexpr int64_t U2 = PacketSize * 2;
|
||||
constexpr int64_t U1 = PacketSize * 1;
|
||||
@ -526,11 +503,13 @@ public:
|
||||
transB::template storeBBlock<maxUBlock, toTemp, remM, 0>(B_arr, LDB, B_temp, LDB_, ymm, remM_);
|
||||
|
||||
EIGEN_IF_CONSTEXPR(maxUBlock < U3) {
|
||||
transB::template loadBBlock<maxUBlock,toTemp, remM>(&B_arr[maxUBlock*LDB], LDB, &B_temp[maxUBlock], LDB_, ymm, remM_);
|
||||
transB::template loadBBlock<maxUBlock, toTemp, remM>(&B_arr[maxUBlock * LDB], LDB, &B_temp[maxUBlock], LDB_,
|
||||
ymm, remM_);
|
||||
transB::template transposeLxL<0 * EIGEN_AVX_MAX_NUM_ROW>(ymm);
|
||||
transB::template transposeLxL<1 * EIGEN_AVX_MAX_NUM_ROW>(ymm);
|
||||
transB::template transposeLxL<2 * EIGEN_AVX_MAX_NUM_ROW>(ymm);
|
||||
transB::template storeBBlock<maxUBlock,toTemp, remM,0>(&B_arr[maxUBlock*LDB], LDB, &B_temp[maxUBlock], LDB_, ymm, remM_);
|
||||
transB::template storeBBlock<maxUBlock, toTemp, remM, 0>(&B_arr[maxUBlock * LDB], LDB, &B_temp[maxUBlock], LDB_,
|
||||
ymm, remM_);
|
||||
}
|
||||
}
|
||||
else EIGEN_IF_CONSTEXPR(unrollN == U2) {
|
||||
@ -543,20 +522,18 @@ public:
|
||||
transB::template storeBBlock<maxUBlock, toTemp, remM, 0>(B_arr, LDB, B_temp, LDB_, ymm, remM_);
|
||||
|
||||
EIGEN_IF_CONSTEXPR(maxUBlock < U2) {
|
||||
transB::template loadBBlock<EIGEN_AVX_MAX_NUM_ROW,toTemp, remM>(
|
||||
&B_arr[maxUBlock*LDB], LDB, &B_temp[maxUBlock], LDB_, ymm, remM_);
|
||||
transB::template loadBBlock<EIGEN_AVX_MAX_NUM_ROW, toTemp, remM>(&B_arr[maxUBlock * LDB], LDB,
|
||||
&B_temp[maxUBlock], LDB_, ymm, remM_);
|
||||
transB::template transposeLxL<0>(ymm);
|
||||
transB::template storeBBlock<EIGEN_AVX_MAX_NUM_ROW,toTemp,remM,0>(
|
||||
&B_arr[maxUBlock*LDB], LDB, &B_temp[maxUBlock], LDB_, ymm, remM_);
|
||||
transB::template storeBBlock<EIGEN_AVX_MAX_NUM_ROW, toTemp, remM, 0>(&B_arr[maxUBlock * LDB], LDB,
|
||||
&B_temp[maxUBlock], LDB_, ymm, remM_);
|
||||
}
|
||||
}
|
||||
else EIGEN_IF_CONSTEXPR(unrollN == U1) {
|
||||
// load LxU1 B col major, transpose LxU1 row major
|
||||
transB::template loadBBlock<U1, toTemp, remM>(B_arr, LDB, B_temp, LDB_, ymm, remM_);
|
||||
transB::template transposeLxL<0>(ymm);
|
||||
EIGEN_IF_CONSTEXPR(EIGEN_AVX_MAX_NUM_ROW < U1) {
|
||||
transB::template transposeLxL<1*EIGEN_AVX_MAX_NUM_ROW>(ymm);
|
||||
}
|
||||
EIGEN_IF_CONSTEXPR(EIGEN_AVX_MAX_NUM_ROW < U1) { transB::template transposeLxL<1 * EIGEN_AVX_MAX_NUM_ROW>(ymm); }
|
||||
transB::template storeBBlock<U1, toTemp, remM, 0>(B_arr, LDB, B_temp, LDB_, ymm, remM_);
|
||||
}
|
||||
else EIGEN_IF_CONSTEXPR(unrollN == 8 && U1 > 8) {
|
||||
@ -601,9 +578,7 @@ public:
|
||||
template <typename Scalar>
|
||||
class trsm {
|
||||
public:
|
||||
using vec = typename std::conditional<std::is_same<Scalar, float>::value,
|
||||
vecFullFloat,
|
||||
vecFullDouble>::type;
|
||||
using vec = typename std::conditional<std::is_same<Scalar, float>::value, vecFullFloat, vecFullDouble>::type;
|
||||
static constexpr int64_t PacketSize = packet_traits<Scalar>::size;
|
||||
|
||||
/***********************************
|
||||
@ -622,9 +597,8 @@ public:
|
||||
* for(startK = 0; startK < endK; startK++)
|
||||
**/
|
||||
template <bool isFWDSolve, int64_t endM, int64_t endK, int64_t counter, bool krem>
|
||||
static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)>
|
||||
aux_loadRHS(Scalar* B_arr, int64_t LDB, PacketBlock<vec,EIGEN_AVX_MAX_NUM_ACC> &RHSInPacket, int64_t rem = 0) {
|
||||
|
||||
static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_loadRHS(
|
||||
Scalar *B_arr, int64_t LDB, PacketBlock<vec, EIGEN_AVX_MAX_NUM_ACC> &RHSInPacket, int64_t rem = 0) {
|
||||
constexpr int64_t counterReverse = endM * endK - counter;
|
||||
constexpr int64_t startM = counterReverse / (endK);
|
||||
constexpr int64_t startK = counterReverse % endK;
|
||||
@ -642,9 +616,8 @@ public:
|
||||
}
|
||||
|
||||
template <bool isFWDSolve, int64_t endM, int64_t endK, int64_t counter, bool krem>
|
||||
static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)>
|
||||
aux_loadRHS(Scalar* B_arr, int64_t LDB, PacketBlock<vec,EIGEN_AVX_MAX_NUM_ACC> &RHSInPacket, int64_t rem = 0)
|
||||
{
|
||||
static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_loadRHS(
|
||||
Scalar *B_arr, int64_t LDB, PacketBlock<vec, EIGEN_AVX_MAX_NUM_ACC> &RHSInPacket, int64_t rem = 0) {
|
||||
EIGEN_UNUSED_VARIABLE(B_arr);
|
||||
EIGEN_UNUSED_VARIABLE(LDB);
|
||||
EIGEN_UNUSED_VARIABLE(RHSInPacket);
|
||||
@ -659,8 +632,8 @@ public:
|
||||
* for(startK = 0; startK < endK; startK++)
|
||||
**/
|
||||
template <bool isFWDSolve, int64_t endM, int64_t endK, int64_t counter, bool krem>
|
||||
static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)>
|
||||
aux_storeRHS(Scalar* B_arr, int64_t LDB, PacketBlock<vec,EIGEN_AVX_MAX_NUM_ACC> &RHSInPacket, int64_t rem = 0) {
|
||||
static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_storeRHS(
|
||||
Scalar *B_arr, int64_t LDB, PacketBlock<vec, EIGEN_AVX_MAX_NUM_ACC> &RHSInPacket, int64_t rem = 0) {
|
||||
constexpr int64_t counterReverse = endM * endK - counter;
|
||||
constexpr int64_t startM = counterReverse / (endK);
|
||||
constexpr int64_t startK = counterReverse % endK;
|
||||
@ -678,9 +651,8 @@ public:
|
||||
}
|
||||
|
||||
template <bool isFWDSolve, int64_t endM, int64_t endK, int64_t counter, bool krem>
|
||||
static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)>
|
||||
aux_storeRHS(Scalar* B_arr, int64_t LDB, PacketBlock<vec,EIGEN_AVX_MAX_NUM_ACC> &RHSInPacket, int64_t rem = 0)
|
||||
{
|
||||
static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_storeRHS(
|
||||
Scalar *B_arr, int64_t LDB, PacketBlock<vec, EIGEN_AVX_MAX_NUM_ACC> &RHSInPacket, int64_t rem = 0) {
|
||||
EIGEN_UNUSED_VARIABLE(B_arr);
|
||||
EIGEN_UNUSED_VARIABLE(LDB);
|
||||
EIGEN_UNUSED_VARIABLE(RHSInPacket);
|
||||
@ -696,8 +668,8 @@ public:
|
||||
* for(startK = 0; startK < endK; startK++)
|
||||
**/
|
||||
template <int64_t currM, int64_t endK, int64_t counter>
|
||||
static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0 && currM >= 0)>
|
||||
aux_divRHSByDiag(PacketBlock<vec,EIGEN_AVX_MAX_NUM_ACC> &RHSInPacket, PacketBlock<vec,EIGEN_AVX_MAX_NUM_ROW> &AInPacket) {
|
||||
static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0 && currM >= 0)> aux_divRHSByDiag(
|
||||
PacketBlock<vec, EIGEN_AVX_MAX_NUM_ACC> &RHSInPacket, PacketBlock<vec, EIGEN_AVX_MAX_NUM_ROW> &AInPacket) {
|
||||
constexpr int64_t counterReverse = endK - counter;
|
||||
constexpr int64_t startK = counterReverse;
|
||||
|
||||
@ -707,8 +679,8 @@ public:
|
||||
}
|
||||
|
||||
template <int64_t currM, int64_t endK, int64_t counter>
|
||||
static EIGEN_ALWAYS_INLINE std::enable_if_t<!(counter > 0 && currM >= 0)>
|
||||
aux_divRHSByDiag(PacketBlock<vec,EIGEN_AVX_MAX_NUM_ACC> &RHSInPacket, PacketBlock<vec,EIGEN_AVX_MAX_NUM_ROW> &AInPacket) {
|
||||
static EIGEN_ALWAYS_INLINE std::enable_if_t<!(counter > 0 && currM >= 0)> aux_divRHSByDiag(
|
||||
PacketBlock<vec, EIGEN_AVX_MAX_NUM_ACC> &RHSInPacket, PacketBlock<vec, EIGEN_AVX_MAX_NUM_ROW> &AInPacket) {
|
||||
EIGEN_UNUSED_VARIABLE(RHSInPacket);
|
||||
EIGEN_UNUSED_VARIABLE(AInPacket);
|
||||
}
|
||||
@ -720,10 +692,11 @@ public:
|
||||
* for(startM = initM; startM < endM; startM++)
|
||||
* for(startK = 0; startK < endK; startK++)
|
||||
**/
|
||||
template<bool isARowMajor, bool isFWDSolve, bool isUnitDiag, int64_t initM, int64_t endM, int64_t endK, int64_t counter, int64_t currentM>
|
||||
static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)>
|
||||
aux_updateRHS(Scalar *A_arr, int64_t LDA, PacketBlock<vec,EIGEN_AVX_MAX_NUM_ACC> &RHSInPacket, PacketBlock<vec,EIGEN_AVX_MAX_NUM_ROW> &AInPacket) {
|
||||
|
||||
template <bool isARowMajor, bool isFWDSolve, bool isUnitDiag, int64_t initM, int64_t endM, int64_t endK,
|
||||
int64_t counter, int64_t currentM>
|
||||
static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_updateRHS(
|
||||
Scalar *A_arr, int64_t LDA, PacketBlock<vec, EIGEN_AVX_MAX_NUM_ACC> &RHSInPacket,
|
||||
PacketBlock<vec, EIGEN_AVX_MAX_NUM_ROW> &AInPacket) {
|
||||
constexpr int64_t counterReverse = (endM - initM) * endK - counter;
|
||||
constexpr int64_t startM = initM + counterReverse / (endK);
|
||||
constexpr int64_t startK = counterReverse % endK;
|
||||
@ -732,8 +705,7 @@ public:
|
||||
constexpr int64_t packetIndex = startM * endK + startK;
|
||||
EIGEN_IF_CONSTEXPR(currentM > 0) {
|
||||
RHSInPacket.packet[packetIndex] =
|
||||
pnmadd(AInPacket.packet[startM],
|
||||
RHSInPacket.packet[(currentM-1)*endK+startK],
|
||||
pnmadd(AInPacket.packet[startM], RHSInPacket.packet[(currentM - 1) * endK + startK],
|
||||
RHSInPacket.packet[packetIndex]);
|
||||
}
|
||||
|
||||
@ -744,24 +716,25 @@ public:
|
||||
// This will be used in divRHSByDiag
|
||||
EIGEN_IF_CONSTEXPR(isFWDSolve)
|
||||
AInPacket.packet[currentM] = pset1<vec>(Scalar(1) / A_arr[idA<isARowMajor>(currentM, currentM, LDA)]);
|
||||
else
|
||||
AInPacket.packet[currentM] = pset1<vec>(Scalar(1)/A_arr[idA<isARowMajor>(-currentM,-currentM,LDA)]);
|
||||
else AInPacket.packet[currentM] = pset1<vec>(Scalar(1) / A_arr[idA<isARowMajor>(-currentM, -currentM, LDA)]);
|
||||
}
|
||||
else {
|
||||
// Broadcast next off diagonal element of A
|
||||
EIGEN_IF_CONSTEXPR(isFWDSolve)
|
||||
AInPacket.packet[startM] = pset1<vec>(A_arr[idA<isARowMajor>(startM, currentM, LDA)]);
|
||||
else
|
||||
AInPacket.packet[startM] = pset1<vec>(A_arr[idA<isARowMajor>(-startM,-currentM,LDA)]);
|
||||
else AInPacket.packet[startM] = pset1<vec>(A_arr[idA<isARowMajor>(-startM, -currentM, LDA)]);
|
||||
}
|
||||
}
|
||||
|
||||
aux_updateRHS<isARowMajor, isFWDSolve, isUnitDiag, initM, endM, endK, counter - 1, currentM>(A_arr, LDA, RHSInPacket, AInPacket);
|
||||
aux_updateRHS<isARowMajor, isFWDSolve, isUnitDiag, initM, endM, endK, counter - 1, currentM>(
|
||||
A_arr, LDA, RHSInPacket, AInPacket);
|
||||
}
|
||||
|
||||
template<bool isARowMajor, bool isFWDSolve, bool isUnitDiag, int64_t initM, int64_t endM, int64_t endK, int64_t counter, int64_t currentM>
|
||||
static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)>
|
||||
aux_updateRHS(Scalar *A_arr, int64_t LDA, PacketBlock<vec,EIGEN_AVX_MAX_NUM_ACC> &RHSInPacket, PacketBlock<vec,EIGEN_AVX_MAX_NUM_ROW> &AInPacket) {
|
||||
template <bool isARowMajor, bool isFWDSolve, bool isUnitDiag, int64_t initM, int64_t endM, int64_t endK,
|
||||
int64_t counter, int64_t currentM>
|
||||
static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_updateRHS(
|
||||
Scalar *A_arr, int64_t LDA, PacketBlock<vec, EIGEN_AVX_MAX_NUM_ACC> &RHSInPacket,
|
||||
PacketBlock<vec, EIGEN_AVX_MAX_NUM_ROW> &AInPacket) {
|
||||
EIGEN_UNUSED_VARIABLE(A_arr);
|
||||
EIGEN_UNUSED_VARIABLE(LDA);
|
||||
EIGEN_UNUSED_VARIABLE(RHSInPacket);
|
||||
@ -775,9 +748,9 @@ public:
|
||||
* for(startM = 0; startM < endM; startM++)
|
||||
**/
|
||||
template <bool isARowMajor, bool isFWDSolve, bool isUnitDiag, int64_t endM, int64_t counter, int64_t numK>
|
||||
static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)>
|
||||
aux_triSolveMicroKernel(Scalar *A_arr, int64_t LDA, PacketBlock<vec,EIGEN_AVX_MAX_NUM_ACC> &RHSInPacket, PacketBlock<vec,EIGEN_AVX_MAX_NUM_ROW> &AInPacket) {
|
||||
|
||||
static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_triSolveMicroKernel(
|
||||
Scalar *A_arr, int64_t LDA, PacketBlock<vec, EIGEN_AVX_MAX_NUM_ACC> &RHSInPacket,
|
||||
PacketBlock<vec, EIGEN_AVX_MAX_NUM_ROW> &AInPacket) {
|
||||
constexpr int64_t counterReverse = endM - counter;
|
||||
constexpr int64_t startM = counterReverse;
|
||||
|
||||
@ -794,21 +767,21 @@ public:
|
||||
// After division, the rhs corresponding to subsequent rows of A can be partially updated
|
||||
// We also broadcast the reciprocal of the next diagonal to AInPacket.packet[currentM] (if needed)
|
||||
// to be used in the next iteration.
|
||||
trsm::template
|
||||
updateRHS<isARowMajor, isFWDSolve, isUnitDiag, startM, endM, numK, currentM>(
|
||||
A_arr, LDA, RHSInPacket, AInPacket);
|
||||
trsm::template updateRHS<isARowMajor, isFWDSolve, isUnitDiag, startM, endM, numK, currentM>(A_arr, LDA, RHSInPacket,
|
||||
AInPacket);
|
||||
|
||||
// Handle division for the RHS corresponding to the final row of A.
|
||||
EIGEN_IF_CONSTEXPR(!isUnitDiag && startM == endM - 1)
|
||||
trsm::template divRHSByDiag<startM, numK>(RHSInPacket, AInPacket);
|
||||
|
||||
aux_triSolveMicroKernel<isARowMajor, isFWDSolve, isUnitDiag, endM, counter - 1, numK>(A_arr, LDA, RHSInPacket, AInPacket);
|
||||
aux_triSolveMicroKernel<isARowMajor, isFWDSolve, isUnitDiag, endM, counter - 1, numK>(A_arr, LDA, RHSInPacket,
|
||||
AInPacket);
|
||||
}
|
||||
|
||||
template <bool isARowMajor, bool isFWDSolve, bool isUnitDiag, int64_t endM, int64_t counter, int64_t numK>
|
||||
static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)>
|
||||
aux_triSolveMicroKernel(Scalar *A_arr, int64_t LDA, PacketBlock<vec,EIGEN_AVX_MAX_NUM_ACC> &RHSInPacket, PacketBlock<vec,EIGEN_AVX_MAX_NUM_ROW> &AInPacket)
|
||||
{
|
||||
static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_triSolveMicroKernel(
|
||||
Scalar *A_arr, int64_t LDA, PacketBlock<vec, EIGEN_AVX_MAX_NUM_ACC> &RHSInPacket,
|
||||
PacketBlock<vec, EIGEN_AVX_MAX_NUM_ROW> &AInPacket) {
|
||||
EIGEN_UNUSED_VARIABLE(A_arr);
|
||||
EIGEN_UNUSED_VARIABLE(LDA);
|
||||
EIGEN_UNUSED_VARIABLE(RHSInPacket);
|
||||
@ -824,8 +797,8 @@ public:
|
||||
* Masked loads are used for cases where endK is not a multiple of PacketSize
|
||||
*/
|
||||
template <bool isFWDSolve, int64_t endM, int64_t endK, bool krem = false>
|
||||
static EIGEN_ALWAYS_INLINE
|
||||
void loadRHS(Scalar* B_arr, int64_t LDB, PacketBlock<vec,EIGEN_AVX_MAX_NUM_ACC> &RHSInPacket, int64_t rem = 0) {
|
||||
static EIGEN_ALWAYS_INLINE void loadRHS(Scalar *B_arr, int64_t LDB,
|
||||
PacketBlock<vec, EIGEN_AVX_MAX_NUM_ACC> &RHSInPacket, int64_t rem = 0) {
|
||||
aux_loadRHS<isFWDSolve, endM, endK, endM * endK, krem>(B_arr, LDB, RHSInPacket, rem);
|
||||
}
|
||||
|
||||
@ -834,8 +807,8 @@ public:
|
||||
* Masked loads are used for cases where endK is not a multiple of PacketSize
|
||||
*/
|
||||
template <bool isFWDSolve, int64_t endM, int64_t endK, bool krem = false>
|
||||
static EIGEN_ALWAYS_INLINE
|
||||
void storeRHS(Scalar* B_arr, int64_t LDB, PacketBlock<vec,EIGEN_AVX_MAX_NUM_ACC> &RHSInPacket, int64_t rem = 0) {
|
||||
static EIGEN_ALWAYS_INLINE void storeRHS(Scalar *B_arr, int64_t LDB,
|
||||
PacketBlock<vec, EIGEN_AVX_MAX_NUM_ACC> &RHSInPacket, int64_t rem = 0) {
|
||||
aux_storeRHS<isFWDSolve, endM, endK, endM * endK, krem>(B_arr, LDB, RHSInPacket, rem);
|
||||
}
|
||||
|
||||
@ -843,8 +816,8 @@ public:
|
||||
* Only used if Triangular matrix has non-unit diagonal values
|
||||
*/
|
||||
template <int64_t currM, int64_t endK>
|
||||
static EIGEN_ALWAYS_INLINE
|
||||
void divRHSByDiag(PacketBlock<vec,EIGEN_AVX_MAX_NUM_ACC> &RHSInPacket, PacketBlock<vec,EIGEN_AVX_MAX_NUM_ROW> &AInPacket) {
|
||||
static EIGEN_ALWAYS_INLINE void divRHSByDiag(PacketBlock<vec, EIGEN_AVX_MAX_NUM_ACC> &RHSInPacket,
|
||||
PacketBlock<vec, EIGEN_AVX_MAX_NUM_ROW> &AInPacket) {
|
||||
aux_divRHSByDiag<currM, endK, endK>(RHSInPacket, AInPacket);
|
||||
}
|
||||
|
||||
@ -852,9 +825,11 @@ public:
|
||||
* Update right-hand sides (stored in avx registers)
|
||||
* Traversing along the column A_{i,currentM}, where currentM <= i <= endM, and broadcasting each value to AInPacket.
|
||||
**/
|
||||
template<bool isARowMajor, bool isFWDSolve, bool isUnitDiag, int64_t startM, int64_t endM, int64_t endK, int64_t currentM>
|
||||
static EIGEN_ALWAYS_INLINE
|
||||
void updateRHS(Scalar *A_arr, int64_t LDA, PacketBlock<vec,EIGEN_AVX_MAX_NUM_ACC> &RHSInPacket, PacketBlock<vec,EIGEN_AVX_MAX_NUM_ROW> &AInPacket) {
|
||||
template <bool isARowMajor, bool isFWDSolve, bool isUnitDiag, int64_t startM, int64_t endM, int64_t endK,
|
||||
int64_t currentM>
|
||||
static EIGEN_ALWAYS_INLINE void updateRHS(Scalar *A_arr, int64_t LDA,
|
||||
PacketBlock<vec, EIGEN_AVX_MAX_NUM_ACC> &RHSInPacket,
|
||||
PacketBlock<vec, EIGEN_AVX_MAX_NUM_ROW> &AInPacket) {
|
||||
aux_updateRHS<isARowMajor, isFWDSolve, isUnitDiag, startM, endM, endK, (endM - startM) * endK, currentM>(
|
||||
A_arr, LDA, RHSInPacket, AInPacket);
|
||||
}
|
||||
@ -866,11 +841,11 @@ public:
|
||||
* isUnitDiag: true => triangular matrix has unit diagonal.
|
||||
*/
|
||||
template <bool isARowMajor, bool isFWDSolve, bool isUnitDiag, int64_t endM, int64_t numK>
|
||||
static EIGEN_ALWAYS_INLINE
|
||||
void triSolveMicroKernel(Scalar *A_arr, int64_t LDA, PacketBlock<vec,EIGEN_AVX_MAX_NUM_ACC> &RHSInPacket, PacketBlock<vec,EIGEN_AVX_MAX_NUM_ROW> &AInPacket) {
|
||||
static EIGEN_ALWAYS_INLINE void triSolveMicroKernel(Scalar *A_arr, int64_t LDA,
|
||||
PacketBlock<vec, EIGEN_AVX_MAX_NUM_ACC> &RHSInPacket,
|
||||
PacketBlock<vec, EIGEN_AVX_MAX_NUM_ROW> &AInPacket) {
|
||||
static_assert(numK >= 1 && numK <= 3, "numK out of range");
|
||||
aux_triSolveMicroKernel<isARowMajor, isFWDSolve, isUnitDiag, endM, endM, numK>(
|
||||
A_arr, LDA, RHSInPacket, AInPacket);
|
||||
aux_triSolveMicroKernel<isARowMajor, isFWDSolve, isUnitDiag, endM, endM, numK>(A_arr, LDA, RHSInPacket, AInPacket);
|
||||
}
|
||||
};
|
||||
|
||||
@ -902,8 +877,8 @@ public:
|
||||
* for(startN = 0; startN < endN; startN++)
|
||||
**/
|
||||
template <int64_t endM, int64_t endN, int64_t counter>
|
||||
static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)>
|
||||
aux_setzero(PacketBlock<vec,EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm) {
|
||||
static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_setzero(
|
||||
PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm) {
|
||||
constexpr int64_t counterReverse = endM * endN - counter;
|
||||
constexpr int64_t startM = counterReverse / (endN);
|
||||
constexpr int64_t startN = counterReverse % endN;
|
||||
@ -913,9 +888,8 @@ public:
|
||||
}
|
||||
|
||||
template <int64_t endM, int64_t endN, int64_t counter>
|
||||
static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)>
|
||||
aux_setzero(PacketBlock<vec,EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm)
|
||||
{
|
||||
static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_setzero(
|
||||
PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm) {
|
||||
EIGEN_UNUSED_VARIABLE(zmm);
|
||||
}
|
||||
|
||||
@ -927,8 +901,8 @@ public:
|
||||
* for(startN = 0; startN < endN; startN++)
|
||||
**/
|
||||
template <int64_t endM, int64_t endN, int64_t counter, bool rem>
|
||||
static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)>
|
||||
aux_updateC(Scalar *C_arr, int64_t LDC, PacketBlock<vec,EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm, int64_t rem_ = 0) {
|
||||
static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_updateC(
|
||||
Scalar *C_arr, int64_t LDC, PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm, int64_t rem_ = 0) {
|
||||
EIGEN_UNUSED_VARIABLE(rem_);
|
||||
constexpr int64_t counterReverse = endM * endN - counter;
|
||||
constexpr int64_t startM = counterReverse / (endN);
|
||||
@ -937,18 +911,15 @@ public:
|
||||
EIGEN_IF_CONSTEXPR(rem)
|
||||
zmm.packet[startN * endM + startM] =
|
||||
padd(ploadu<vec>(&C_arr[(startN)*LDC + startM * PacketSize], remMask<PacketSize>(rem_)),
|
||||
zmm.packet[startN*endM + startM],
|
||||
remMask<PacketSize>(rem_));
|
||||
else
|
||||
zmm.packet[startN*endM + startM] =
|
||||
zmm.packet[startN * endM + startM], remMask<PacketSize>(rem_));
|
||||
else zmm.packet[startN * endM + startM] =
|
||||
padd(ploadu<vec>(&C_arr[(startN)*LDC + startM * PacketSize]), zmm.packet[startN * endM + startM]);
|
||||
aux_updateC<endM, endN, counter - 1, rem>(C_arr, LDC, zmm, rem_);
|
||||
}
|
||||
|
||||
template <int64_t endM, int64_t endN, int64_t counter, bool rem>
|
||||
static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)>
|
||||
aux_updateC(Scalar *C_arr, int64_t LDC, PacketBlock<vec,EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm, int64_t rem_ = 0)
|
||||
{
|
||||
static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_updateC(
|
||||
Scalar *C_arr, int64_t LDC, PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm, int64_t rem_ = 0) {
|
||||
EIGEN_UNUSED_VARIABLE(C_arr);
|
||||
EIGEN_UNUSED_VARIABLE(LDC);
|
||||
EIGEN_UNUSED_VARIABLE(zmm);
|
||||
@ -963,24 +934,23 @@ public:
|
||||
* for(startN = 0; startN < endN; startN++)
|
||||
**/
|
||||
template <int64_t endM, int64_t endN, int64_t counter, bool rem>
|
||||
static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)>
|
||||
aux_storeC(Scalar *C_arr, int64_t LDC, PacketBlock<vec,EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm, int64_t rem_ = 0) {
|
||||
static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_storeC(
|
||||
Scalar *C_arr, int64_t LDC, PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm, int64_t rem_ = 0) {
|
||||
EIGEN_UNUSED_VARIABLE(rem_);
|
||||
constexpr int64_t counterReverse = endM * endN - counter;
|
||||
constexpr int64_t startM = counterReverse / (endN);
|
||||
constexpr int64_t startN = counterReverse % endN;
|
||||
|
||||
EIGEN_IF_CONSTEXPR(rem)
|
||||
pstoreu<Scalar>(&C_arr[(startN) * LDC + startM*PacketSize], zmm.packet[startN*endM + startM], remMask<PacketSize>(rem_));
|
||||
else
|
||||
pstoreu<Scalar>(&C_arr[(startN) * LDC + startM*PacketSize], zmm.packet[startN*endM + startM]);
|
||||
pstoreu<Scalar>(&C_arr[(startN)*LDC + startM * PacketSize], zmm.packet[startN * endM + startM],
|
||||
remMask<PacketSize>(rem_));
|
||||
else pstoreu<Scalar>(&C_arr[(startN)*LDC + startM * PacketSize], zmm.packet[startN * endM + startM]);
|
||||
aux_storeC<endM, endN, counter - 1, rem>(C_arr, LDC, zmm, rem_);
|
||||
}
|
||||
|
||||
template <int64_t endM, int64_t endN, int64_t counter, bool rem>
|
||||
static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)>
|
||||
aux_storeC(Scalar *C_arr, int64_t LDC, PacketBlock<vec,EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm, int64_t rem_ = 0)
|
||||
{
|
||||
static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_storeC(
|
||||
Scalar *C_arr, int64_t LDC, PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm, int64_t rem_ = 0) {
|
||||
EIGEN_UNUSED_VARIABLE(C_arr);
|
||||
EIGEN_UNUSED_VARIABLE(LDC);
|
||||
EIGEN_UNUSED_VARIABLE(zmm);
|
||||
@ -994,8 +964,8 @@ public:
|
||||
* for(startL = 0; startL < endL; startL++)
|
||||
**/
|
||||
template <int64_t unrollM, int64_t unrollN, int64_t endL, int64_t counter, bool rem>
|
||||
static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)>
|
||||
aux_startLoadB(Scalar *B_t, int64_t LDB, PacketBlock<vec,EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm, int64_t rem_ = 0) {
|
||||
static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_startLoadB(
|
||||
Scalar *B_t, int64_t LDB, PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm, int64_t rem_ = 0) {
|
||||
EIGEN_UNUSED_VARIABLE(rem_);
|
||||
constexpr int64_t counterReverse = endL - counter;
|
||||
constexpr int64_t startL = counterReverse;
|
||||
@ -1003,18 +973,15 @@ public:
|
||||
EIGEN_IF_CONSTEXPR(rem)
|
||||
zmm.packet[unrollM * unrollN + startL] =
|
||||
ploadu<vec>(&B_t[(startL / unrollM) * LDB + (startL % unrollM) * PacketSize], remMask<PacketSize>(rem_));
|
||||
else
|
||||
zmm.packet[unrollM*unrollN+startL] = ploadu<vec>(&B_t[(startL/unrollM)*LDB + (startL%unrollM)*PacketSize]);
|
||||
else zmm.packet[unrollM * unrollN + startL] =
|
||||
ploadu<vec>(&B_t[(startL / unrollM) * LDB + (startL % unrollM) * PacketSize]);
|
||||
|
||||
aux_startLoadB<unrollM, unrollN, endL, counter - 1, rem>(B_t, LDB, zmm, rem_);
|
||||
}
|
||||
|
||||
template <int64_t unrollM, int64_t unrollN, int64_t endL, int64_t counter, bool rem>
|
||||
static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)>
|
||||
aux_startLoadB(
|
||||
Scalar *B_t, int64_t LDB,
|
||||
PacketBlock<vec,EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm, int64_t rem_ = 0)
|
||||
{
|
||||
static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_startLoadB(
|
||||
Scalar *B_t, int64_t LDB, PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm, int64_t rem_ = 0) {
|
||||
EIGEN_UNUSED_VARIABLE(B_t);
|
||||
EIGEN_UNUSED_VARIABLE(LDB);
|
||||
EIGEN_UNUSED_VARIABLE(zmm);
|
||||
@ -1028,8 +995,8 @@ public:
|
||||
* for(startB = 0; startB < endB; startB++)
|
||||
**/
|
||||
template <bool isARowMajor, int64_t unrollM, int64_t unrollN, int64_t endB, int64_t counter, int64_t numLoad>
|
||||
static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)>
|
||||
aux_startBCastA(Scalar *A_t, int64_t LDA, PacketBlock<vec,EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm) {
|
||||
static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_startBCastA(
|
||||
Scalar *A_t, int64_t LDA, PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm) {
|
||||
constexpr int64_t counterReverse = endB - counter;
|
||||
constexpr int64_t startB = counterReverse;
|
||||
|
||||
@ -1039,9 +1006,8 @@ public:
|
||||
}
|
||||
|
||||
template <bool isARowMajor, int64_t unrollM, int64_t unrollN, int64_t endB, int64_t counter, int64_t numLoad>
|
||||
static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)>
|
||||
aux_startBCastA(Scalar *A_t, int64_t LDA, PacketBlock<vec,EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm)
|
||||
{
|
||||
static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_startBCastA(
|
||||
Scalar *A_t, int64_t LDA, PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm) {
|
||||
EIGEN_UNUSED_VARIABLE(A_t);
|
||||
EIGEN_UNUSED_VARIABLE(LDA);
|
||||
EIGEN_UNUSED_VARIABLE(zmm);
|
||||
@ -1054,9 +1020,10 @@ public:
|
||||
* 1-D unroll
|
||||
* for(startM = 0; startM < endM; startM++)
|
||||
**/
|
||||
template<int64_t endM, int64_t counter, int64_t unrollN, int64_t currK, int64_t unrollK, int64_t numLoad, int64_t numBCast, bool rem>
|
||||
static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)>
|
||||
aux_loadB(Scalar *B_t, int64_t LDB, PacketBlock<vec,EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm, int64_t rem_ = 0) {
|
||||
template <int64_t endM, int64_t counter, int64_t unrollN, int64_t currK, int64_t unrollK, int64_t numLoad,
|
||||
int64_t numBCast, bool rem>
|
||||
static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_loadB(
|
||||
Scalar *B_t, int64_t LDB, PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm, int64_t rem_ = 0) {
|
||||
EIGEN_UNUSED_VARIABLE(rem_);
|
||||
if ((numLoad / endM + currK < unrollK)) {
|
||||
constexpr int64_t counterReverse = endM - counter;
|
||||
@ -1075,12 +1042,10 @@ public:
|
||||
}
|
||||
}
|
||||
|
||||
template<int64_t endM, int64_t counter, int64_t unrollN, int64_t currK, int64_t unrollK, int64_t numLoad, int64_t numBCast, bool rem>
|
||||
static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)>
|
||||
aux_loadB(
|
||||
Scalar *B_t, int64_t LDB,
|
||||
PacketBlock<vec,EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm, int64_t rem_ = 0)
|
||||
{
|
||||
template <int64_t endM, int64_t counter, int64_t unrollN, int64_t currK, int64_t unrollK, int64_t numLoad,
|
||||
int64_t numBCast, bool rem>
|
||||
static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_loadB(
|
||||
Scalar *B_t, int64_t LDB, PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm, int64_t rem_ = 0) {
|
||||
EIGEN_UNUSED_VARIABLE(B_t);
|
||||
EIGEN_UNUSED_VARIABLE(LDB);
|
||||
EIGEN_UNUSED_VARIABLE(zmm);
|
||||
@ -1095,11 +1060,11 @@ public:
|
||||
* for(startN = 0; startN < endN; startN++)
|
||||
* for(startK = 0; startK < endK; startK++)
|
||||
**/
|
||||
template<bool isARowMajor, int64_t endM, int64_t endN, int64_t endK, int64_t counter, int64_t numLoad, int64_t numBCast, bool rem>
|
||||
static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)>
|
||||
aux_microKernel(
|
||||
Scalar *B_t, Scalar* A_t, int64_t LDB, int64_t LDA,
|
||||
PacketBlock<vec,EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm, int64_t rem_ = 0) {
|
||||
template <bool isARowMajor, int64_t endM, int64_t endN, int64_t endK, int64_t counter, int64_t numLoad,
|
||||
int64_t numBCast, bool rem>
|
||||
static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_microKernel(
|
||||
Scalar *B_t, Scalar *A_t, int64_t LDB, int64_t LDA, PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm,
|
||||
int64_t rem_ = 0) {
|
||||
EIGEN_UNUSED_VARIABLE(rem_);
|
||||
constexpr int64_t counterReverse = endM * endN * endK - counter;
|
||||
constexpr int startK = counterReverse / (endM * endN);
|
||||
@ -1107,10 +1072,8 @@ public:
|
||||
constexpr int startM = counterReverse % endM;
|
||||
|
||||
EIGEN_IF_CONSTEXPR(startK == 0 && startM == 0 && startN == 0) {
|
||||
gemm:: template
|
||||
startLoadB<endM, endN, numLoad, rem>(B_t, LDB, zmm, rem_);
|
||||
gemm:: template
|
||||
startBCastA<isARowMajor, endM, endN, numBCast, numLoad>(A_t, LDA, zmm);
|
||||
gemm::template startLoadB<endM, endN, numLoad, rem>(B_t, LDB, zmm, rem_);
|
||||
gemm::template startBCastA<isARowMajor, endM, endN, numBCast, numLoad>(A_t, LDA, zmm);
|
||||
}
|
||||
|
||||
{
|
||||
@ -1127,9 +1090,8 @@ public:
|
||||
}
|
||||
// Bcast
|
||||
EIGEN_IF_CONSTEXPR(startM == endM - 1 && (numBCast + startN + startK * endN < endK * endN)) {
|
||||
zmm.packet[endM*endN+numLoad+(startN+startK*endN)%numBCast] =
|
||||
pload1<vec>(&A_t[idA<isARowMajor>((numBCast + startN + startK*endN)%endN,
|
||||
(numBCast + startN + startK*endN)/endN, LDA)]);
|
||||
zmm.packet[endM * endN + numLoad + (startN + startK * endN) % numBCast] = pload1<vec>(&A_t[idA<isARowMajor>(
|
||||
(numBCast + startN + startK * endN) % endN, (numBCast + startN + startK * endN) / endN, LDA)]);
|
||||
}
|
||||
}
|
||||
|
||||
@ -1138,15 +1100,13 @@ public:
|
||||
gemm::template loadB<endM, endN, startK, endK, numLoad, numBCast, rem>(B_t, LDB, zmm, rem_);
|
||||
}
|
||||
aux_microKernel<isARowMajor, endM, endN, endK, counter - 1, numLoad, numBCast, rem>(B_t, A_t, LDB, LDA, zmm, rem_);
|
||||
|
||||
}
|
||||
|
||||
template<bool isARowMajor, int64_t endM, int64_t endN, int64_t endK, int64_t counter, int64_t numLoad, int64_t numBCast, bool rem>
|
||||
static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)>
|
||||
aux_microKernel(
|
||||
Scalar *B_t, Scalar* A_t, int64_t LDB, int64_t LDA,
|
||||
PacketBlock<vec,EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm, int64_t rem_ = 0)
|
||||
{
|
||||
template <bool isARowMajor, int64_t endM, int64_t endN, int64_t endK, int64_t counter, int64_t numLoad,
|
||||
int64_t numBCast, bool rem>
|
||||
static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_microKernel(
|
||||
Scalar *B_t, Scalar *A_t, int64_t LDB, int64_t LDA, PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm,
|
||||
int64_t rem_ = 0) {
|
||||
EIGEN_UNUSED_VARIABLE(B_t);
|
||||
EIGEN_UNUSED_VARIABLE(A_t);
|
||||
EIGEN_UNUSED_VARIABLE(LDB);
|
||||
@ -1160,8 +1120,7 @@ public:
|
||||
********************************************************/
|
||||
|
||||
template <int64_t endM, int64_t endN>
|
||||
static EIGEN_ALWAYS_INLINE
|
||||
void setzero(PacketBlock<vec,EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm){
|
||||
static EIGEN_ALWAYS_INLINE void setzero(PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm) {
|
||||
aux_setzero<endM, endN, endM * endN>(zmm);
|
||||
}
|
||||
|
||||
@ -1169,15 +1128,17 @@ public:
|
||||
* Ideally the compiler folds these into vaddp{s,d} with an embedded memory load.
|
||||
*/
|
||||
template <int64_t endM, int64_t endN, bool rem = false>
|
||||
static EIGEN_ALWAYS_INLINE
|
||||
void updateC(Scalar *C_arr, int64_t LDC, PacketBlock<vec,EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm, int64_t rem_ = 0){
|
||||
static EIGEN_ALWAYS_INLINE void updateC(Scalar *C_arr, int64_t LDC,
|
||||
PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm,
|
||||
int64_t rem_ = 0) {
|
||||
EIGEN_UNUSED_VARIABLE(rem_);
|
||||
aux_updateC<endM, endN, endM * endN, rem>(C_arr, LDC, zmm, rem_);
|
||||
}
|
||||
|
||||
template <int64_t endM, int64_t endN, bool rem = false>
|
||||
static EIGEN_ALWAYS_INLINE
|
||||
void storeC(Scalar *C_arr, int64_t LDC, PacketBlock<vec,EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm, int64_t rem_ = 0){
|
||||
static EIGEN_ALWAYS_INLINE void storeC(Scalar *C_arr, int64_t LDC,
|
||||
PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm,
|
||||
int64_t rem_ = 0) {
|
||||
EIGEN_UNUSED_VARIABLE(rem_);
|
||||
aux_storeC<endM, endN, endM * endN, rem>(C_arr, LDC, zmm, rem_);
|
||||
}
|
||||
@ -1186,8 +1147,9 @@ public:
|
||||
* Use numLoad registers for loading B at start of microKernel
|
||||
*/
|
||||
template <int64_t unrollM, int64_t unrollN, int64_t endL, bool rem>
|
||||
static EIGEN_ALWAYS_INLINE
|
||||
void startLoadB(Scalar *B_t, int64_t LDB, PacketBlock<vec,EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm, int64_t rem_ = 0){
|
||||
static EIGEN_ALWAYS_INLINE void startLoadB(Scalar *B_t, int64_t LDB,
|
||||
PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm,
|
||||
int64_t rem_ = 0) {
|
||||
EIGEN_UNUSED_VARIABLE(rem_);
|
||||
aux_startLoadB<unrollM, unrollN, endL, endL, rem>(B_t, LDB, zmm, rem_);
|
||||
}
|
||||
@ -1196,8 +1158,8 @@ public:
|
||||
* Use numBCast registers for broadcasting A at start of microKernel
|
||||
*/
|
||||
template <bool isARowMajor, int64_t unrollM, int64_t unrollN, int64_t endB, int64_t numLoad>
|
||||
static EIGEN_ALWAYS_INLINE
|
||||
void startBCastA(Scalar *A_t, int64_t LDA, PacketBlock<vec,EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm){
|
||||
static EIGEN_ALWAYS_INLINE void startBCastA(Scalar *A_t, int64_t LDA,
|
||||
PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm) {
|
||||
aux_startBCastA<isARowMajor, unrollM, unrollN, endB, endB, numLoad>(A_t, LDA, zmm);
|
||||
}
|
||||
|
||||
@ -1205,9 +1167,9 @@ public:
|
||||
* Loads next set of B into vector registers between each K unroll.
|
||||
*/
|
||||
template <int64_t endM, int64_t unrollN, int64_t currK, int64_t unrollK, int64_t numLoad, int64_t numBCast, bool rem>
|
||||
static EIGEN_ALWAYS_INLINE
|
||||
void loadB(
|
||||
Scalar *B_t, int64_t LDB, PacketBlock<vec,EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm, int64_t rem_ = 0){
|
||||
static EIGEN_ALWAYS_INLINE void loadB(Scalar *B_t, int64_t LDB,
|
||||
PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm,
|
||||
int64_t rem_ = 0) {
|
||||
EIGEN_UNUSED_VARIABLE(rem_);
|
||||
aux_loadB<endM, endM, unrollN, currK, unrollK, numLoad, numBCast, rem>(B_t, LDB, zmm, rem_);
|
||||
}
|
||||
@ -1235,18 +1197,16 @@ public:
|
||||
* From testing, there are no register spills with clang. There are register spills with GNU, which
|
||||
* causes a performance hit.
|
||||
*/
|
||||
template<bool isARowMajor, int64_t endM, int64_t endN, int64_t endK, int64_t numLoad, int64_t numBCast, bool rem = false>
|
||||
static EIGEN_ALWAYS_INLINE
|
||||
void microKernel(
|
||||
Scalar *B_t, Scalar* A_t, int64_t LDB, int64_t LDA,
|
||||
PacketBlock<vec,EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm, int64_t rem_ = 0){
|
||||
template <bool isARowMajor, int64_t endM, int64_t endN, int64_t endK, int64_t numLoad, int64_t numBCast,
|
||||
bool rem = false>
|
||||
static EIGEN_ALWAYS_INLINE void microKernel(Scalar *B_t, Scalar *A_t, int64_t LDB, int64_t LDA,
|
||||
PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm,
|
||||
int64_t rem_ = 0) {
|
||||
EIGEN_UNUSED_VARIABLE(rem_);
|
||||
aux_microKernel<isARowMajor,endM, endN, endK, endM*endN*endK, numLoad, numBCast, rem>(
|
||||
B_t, A_t, LDB, LDA, zmm, rem_);
|
||||
aux_microKernel<isARowMajor, endM, endN, endK, endM * endN * endK, numLoad, numBCast, rem>(B_t, A_t, LDB, LDA, zmm,
|
||||
rem_);
|
||||
}
|
||||
|
||||
};
|
||||
} // namespace unrolls
|
||||
|
||||
|
||||
#endif // EIGEN_UNROLLS_IMPL_H
|
||||
|
@ -171,7 +171,7 @@ EIGEN_DONT_INLINE void triangular_solve_matrix<Scalar,Index,OnTheLeft,Mode,Conju
|
||||
std::ptrdiff_t l1, l2, l3;
|
||||
manage_caching_sizes(GetAction, &l1, &l2, &l3);
|
||||
|
||||
#if defined(EIGEN_ENABLE_AVX512_NOCOPY_TRSML_CUTOFFS)
|
||||
#if (EIGEN_ENABLE_AVX512_NOCOPY_TRSM_L_CUTOFFS)
|
||||
EIGEN_IF_CONSTEXPR( (OtherInnerStride == 1 &&
|
||||
(std::is_same<Scalar,float>::value ||
|
||||
std::is_same<Scalar,double>::value)) ) {
|
||||
@ -246,7 +246,7 @@ EIGEN_DONT_INLINE void triangular_solve_matrix<Scalar,Index,OnTheLeft,Mode,Conju
|
||||
// tr solve
|
||||
{
|
||||
Index i = IsLower ? k2+k1 : k2-k1;
|
||||
#if defined(EIGEN_USE_AVX512_TRSM_L_KERNELS)
|
||||
#if EIGEN_USE_AVX512_TRSM_L_KERNELS
|
||||
EIGEN_IF_CONSTEXPR( (OtherInnerStride == 1 &&
|
||||
(std::is_same<Scalar,float>::value ||
|
||||
std::is_same<Scalar,double>::value)) ) {
|
||||
@ -318,7 +318,7 @@ EIGEN_DONT_INLINE void triangular_solve_matrix<Scalar,Index,OnTheRight,Mode,Conj
|
||||
{
|
||||
Index rows = otherSize;
|
||||
|
||||
#if defined(EIGEN_ENABLE_AVX512_NOCOPY_TRSM_R_CUTOFFS)
|
||||
#if EIGEN_ENABLE_AVX512_NOCOPY_TRSM_R_CUTOFFS
|
||||
EIGEN_IF_CONSTEXPR( (OtherInnerStride == 1 &&
|
||||
(std::is_same<Scalar,float>::value ||
|
||||
std::is_same<Scalar,double>::value)) ) {
|
||||
|
Loading…
x
Reference in New Issue
Block a user