AVX512 TRSM kernels use alloca if EIGEN_NO_MALLOC requested

This commit is contained in:
b-shi 2022-06-17 18:05:26 +00:00 committed by Rasmus Munk Larsen
parent 4d1c16eab8
commit 37673ca1bc
4 changed files with 2087 additions and 2030 deletions

File diff suppressed because it is too large Load Diff

View File

@ -12,13 +12,20 @@
#include "../../InternalHeaderCheck.h" #include "../../InternalHeaderCheck.h"
#define EIGEN_USE_AVX512_TRSM_KERNELS // Comment out to prevent using optimized trsm kernels. #if !defined(EIGEN_USE_AVX512_TRSM_KERNELS)
#define EIGEN_USE_AVX512_TRSM_KERNELS 1
#if defined(EIGEN_USE_AVX512_TRSM_KERNELS)
#define EIGEN_USE_AVX512_TRSM_R_KERNELS
#if !defined(EIGEN_NO_MALLOC) // Separate MACRO since these kernels require malloc
#define EIGEN_USE_AVX512_TRSM_L_KERNELS
#endif #endif
#if EIGEN_USE_AVX512_TRSM_KERNELS
#if !defined(EIGEN_USE_AVX512_TRSM_R_KERNELS)
#define EIGEN_USE_AVX512_TRSM_R_KERNELS 1
#endif
#if !defined(EIGEN_USE_AVX512_TRSM_L_KERNELS)
#define EIGEN_USE_AVX512_TRSM_L_KERNELS 1
#endif
#else // EIGEN_USE_AVX512_TRSM_KERNELS == 0
#define EIGEN_USE_AVX512_TRSM_R_KERNELS 0
#define EIGEN_USE_AVX512_TRSM_L_KERNELS 0
#endif #endif
#if defined(EIGEN_HAS_CXX17_IFCONSTEXPR) #if defined(EIGEN_HAS_CXX17_IFCONSTEXPR)
@ -49,8 +56,7 @@ typedef Packet4d vecHalfDouble;
// Note: this depends on macros and typedefs above. // Note: this depends on macros and typedefs above.
#include "TrsmUnrolls.inc" #include "TrsmUnrolls.inc"
#if (EIGEN_USE_AVX512_TRSM_KERNELS) && (EIGEN_COMP_CLANG != 0)
#if defined(EIGEN_USE_AVX512_TRSM_KERNELS) && (EIGEN_COMP_CLANG != 0)
/** /**
* For smaller problem sizes, and certain compilers, using the optimized kernels trsmKernelL/R directly * For smaller problem sizes, and certain compilers, using the optimized kernels trsmKernelL/R directly
* is faster than the packed versions in TriangularSolverMatrix.h. * is faster than the packed versions in TriangularSolverMatrix.h.
@ -67,34 +73,46 @@ typedef Packet4d vecHalfDouble;
* M = Dimension of triangular matrix * M = Dimension of triangular matrix
* *
*/ */
#define EIGEN_ENABLE_AVX512_NOCOPY_TRSM_CUTOFFS // Comment out to disable no-copy dispatch #if !defined(EIGEN_ENABLE_AVX512_NOCOPY_TRSM_CUTOFFS)
#define EIGEN_ENABLE_AVX512_NOCOPY_TRSM_CUTOFFS 1
#endif
#if defined(EIGEN_ENABLE_AVX512_NOCOPY_TRSM_CUTOFFS) #if EIGEN_ENABLE_AVX512_NOCOPY_TRSM_CUTOFFS
#define EIGEN_ENABLE_AVX512_NOCOPY_TRSM_R_CUTOFFS
#if !defined(EIGEN_NO_MALLOC) #if EIGEN_USE_AVX512_TRSM_R_KERNELS
#define EIGEN_ENABLE_AVX512_NOCOPY_TRSM_L_CUTOFFS #if !defined(EIGEN_ENABLE_AVX512_NOCOPY_TRSM_R_CUTOFFS)
#define EIGEN_ENABLE_AVX512_NOCOPY_TRSM_R_CUTOFFS 1
#endif // !defined(EIGEN_ENABLE_AVX512_NOCOPY_TRSM_R_CUTOFFS)
#endif #endif
#if EIGEN_USE_AVX512_TRSM_L_KERNELS
#if !defined(EIGEN_ENABLE_AVX512_NOCOPY_TRSM_L_CUTOFFS)
#define EIGEN_ENABLE_AVX512_NOCOPY_TRSM_L_CUTOFFS 1
#endif #endif
#endif // EIGEN_USE_AVX512_TRSM_L_KERNELS
#else // EIGEN_ENABLE_AVX512_NOCOPY_TRSM_CUTOFFS == 0
#define EIGEN_ENABLE_AVX512_NOCOPY_TRSM_R_CUTOFFS 0
#define EIGEN_ENABLE_AVX512_NOCOPY_TRSM_L_CUTOFFS 0
#endif // EIGEN_ENABLE_AVX512_NOCOPY_TRSM_CUTOFFS
template <typename Scalar> template <typename Scalar>
int64_t avx512_trsm_cutoff(int64_t L2Size, int64_t N, double L2Cap) { int64_t avx512_trsm_cutoff(int64_t L2Size, int64_t N, double L2Cap) {
const int64_t U3 = 3 * packet_traits<Scalar>::size; const int64_t U3 = 3 * packet_traits<Scalar>::size;
const int64_t MaxNb = 5 * U3; const int64_t MaxNb = 5 * U3;
int64_t Nb = std::min(MaxNb, N); int64_t Nb = std::min(MaxNb, N);
double cutoff_d = (((L2Size*L2Cap)/(sizeof(Scalar)))-(EIGEN_AVX_MAX_NUM_ROW)*Nb)/ double cutoff_d =
((EIGEN_AVX_MAX_NUM_ROW)+Nb); (((L2Size * L2Cap) / (sizeof(Scalar))) - (EIGEN_AVX_MAX_NUM_ROW)*Nb) / ((EIGEN_AVX_MAX_NUM_ROW) + Nb);
int64_t cutoff_l = static_cast<int64_t>(cutoff_d); int64_t cutoff_l = static_cast<int64_t>(cutoff_d);
return (cutoff_l / EIGEN_AVX_MAX_NUM_ROW) * EIGEN_AVX_MAX_NUM_ROW; return (cutoff_l / EIGEN_AVX_MAX_NUM_ROW) * EIGEN_AVX_MAX_NUM_ROW;
} }
#endif #endif
/** /**
* Used by gemmKernel for the case A/B row-major and C col-major. * Used by gemmKernel for the case A/B row-major and C col-major.
*/ */
template <typename Scalar, typename vec, int64_t unrollM, int64_t unrollN, bool remM, bool remN> template <typename Scalar, typename vec, int64_t unrollM, int64_t unrollN, bool remM, bool remN>
static EIGEN_ALWAYS_INLINE static EIGEN_ALWAYS_INLINE void transStoreC(PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm,
void transStoreC(PacketBlock<vec,EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm,
Scalar *C_arr, int64_t LDC, int64_t remM_ = 0, int64_t remN_ = 0) { Scalar *C_arr, int64_t LDC, int64_t remM_ = 0, int64_t remN_ = 0) {
EIGEN_UNUSED_VARIABLE(remN_); EIGEN_UNUSED_VARIABLE(remN_);
EIGEN_UNUSED_VARIABLE(remM_); EIGEN_UNUSED_VARIABLE(remM_);
@ -194,16 +212,13 @@ void transStoreC(PacketBlock<vec,EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm,
* handleKRem: Handle arbitrary K? This is not needed for trsm. * handleKRem: Handle arbitrary K? This is not needed for trsm.
*/ */
template <typename Scalar, bool isARowMajor, bool isCRowMajor, bool isAdd, bool handleKRem> template <typename Scalar, bool isARowMajor, bool isCRowMajor, bool isAdd, bool handleKRem>
void gemmKernel(Scalar *A_arr, Scalar *B_arr, Scalar *C_arr, void gemmKernel(Scalar *A_arr, Scalar *B_arr, Scalar *C_arr, int64_t M, int64_t N, int64_t K, int64_t LDA, int64_t LDB,
int64_t M, int64_t N, int64_t K, int64_t LDC) {
int64_t LDA, int64_t LDB, int64_t LDC) {
using urolls = unrolls::gemm<Scalar, isAdd>; using urolls = unrolls::gemm<Scalar, isAdd>;
constexpr int64_t U3 = urolls::PacketSize * 3; constexpr int64_t U3 = urolls::PacketSize * 3;
constexpr int64_t U2 = urolls::PacketSize * 2; constexpr int64_t U2 = urolls::PacketSize * 2;
constexpr int64_t U1 = urolls::PacketSize * 1; constexpr int64_t U1 = urolls::PacketSize * 1;
using vec = typename std::conditional<std::is_same<Scalar, float>::value, using vec = typename std::conditional<std::is_same<Scalar, float>::value, vecFullFloat, vecFullDouble>::type;
vecFullFloat,
vecFullDouble>::type;
int64_t N_ = (N / U3) * U3; int64_t N_ = (N / U3) * U3;
int64_t M_ = (M / EIGEN_AVX_MAX_NUM_ROW) * EIGEN_AVX_MAX_NUM_ROW; int64_t M_ = (M / EIGEN_AVX_MAX_NUM_ROW) * EIGEN_AVX_MAX_NUM_ROW;
int64_t K_ = (K / EIGEN_AVX_MAX_K_UNROL) * EIGEN_AVX_MAX_K_UNROL; int64_t K_ = (K / EIGEN_AVX_MAX_K_UNROL) * EIGEN_AVX_MAX_K_UNROL;
@ -216,19 +231,19 @@ void gemmKernel(Scalar *A_arr, Scalar *B_arr, Scalar *C_arr,
PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> zmm; PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> zmm;
urolls::template setzero<3, EIGEN_AVX_MAX_NUM_ROW>(zmm); urolls::template setzero<3, EIGEN_AVX_MAX_NUM_ROW>(zmm);
for (int64_t k = 0; k < K_; k += EIGEN_AVX_MAX_K_UNROL) { for (int64_t k = 0; k < K_; k += EIGEN_AVX_MAX_K_UNROL) {
urolls:: template microKernel<isARowMajor,3,EIGEN_AVX_MAX_NUM_ROW,EIGEN_AVX_MAX_K_UNROL, urolls::template microKernel<isARowMajor, 3, EIGEN_AVX_MAX_NUM_ROW, EIGEN_AVX_MAX_K_UNROL, EIGEN_AVX_MAX_B_LOAD,
EIGEN_AVX_MAX_B_LOAD,EIGEN_AVX_MAX_A_BCAST>( EIGEN_AVX_MAX_A_BCAST>(B_t, A_t, LDB, LDA, zmm);
B_t, A_t, LDB, LDA, zmm);
B_t += EIGEN_AVX_MAX_K_UNROL * LDB; B_t += EIGEN_AVX_MAX_K_UNROL * LDB;
EIGEN_IF_CONSTEXPR(isARowMajor) A_t += EIGEN_AVX_MAX_K_UNROL; else A_t += EIGEN_AVX_MAX_K_UNROL*LDA; EIGEN_IF_CONSTEXPR(isARowMajor) A_t += EIGEN_AVX_MAX_K_UNROL;
else A_t += EIGEN_AVX_MAX_K_UNROL * LDA;
} }
EIGEN_IF_CONSTEXPR(handleKRem) { EIGEN_IF_CONSTEXPR(handleKRem) {
for (int64_t k = K_; k < K; k++) { for (int64_t k = K_; k < K; k++) {
urolls:: template microKernel<isARowMajor,3,EIGEN_AVX_MAX_NUM_ROW,1, urolls::template microKernel<isARowMajor, 3, EIGEN_AVX_MAX_NUM_ROW, 1, EIGEN_AVX_B_LOAD_SETS * 3,
EIGEN_AVX_B_LOAD_SETS*3,EIGEN_AVX_MAX_A_BCAST>( EIGEN_AVX_MAX_A_BCAST>(B_t, A_t, LDB, LDA, zmm);
B_t, A_t, LDB, LDA, zmm);
B_t += LDB; B_t += LDB;
EIGEN_IF_CONSTEXPR(isARowMajor) A_t++; else A_t += LDA; EIGEN_IF_CONSTEXPR(isARowMajor) A_t++;
else A_t += LDA;
} }
} }
EIGEN_IF_CONSTEXPR(isCRowMajor) { EIGEN_IF_CONSTEXPR(isCRowMajor) {
@ -245,19 +260,19 @@ void gemmKernel(Scalar *A_arr, Scalar *B_arr, Scalar *C_arr,
PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> zmm; PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> zmm;
urolls::template setzero<3, 4>(zmm); urolls::template setzero<3, 4>(zmm);
for (int64_t k = 0; k < K_; k += EIGEN_AVX_MAX_K_UNROL) { for (int64_t k = 0; k < K_; k += EIGEN_AVX_MAX_K_UNROL) {
urolls:: template microKernel<isARowMajor,3,4,EIGEN_AVX_MAX_K_UNROL, urolls::template microKernel<isARowMajor, 3, 4, EIGEN_AVX_MAX_K_UNROL, EIGEN_AVX_B_LOAD_SETS * 3,
EIGEN_AVX_B_LOAD_SETS*3,EIGEN_AVX_MAX_A_BCAST>( EIGEN_AVX_MAX_A_BCAST>(B_t, A_t, LDB, LDA, zmm);
B_t, A_t, LDB, LDA, zmm);
B_t += EIGEN_AVX_MAX_K_UNROL * LDB; B_t += EIGEN_AVX_MAX_K_UNROL * LDB;
EIGEN_IF_CONSTEXPR(isARowMajor) A_t += EIGEN_AVX_MAX_K_UNROL; else A_t += EIGEN_AVX_MAX_K_UNROL*LDA; EIGEN_IF_CONSTEXPR(isARowMajor) A_t += EIGEN_AVX_MAX_K_UNROL;
else A_t += EIGEN_AVX_MAX_K_UNROL * LDA;
} }
EIGEN_IF_CONSTEXPR(handleKRem) { EIGEN_IF_CONSTEXPR(handleKRem) {
for (int64_t k = K_; k < K; k++) { for (int64_t k = K_; k < K; k++) {
urolls:: template microKernel<isARowMajor,3,4,1, urolls::template microKernel<isARowMajor, 3, 4, 1, EIGEN_AVX_B_LOAD_SETS * 3, EIGEN_AVX_MAX_A_BCAST>(
EIGEN_AVX_B_LOAD_SETS*3,EIGEN_AVX_MAX_A_BCAST>(
B_t, A_t, LDB, LDA, zmm); B_t, A_t, LDB, LDA, zmm);
B_t += LDB; B_t += LDB;
EIGEN_IF_CONSTEXPR(isARowMajor) A_t++; else A_t += LDA; EIGEN_IF_CONSTEXPR(isARowMajor) A_t++;
else A_t += LDA;
} }
} }
EIGEN_IF_CONSTEXPR(isCRowMajor) { EIGEN_IF_CONSTEXPR(isCRowMajor) {
@ -275,19 +290,19 @@ void gemmKernel(Scalar *A_arr, Scalar *B_arr, Scalar *C_arr,
PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> zmm; PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> zmm;
urolls::template setzero<3, 2>(zmm); urolls::template setzero<3, 2>(zmm);
for (int64_t k = 0; k < K_; k += EIGEN_AVX_MAX_K_UNROL) { for (int64_t k = 0; k < K_; k += EIGEN_AVX_MAX_K_UNROL) {
urolls:: template microKernel<isARowMajor,3,2,EIGEN_AVX_MAX_K_UNROL, urolls::template microKernel<isARowMajor, 3, 2, EIGEN_AVX_MAX_K_UNROL, EIGEN_AVX_B_LOAD_SETS * 3,
EIGEN_AVX_B_LOAD_SETS*3,EIGEN_AVX_MAX_A_BCAST>( EIGEN_AVX_MAX_A_BCAST>(B_t, A_t, LDB, LDA, zmm);
B_t, A_t, LDB, LDA, zmm);
B_t += EIGEN_AVX_MAX_K_UNROL * LDB; B_t += EIGEN_AVX_MAX_K_UNROL * LDB;
EIGEN_IF_CONSTEXPR(isARowMajor) A_t += EIGEN_AVX_MAX_K_UNROL; else A_t += EIGEN_AVX_MAX_K_UNROL*LDA; EIGEN_IF_CONSTEXPR(isARowMajor) A_t += EIGEN_AVX_MAX_K_UNROL;
else A_t += EIGEN_AVX_MAX_K_UNROL * LDA;
} }
EIGEN_IF_CONSTEXPR(handleKRem) { EIGEN_IF_CONSTEXPR(handleKRem) {
for (int64_t k = K_; k < K; k++) { for (int64_t k = K_; k < K; k++) {
urolls:: template microKernel<isARowMajor,3,2,1, urolls::template microKernel<isARowMajor, 3, 2, 1, EIGEN_AVX_B_LOAD_SETS * 3, EIGEN_AVX_MAX_A_BCAST>(
EIGEN_AVX_B_LOAD_SETS*3,EIGEN_AVX_MAX_A_BCAST>(
B_t, A_t, LDB, LDA, zmm); B_t, A_t, LDB, LDA, zmm);
B_t += LDB; B_t += LDB;
EIGEN_IF_CONSTEXPR(isARowMajor) A_t++; else A_t += LDA; EIGEN_IF_CONSTEXPR(isARowMajor) A_t++;
else A_t += LDA;
} }
} }
EIGEN_IF_CONSTEXPR(isCRowMajor) { EIGEN_IF_CONSTEXPR(isCRowMajor) {
@ -306,18 +321,18 @@ void gemmKernel(Scalar *A_arr, Scalar *B_arr, Scalar *C_arr,
urolls::template setzero<3, 1>(zmm); urolls::template setzero<3, 1>(zmm);
{ {
for (int64_t k = 0; k < K_; k += EIGEN_AVX_MAX_K_UNROL) { for (int64_t k = 0; k < K_; k += EIGEN_AVX_MAX_K_UNROL) {
urolls:: template microKernel<isARowMajor,3,1,EIGEN_AVX_MAX_K_UNROL, urolls::template microKernel<isARowMajor, 3, 1, EIGEN_AVX_MAX_K_UNROL, EIGEN_AVX_B_LOAD_SETS * 3, 1>(
EIGEN_AVX_B_LOAD_SETS*3,1>(
B_t, A_t, LDB, LDA, zmm); B_t, A_t, LDB, LDA, zmm);
B_t += EIGEN_AVX_MAX_K_UNROL * LDB; B_t += EIGEN_AVX_MAX_K_UNROL * LDB;
EIGEN_IF_CONSTEXPR(isARowMajor) A_t += EIGEN_AVX_MAX_K_UNROL; else A_t += EIGEN_AVX_MAX_K_UNROL*LDA; EIGEN_IF_CONSTEXPR(isARowMajor) A_t += EIGEN_AVX_MAX_K_UNROL;
else A_t += EIGEN_AVX_MAX_K_UNROL * LDA;
} }
EIGEN_IF_CONSTEXPR(handleKRem) { EIGEN_IF_CONSTEXPR(handleKRem) {
for (int64_t k = K_; k < K; k++) { for (int64_t k = K_; k < K; k++) {
urolls:: template microKernel<isARowMajor,3,1,1, urolls::template microKernel<isARowMajor, 3, 1, 1, EIGEN_AVX_B_LOAD_SETS * 3, 1>(B_t, A_t, LDB, LDA, zmm);
EIGEN_AVX_B_LOAD_SETS*3,1>(B_t, A_t, LDB, LDA, zmm);
B_t += LDB; B_t += LDB;
EIGEN_IF_CONSTEXPR(isARowMajor) A_t++; else A_t += LDA; EIGEN_IF_CONSTEXPR(isARowMajor) A_t++;
else A_t += LDA;
} }
} }
EIGEN_IF_CONSTEXPR(isCRowMajor) { EIGEN_IF_CONSTEXPR(isCRowMajor) {
@ -339,19 +354,19 @@ void gemmKernel(Scalar *A_arr, Scalar *B_arr, Scalar *C_arr,
PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> zmm; PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> zmm;
urolls::template setzero<2, EIGEN_AVX_MAX_NUM_ROW>(zmm); urolls::template setzero<2, EIGEN_AVX_MAX_NUM_ROW>(zmm);
for (int64_t k = 0; k < K_; k += EIGEN_AVX_MAX_K_UNROL) { for (int64_t k = 0; k < K_; k += EIGEN_AVX_MAX_K_UNROL) {
urolls:: template microKernel<isARowMajor,2,EIGEN_AVX_MAX_NUM_ROW, urolls::template microKernel<isARowMajor, 2, EIGEN_AVX_MAX_NUM_ROW, EIGEN_AVX_MAX_K_UNROL, EIGEN_AVX_MAX_B_LOAD,
EIGEN_AVX_MAX_K_UNROL,EIGEN_AVX_MAX_B_LOAD,EIGEN_AVX_MAX_A_BCAST>( EIGEN_AVX_MAX_A_BCAST>(B_t, A_t, LDB, LDA, zmm);
B_t, A_t, LDB, LDA, zmm);
B_t += EIGEN_AVX_MAX_K_UNROL * LDB; B_t += EIGEN_AVX_MAX_K_UNROL * LDB;
EIGEN_IF_CONSTEXPR(isARowMajor) A_t += EIGEN_AVX_MAX_K_UNROL; else A_t += EIGEN_AVX_MAX_K_UNROL*LDA; EIGEN_IF_CONSTEXPR(isARowMajor) A_t += EIGEN_AVX_MAX_K_UNROL;
else A_t += EIGEN_AVX_MAX_K_UNROL * LDA;
} }
EIGEN_IF_CONSTEXPR(handleKRem) { EIGEN_IF_CONSTEXPR(handleKRem) {
for (int64_t k = K_; k < K; k++) { for (int64_t k = K_; k < K; k++) {
urolls:: template microKernel<isARowMajor,2,EIGEN_AVX_MAX_NUM_ROW,1, urolls::template microKernel<isARowMajor, 2, EIGEN_AVX_MAX_NUM_ROW, 1, EIGEN_AVX_MAX_B_LOAD,
EIGEN_AVX_MAX_B_LOAD,EIGEN_AVX_MAX_A_BCAST>( EIGEN_AVX_MAX_A_BCAST>(B_t, A_t, LDB, LDA, zmm);
B_t, A_t, LDB, LDA, zmm);
B_t += LDB; B_t += LDB;
EIGEN_IF_CONSTEXPR(isARowMajor) A_t++; else A_t += LDA; EIGEN_IF_CONSTEXPR(isARowMajor) A_t++;
else A_t += LDA;
} }
} }
EIGEN_IF_CONSTEXPR(isCRowMajor) { EIGEN_IF_CONSTEXPR(isCRowMajor) {
@ -368,19 +383,19 @@ void gemmKernel(Scalar *A_arr, Scalar *B_arr, Scalar *C_arr,
PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> zmm; PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> zmm;
urolls::template setzero<2, 4>(zmm); urolls::template setzero<2, 4>(zmm);
for (int64_t k = 0; k < K_; k += EIGEN_AVX_MAX_K_UNROL) { for (int64_t k = 0; k < K_; k += EIGEN_AVX_MAX_K_UNROL) {
urolls:: template microKernel<isARowMajor,2,4,EIGEN_AVX_MAX_K_UNROL, urolls::template microKernel<isARowMajor, 2, 4, EIGEN_AVX_MAX_K_UNROL, EIGEN_AVX_MAX_B_LOAD,
EIGEN_AVX_MAX_B_LOAD,EIGEN_AVX_MAX_A_BCAST>( EIGEN_AVX_MAX_A_BCAST>(B_t, A_t, LDB, LDA, zmm);
B_t, A_t, LDB, LDA, zmm);
B_t += EIGEN_AVX_MAX_K_UNROL * LDB; B_t += EIGEN_AVX_MAX_K_UNROL * LDB;
EIGEN_IF_CONSTEXPR(isARowMajor) A_t += EIGEN_AVX_MAX_K_UNROL; else A_t += EIGEN_AVX_MAX_K_UNROL*LDA; EIGEN_IF_CONSTEXPR(isARowMajor) A_t += EIGEN_AVX_MAX_K_UNROL;
else A_t += EIGEN_AVX_MAX_K_UNROL * LDA;
} }
EIGEN_IF_CONSTEXPR(handleKRem) { EIGEN_IF_CONSTEXPR(handleKRem) {
for (int64_t k = K_; k < K; k++) { for (int64_t k = K_; k < K; k++) {
urolls:: template microKernel<isARowMajor,2,4,1, urolls::template microKernel<isARowMajor, 2, 4, 1, EIGEN_AVX_MAX_B_LOAD, EIGEN_AVX_MAX_A_BCAST>(B_t, A_t, LDB,
EIGEN_AVX_MAX_B_LOAD,EIGEN_AVX_MAX_A_BCAST>( LDA, zmm);
B_t, A_t, LDB, LDA, zmm);
B_t += LDB; B_t += LDB;
EIGEN_IF_CONSTEXPR(isARowMajor) A_t++; else A_t += LDA; EIGEN_IF_CONSTEXPR(isARowMajor) A_t++;
else A_t += LDA;
} }
} }
EIGEN_IF_CONSTEXPR(isCRowMajor) { EIGEN_IF_CONSTEXPR(isCRowMajor) {
@ -398,19 +413,19 @@ void gemmKernel(Scalar *A_arr, Scalar *B_arr, Scalar *C_arr,
PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> zmm; PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> zmm;
urolls::template setzero<2, 2>(zmm); urolls::template setzero<2, 2>(zmm);
for (int64_t k = 0; k < K_; k += EIGEN_AVX_MAX_K_UNROL) { for (int64_t k = 0; k < K_; k += EIGEN_AVX_MAX_K_UNROL) {
urolls:: template microKernel<isARowMajor,2,2,EIGEN_AVX_MAX_K_UNROL, urolls::template microKernel<isARowMajor, 2, 2, EIGEN_AVX_MAX_K_UNROL, EIGEN_AVX_MAX_B_LOAD,
EIGEN_AVX_MAX_B_LOAD,EIGEN_AVX_MAX_A_BCAST>( EIGEN_AVX_MAX_A_BCAST>(B_t, A_t, LDB, LDA, zmm);
B_t, A_t, LDB, LDA, zmm);
B_t += EIGEN_AVX_MAX_K_UNROL * LDB; B_t += EIGEN_AVX_MAX_K_UNROL * LDB;
EIGEN_IF_CONSTEXPR(isARowMajor) A_t += EIGEN_AVX_MAX_K_UNROL; else A_t += EIGEN_AVX_MAX_K_UNROL*LDA; EIGEN_IF_CONSTEXPR(isARowMajor) A_t += EIGEN_AVX_MAX_K_UNROL;
else A_t += EIGEN_AVX_MAX_K_UNROL * LDA;
} }
EIGEN_IF_CONSTEXPR(handleKRem) { EIGEN_IF_CONSTEXPR(handleKRem) {
for (int64_t k = K_; k < K; k++) { for (int64_t k = K_; k < K; k++) {
urolls:: template microKernel<isARowMajor,2,2,1, urolls::template microKernel<isARowMajor, 2, 2, 1, EIGEN_AVX_MAX_B_LOAD, EIGEN_AVX_MAX_A_BCAST>(B_t, A_t, LDB,
EIGEN_AVX_MAX_B_LOAD,EIGEN_AVX_MAX_A_BCAST>( LDA, zmm);
B_t, A_t, LDB, LDA, zmm);
B_t += LDB; B_t += LDB;
EIGEN_IF_CONSTEXPR(isARowMajor) A_t++; else A_t += LDA; EIGEN_IF_CONSTEXPR(isARowMajor) A_t++;
else A_t += LDA;
} }
} }
EIGEN_IF_CONSTEXPR(isCRowMajor) { EIGEN_IF_CONSTEXPR(isCRowMajor) {
@ -428,18 +443,18 @@ void gemmKernel(Scalar *A_arr, Scalar *B_arr, Scalar *C_arr,
PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> zmm; PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> zmm;
urolls::template setzero<2, 1>(zmm); urolls::template setzero<2, 1>(zmm);
for (int64_t k = 0; k < K_; k += EIGEN_AVX_MAX_K_UNROL) { for (int64_t k = 0; k < K_; k += EIGEN_AVX_MAX_K_UNROL) {
urolls:: template microKernel<isARowMajor,2,1,EIGEN_AVX_MAX_K_UNROL, urolls::template microKernel<isARowMajor, 2, 1, EIGEN_AVX_MAX_K_UNROL, EIGEN_AVX_MAX_B_LOAD, 1>(B_t, A_t, LDB,
EIGEN_AVX_MAX_B_LOAD,1>( LDA, zmm);
B_t, A_t, LDB, LDA, zmm);
B_t += EIGEN_AVX_MAX_K_UNROL * LDB; B_t += EIGEN_AVX_MAX_K_UNROL * LDB;
EIGEN_IF_CONSTEXPR(isARowMajor) A_t += EIGEN_AVX_MAX_K_UNROL; else A_t += EIGEN_AVX_MAX_K_UNROL*LDA; EIGEN_IF_CONSTEXPR(isARowMajor) A_t += EIGEN_AVX_MAX_K_UNROL;
else A_t += EIGEN_AVX_MAX_K_UNROL * LDA;
} }
EIGEN_IF_CONSTEXPR(handleKRem) { EIGEN_IF_CONSTEXPR(handleKRem) {
for (int64_t k = K_; k < K; k++) { for (int64_t k = K_; k < K; k++) {
urolls:: template microKernel<isARowMajor,2,1,1, urolls::template microKernel<isARowMajor, 2, 1, 1, EIGEN_AVX_MAX_B_LOAD, 1>(B_t, A_t, LDB, LDA, zmm);
EIGEN_AVX_MAX_B_LOAD,1>(B_t, A_t, LDB, LDA, zmm);
B_t += LDB; B_t += LDB;
EIGEN_IF_CONSTEXPR(isARowMajor) A_t++; else A_t += LDA; EIGEN_IF_CONSTEXPR(isARowMajor) A_t++;
else A_t += LDA;
} }
} }
EIGEN_IF_CONSTEXPR(isCRowMajor) { EIGEN_IF_CONSTEXPR(isCRowMajor) {
@ -460,19 +475,19 @@ void gemmKernel(Scalar *A_arr, Scalar *B_arr, Scalar *C_arr,
PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> zmm; PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> zmm;
urolls::template setzero<1, EIGEN_AVX_MAX_NUM_ROW>(zmm); urolls::template setzero<1, EIGEN_AVX_MAX_NUM_ROW>(zmm);
for (int64_t k = 0; k < K_; k += EIGEN_AVX_MAX_K_UNROL) { for (int64_t k = 0; k < K_; k += EIGEN_AVX_MAX_K_UNROL) {
urolls:: template microKernel<isARowMajor,1,EIGEN_AVX_MAX_NUM_ROW,EIGEN_AVX_MAX_K_UNROL, urolls::template microKernel<isARowMajor, 1, EIGEN_AVX_MAX_NUM_ROW, EIGEN_AVX_MAX_K_UNROL, EIGEN_AVX_MAX_B_LOAD,
EIGEN_AVX_MAX_B_LOAD,EIGEN_AVX_MAX_A_BCAST>( EIGEN_AVX_MAX_A_BCAST>(B_t, A_t, LDB, LDA, zmm);
B_t, A_t, LDB, LDA, zmm);
B_t += EIGEN_AVX_MAX_K_UNROL * LDB; B_t += EIGEN_AVX_MAX_K_UNROL * LDB;
EIGEN_IF_CONSTEXPR(isARowMajor) A_t += EIGEN_AVX_MAX_K_UNROL; else A_t += EIGEN_AVX_MAX_K_UNROL*LDA; EIGEN_IF_CONSTEXPR(isARowMajor) A_t += EIGEN_AVX_MAX_K_UNROL;
else A_t += EIGEN_AVX_MAX_K_UNROL * LDA;
} }
EIGEN_IF_CONSTEXPR(handleKRem) { EIGEN_IF_CONSTEXPR(handleKRem) {
for (int64_t k = K_; k < K; k++) { for (int64_t k = K_; k < K; k++) {
urolls:: template microKernel<isARowMajor,1,EIGEN_AVX_MAX_NUM_ROW,1, urolls::template microKernel<isARowMajor, 1, EIGEN_AVX_MAX_NUM_ROW, 1, EIGEN_AVX_B_LOAD_SETS * 1,
EIGEN_AVX_B_LOAD_SETS*1,EIGEN_AVX_MAX_A_BCAST>( EIGEN_AVX_MAX_A_BCAST>(B_t, A_t, LDB, LDA, zmm);
B_t, A_t, LDB, LDA, zmm);
B_t += LDB; B_t += LDB;
EIGEN_IF_CONSTEXPR(isARowMajor) A_t++; else A_t += LDA; EIGEN_IF_CONSTEXPR(isARowMajor) A_t++;
else A_t += LDA;
} }
} }
EIGEN_IF_CONSTEXPR(isCRowMajor) { EIGEN_IF_CONSTEXPR(isCRowMajor) {
@ -489,19 +504,19 @@ void gemmKernel(Scalar *A_arr, Scalar *B_arr, Scalar *C_arr,
PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> zmm; PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> zmm;
urolls::template setzero<1, 4>(zmm); urolls::template setzero<1, 4>(zmm);
for (int64_t k = 0; k < K_; k += EIGEN_AVX_MAX_K_UNROL) { for (int64_t k = 0; k < K_; k += EIGEN_AVX_MAX_K_UNROL) {
urolls:: template microKernel<isARowMajor,1,4,EIGEN_AVX_MAX_K_UNROL, urolls::template microKernel<isARowMajor, 1, 4, EIGEN_AVX_MAX_K_UNROL, EIGEN_AVX_MAX_B_LOAD,
EIGEN_AVX_MAX_B_LOAD,EIGEN_AVX_MAX_A_BCAST>( EIGEN_AVX_MAX_A_BCAST>(B_t, A_t, LDB, LDA, zmm);
B_t, A_t, LDB, LDA, zmm);
B_t += EIGEN_AVX_MAX_K_UNROL * LDB; B_t += EIGEN_AVX_MAX_K_UNROL * LDB;
EIGEN_IF_CONSTEXPR(isARowMajor) A_t += EIGEN_AVX_MAX_K_UNROL; else A_t += EIGEN_AVX_MAX_K_UNROL*LDA; EIGEN_IF_CONSTEXPR(isARowMajor) A_t += EIGEN_AVX_MAX_K_UNROL;
else A_t += EIGEN_AVX_MAX_K_UNROL * LDA;
} }
EIGEN_IF_CONSTEXPR(handleKRem) { EIGEN_IF_CONSTEXPR(handleKRem) {
for (int64_t k = K_; k < K; k++) { for (int64_t k = K_; k < K; k++) {
urolls:: template microKernel<isARowMajor,1,4,1, urolls::template microKernel<isARowMajor, 1, 4, 1, EIGEN_AVX_MAX_B_LOAD, EIGEN_AVX_MAX_A_BCAST>(B_t, A_t, LDB,
EIGEN_AVX_MAX_B_LOAD,EIGEN_AVX_MAX_A_BCAST>( LDA, zmm);
B_t, A_t, LDB, LDA, zmm);
B_t += LDB; B_t += LDB;
EIGEN_IF_CONSTEXPR(isARowMajor) A_t++; else A_t += LDA; EIGEN_IF_CONSTEXPR(isARowMajor) A_t++;
else A_t += LDA;
} }
} }
EIGEN_IF_CONSTEXPR(isCRowMajor) { EIGEN_IF_CONSTEXPR(isCRowMajor) {
@ -519,19 +534,19 @@ void gemmKernel(Scalar *A_arr, Scalar *B_arr, Scalar *C_arr,
PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> zmm; PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> zmm;
urolls::template setzero<1, 2>(zmm); urolls::template setzero<1, 2>(zmm);
for (int64_t k = 0; k < K_; k += EIGEN_AVX_MAX_K_UNROL) { for (int64_t k = 0; k < K_; k += EIGEN_AVX_MAX_K_UNROL) {
urolls:: template microKernel<isARowMajor,1,2,EIGEN_AVX_MAX_K_UNROL, urolls::template microKernel<isARowMajor, 1, 2, EIGEN_AVX_MAX_K_UNROL, EIGEN_AVX_MAX_B_LOAD,
EIGEN_AVX_MAX_B_LOAD,EIGEN_AVX_MAX_A_BCAST>( EIGEN_AVX_MAX_A_BCAST>(B_t, A_t, LDB, LDA, zmm);
B_t, A_t, LDB, LDA, zmm);
B_t += EIGEN_AVX_MAX_K_UNROL * LDB; B_t += EIGEN_AVX_MAX_K_UNROL * LDB;
EIGEN_IF_CONSTEXPR(isARowMajor) A_t += EIGEN_AVX_MAX_K_UNROL; else A_t += EIGEN_AVX_MAX_K_UNROL*LDA; EIGEN_IF_CONSTEXPR(isARowMajor) A_t += EIGEN_AVX_MAX_K_UNROL;
else A_t += EIGEN_AVX_MAX_K_UNROL * LDA;
} }
EIGEN_IF_CONSTEXPR(handleKRem) { EIGEN_IF_CONSTEXPR(handleKRem) {
for (int64_t k = K_; k < K; k++) { for (int64_t k = K_; k < K; k++) {
urolls:: template microKernel<isARowMajor,1,2,1, urolls::template microKernel<isARowMajor, 1, 2, 1, EIGEN_AVX_MAX_B_LOAD, EIGEN_AVX_MAX_A_BCAST>(B_t, A_t, LDB,
EIGEN_AVX_MAX_B_LOAD,EIGEN_AVX_MAX_A_BCAST>( LDA, zmm);
B_t, A_t, LDB, LDA, zmm);
B_t += LDB; B_t += LDB;
EIGEN_IF_CONSTEXPR(isARowMajor) A_t++; else A_t += LDA; EIGEN_IF_CONSTEXPR(isARowMajor) A_t++;
else A_t += LDA;
} }
} }
EIGEN_IF_CONSTEXPR(isCRowMajor) { EIGEN_IF_CONSTEXPR(isCRowMajor) {
@ -550,17 +565,18 @@ void gemmKernel(Scalar *A_arr, Scalar *B_arr, Scalar *C_arr,
urolls::template setzero<1, 1>(zmm); urolls::template setzero<1, 1>(zmm);
{ {
for (int64_t k = 0; k < K_; k += EIGEN_AVX_MAX_K_UNROL) { for (int64_t k = 0; k < K_; k += EIGEN_AVX_MAX_K_UNROL) {
urolls:: template microKernel<isARowMajor,1,1,EIGEN_AVX_MAX_K_UNROL, urolls::template microKernel<isARowMajor, 1, 1, EIGEN_AVX_MAX_K_UNROL, EIGEN_AVX_MAX_B_LOAD, 1>(B_t, A_t, LDB,
EIGEN_AVX_MAX_B_LOAD,1>( LDA, zmm);
B_t, A_t, LDB, LDA, zmm);
B_t += EIGEN_AVX_MAX_K_UNROL * LDB; B_t += EIGEN_AVX_MAX_K_UNROL * LDB;
EIGEN_IF_CONSTEXPR(isARowMajor) A_t += EIGEN_AVX_MAX_K_UNROL; else A_t += EIGEN_AVX_MAX_K_UNROL*LDA; EIGEN_IF_CONSTEXPR(isARowMajor) A_t += EIGEN_AVX_MAX_K_UNROL;
else A_t += EIGEN_AVX_MAX_K_UNROL * LDA;
} }
EIGEN_IF_CONSTEXPR(handleKRem) { EIGEN_IF_CONSTEXPR(handleKRem) {
for (int64_t k = K_; k < K; k++) { for (int64_t k = K_; k < K; k++) {
urolls::template microKernel<isARowMajor, 1, 1, 1, EIGEN_AVX_B_LOAD_SETS * 1, 1>(B_t, A_t, LDB, LDA, zmm); urolls::template microKernel<isARowMajor, 1, 1, 1, EIGEN_AVX_B_LOAD_SETS * 1, 1>(B_t, A_t, LDB, LDA, zmm);
B_t += LDB; B_t += LDB;
EIGEN_IF_CONSTEXPR(isARowMajor) A_t++; else A_t += LDA; EIGEN_IF_CONSTEXPR(isARowMajor) A_t++;
else A_t += LDA;
} }
} }
EIGEN_IF_CONSTEXPR(isCRowMajor) { EIGEN_IF_CONSTEXPR(isCRowMajor) {
@ -583,19 +599,19 @@ void gemmKernel(Scalar *A_arr, Scalar *B_arr, Scalar *C_arr,
PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> zmm; PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> zmm;
urolls::template setzero<1, EIGEN_AVX_MAX_NUM_ROW>(zmm); urolls::template setzero<1, EIGEN_AVX_MAX_NUM_ROW>(zmm);
for (int64_t k = 0; k < K_; k += EIGEN_AVX_MAX_K_UNROL) { for (int64_t k = 0; k < K_; k += EIGEN_AVX_MAX_K_UNROL) {
urolls:: template microKernel<isARowMajor,1,EIGEN_AVX_MAX_NUM_ROW,EIGEN_AVX_MAX_K_UNROL, urolls::template microKernel<isARowMajor, 1, EIGEN_AVX_MAX_NUM_ROW, EIGEN_AVX_MAX_K_UNROL, EIGEN_AVX_MAX_B_LOAD,
EIGEN_AVX_MAX_B_LOAD,EIGEN_AVX_MAX_A_BCAST,true>( EIGEN_AVX_MAX_A_BCAST, true>(B_t, A_t, LDB, LDA, zmm, N - j);
B_t, A_t, LDB, LDA, zmm, N - j);
B_t += EIGEN_AVX_MAX_K_UNROL * LDB; B_t += EIGEN_AVX_MAX_K_UNROL * LDB;
EIGEN_IF_CONSTEXPR(isARowMajor) A_t += EIGEN_AVX_MAX_K_UNROL; else A_t += EIGEN_AVX_MAX_K_UNROL*LDA; EIGEN_IF_CONSTEXPR(isARowMajor) A_t += EIGEN_AVX_MAX_K_UNROL;
else A_t += EIGEN_AVX_MAX_K_UNROL * LDA;
} }
EIGEN_IF_CONSTEXPR(handleKRem) { EIGEN_IF_CONSTEXPR(handleKRem) {
for (int64_t k = K_; k < K; k++) { for (int64_t k = K_; k < K; k++) {
urolls:: template microKernel<isARowMajor,1,EIGEN_AVX_MAX_NUM_ROW,1, urolls::template microKernel<isARowMajor, 1, EIGEN_AVX_MAX_NUM_ROW, 1, EIGEN_AVX_MAX_B_LOAD,
EIGEN_AVX_MAX_B_LOAD,EIGEN_AVX_MAX_A_BCAST,true>( EIGEN_AVX_MAX_A_BCAST, true>(B_t, A_t, LDB, LDA, zmm, N - j);
B_t, A_t, LDB, LDA, zmm, N - j);
B_t += LDB; B_t += LDB;
EIGEN_IF_CONSTEXPR(isARowMajor) A_t++; else A_t += LDA; EIGEN_IF_CONSTEXPR(isARowMajor) A_t++;
else A_t += LDA;
} }
} }
EIGEN_IF_CONSTEXPR(isCRowMajor) { EIGEN_IF_CONSTEXPR(isCRowMajor) {
@ -612,19 +628,19 @@ void gemmKernel(Scalar *A_arr, Scalar *B_arr, Scalar *C_arr,
PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> zmm; PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> zmm;
urolls::template setzero<1, 4>(zmm); urolls::template setzero<1, 4>(zmm);
for (int64_t k = 0; k < K_; k += EIGEN_AVX_MAX_K_UNROL) { for (int64_t k = 0; k < K_; k += EIGEN_AVX_MAX_K_UNROL) {
urolls:: template microKernel<isARowMajor,1,4,EIGEN_AVX_MAX_K_UNROL, urolls::template microKernel<isARowMajor, 1, 4, EIGEN_AVX_MAX_K_UNROL, EIGEN_AVX_MAX_B_LOAD,
EIGEN_AVX_MAX_B_LOAD,EIGEN_AVX_MAX_A_BCAST,true>( EIGEN_AVX_MAX_A_BCAST, true>(B_t, A_t, LDB, LDA, zmm, N - j);
B_t, A_t, LDB, LDA, zmm, N - j);
B_t += EIGEN_AVX_MAX_K_UNROL * LDB; B_t += EIGEN_AVX_MAX_K_UNROL * LDB;
EIGEN_IF_CONSTEXPR(isARowMajor) A_t += EIGEN_AVX_MAX_K_UNROL; else A_t += EIGEN_AVX_MAX_K_UNROL*LDA; EIGEN_IF_CONSTEXPR(isARowMajor) A_t += EIGEN_AVX_MAX_K_UNROL;
else A_t += EIGEN_AVX_MAX_K_UNROL * LDA;
} }
EIGEN_IF_CONSTEXPR(handleKRem) { EIGEN_IF_CONSTEXPR(handleKRem) {
for (int64_t k = K_; k < K; k++) { for (int64_t k = K_; k < K; k++) {
urolls:: template microKernel<isARowMajor,1,4,1, urolls::template microKernel<isARowMajor, 1, 4, 1, EIGEN_AVX_MAX_B_LOAD, EIGEN_AVX_MAX_A_BCAST, true>(
EIGEN_AVX_MAX_B_LOAD,EIGEN_AVX_MAX_A_BCAST,true>(
B_t, A_t, LDB, LDA, zmm, N - j); B_t, A_t, LDB, LDA, zmm, N - j);
B_t += LDB; B_t += LDB;
EIGEN_IF_CONSTEXPR(isARowMajor) A_t++; else A_t += LDA; EIGEN_IF_CONSTEXPR(isARowMajor) A_t++;
else A_t += LDA;
} }
} }
EIGEN_IF_CONSTEXPR(isCRowMajor) { EIGEN_IF_CONSTEXPR(isCRowMajor) {
@ -642,19 +658,19 @@ void gemmKernel(Scalar *A_arr, Scalar *B_arr, Scalar *C_arr,
PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> zmm; PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> zmm;
urolls::template setzero<1, 2>(zmm); urolls::template setzero<1, 2>(zmm);
for (int64_t k = 0; k < K_; k += EIGEN_AVX_MAX_K_UNROL) { for (int64_t k = 0; k < K_; k += EIGEN_AVX_MAX_K_UNROL) {
urolls:: template microKernel<isARowMajor,1,2,EIGEN_AVX_MAX_K_UNROL, urolls::template microKernel<isARowMajor, 1, 2, EIGEN_AVX_MAX_K_UNROL, EIGEN_AVX_MAX_B_LOAD,
EIGEN_AVX_MAX_B_LOAD,EIGEN_AVX_MAX_A_BCAST,true>( EIGEN_AVX_MAX_A_BCAST, true>(B_t, A_t, LDB, LDA, zmm, N - j);
B_t, A_t, LDB, LDA, zmm, N - j);
B_t += EIGEN_AVX_MAX_K_UNROL * LDB; B_t += EIGEN_AVX_MAX_K_UNROL * LDB;
EIGEN_IF_CONSTEXPR(isARowMajor) A_t += EIGEN_AVX_MAX_K_UNROL; else A_t += EIGEN_AVX_MAX_K_UNROL*LDA; EIGEN_IF_CONSTEXPR(isARowMajor) A_t += EIGEN_AVX_MAX_K_UNROL;
else A_t += EIGEN_AVX_MAX_K_UNROL * LDA;
} }
EIGEN_IF_CONSTEXPR(handleKRem) { EIGEN_IF_CONSTEXPR(handleKRem) {
for (int64_t k = K_; k < K; k++) { for (int64_t k = K_; k < K; k++) {
urolls:: template microKernel<isARowMajor,1,2,1, urolls::template microKernel<isARowMajor, 1, 2, 1, EIGEN_AVX_MAX_B_LOAD, EIGEN_AVX_MAX_A_BCAST, true>(
EIGEN_AVX_MAX_B_LOAD,EIGEN_AVX_MAX_A_BCAST,true>(
B_t, A_t, LDB, LDA, zmm, N - j); B_t, A_t, LDB, LDA, zmm, N - j);
B_t += LDB; B_t += LDB;
EIGEN_IF_CONSTEXPR(isARowMajor) A_t++; else A_t += LDA; EIGEN_IF_CONSTEXPR(isARowMajor) A_t++;
else A_t += LDA;
} }
} }
EIGEN_IF_CONSTEXPR(isCRowMajor) { EIGEN_IF_CONSTEXPR(isCRowMajor) {
@ -672,19 +688,19 @@ void gemmKernel(Scalar *A_arr, Scalar *B_arr, Scalar *C_arr,
PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> zmm; PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> zmm;
urolls::template setzero<1, 1>(zmm); urolls::template setzero<1, 1>(zmm);
for (int64_t k = 0; k < K_; k += EIGEN_AVX_MAX_K_UNROL) { for (int64_t k = 0; k < K_; k += EIGEN_AVX_MAX_K_UNROL) {
urolls:: template microKernel<isARowMajor,1,1,EIGEN_AVX_MAX_K_UNROL, urolls::template microKernel<isARowMajor, 1, 1, EIGEN_AVX_MAX_K_UNROL, EIGEN_AVX_MAX_B_LOAD, 1, true>(
EIGEN_AVX_MAX_B_LOAD,1,true>(
B_t, A_t, LDB, LDA, zmm, N - j); B_t, A_t, LDB, LDA, zmm, N - j);
B_t += EIGEN_AVX_MAX_K_UNROL * LDB; B_t += EIGEN_AVX_MAX_K_UNROL * LDB;
EIGEN_IF_CONSTEXPR(isARowMajor) A_t += EIGEN_AVX_MAX_K_UNROL; else A_t += EIGEN_AVX_MAX_K_UNROL*LDA; EIGEN_IF_CONSTEXPR(isARowMajor) A_t += EIGEN_AVX_MAX_K_UNROL;
else A_t += EIGEN_AVX_MAX_K_UNROL * LDA;
} }
EIGEN_IF_CONSTEXPR(handleKRem) { EIGEN_IF_CONSTEXPR(handleKRem) {
for (int64_t k = K_; k < K; k++) { for (int64_t k = K_; k < K; k++) {
urolls:: template microKernel<isARowMajor,1,1,1, urolls::template microKernel<isARowMajor, 1, 1, 1, EIGEN_AVX_MAX_B_LOAD, 1, true>(B_t, A_t, LDB, LDA, zmm,
EIGEN_AVX_MAX_B_LOAD,1,true>( N - j);
B_t, A_t, LDB, LDA, zmm, N - j);
B_t += LDB; B_t += LDB;
EIGEN_IF_CONSTEXPR(isARowMajor) A_t++; else A_t += LDA; EIGEN_IF_CONSTEXPR(isARowMajor) A_t++;
else A_t += LDA;
} }
} }
EIGEN_IF_CONSTEXPR(isCRowMajor) { EIGEN_IF_CONSTEXPR(isCRowMajor) {
@ -707,9 +723,7 @@ void gemmKernel(Scalar *A_arr, Scalar *B_arr, Scalar *C_arr,
* The B matrix (RHS) is assumed to be row-major * The B matrix (RHS) is assumed to be row-major
*/ */
template <typename Scalar, typename vec, int64_t unrollM, bool isARowMajor, bool isFWDSolve, bool isUnitDiag> template <typename Scalar, typename vec, int64_t unrollM, bool isARowMajor, bool isFWDSolve, bool isUnitDiag>
static EIGEN_ALWAYS_INLINE static EIGEN_ALWAYS_INLINE void triSolveKernel(Scalar *A_arr, Scalar *B_arr, int64_t K, int64_t LDA, int64_t LDB) {
void triSolveKernel(Scalar *A_arr, Scalar *B_arr, int64_t K, int64_t LDA, int64_t LDB) {
static_assert(unrollM <= EIGEN_AVX_MAX_NUM_ROW, "unrollM should be equal to EIGEN_AVX_MAX_NUM_ROW"); static_assert(unrollM <= EIGEN_AVX_MAX_NUM_ROW, "unrollM should be equal to EIGEN_AVX_MAX_NUM_ROW");
using urolls = unrolls::trsm<Scalar>; using urolls = unrolls::trsm<Scalar>;
constexpr int64_t U3 = urolls::PacketSize * 3; constexpr int64_t U3 = urolls::PacketSize * 3;
@ -722,30 +736,30 @@ void triSolveKernel(Scalar *A_arr, Scalar *B_arr, int64_t K, int64_t LDA, int64_
int64_t k = 0; int64_t k = 0;
while (K - k >= U3) { while (K - k >= U3) {
urolls::template loadRHS<isFWDSolve, unrollM, 3>(B_arr + k, LDB, RHSInPacket); urolls::template loadRHS<isFWDSolve, unrollM, 3>(B_arr + k, LDB, RHSInPacket);
urolls:: template triSolveMicroKernel<isARowMajor, isFWDSolve, isUnitDiag, unrollM, 3>( urolls::template triSolveMicroKernel<isARowMajor, isFWDSolve, isUnitDiag, unrollM, 3>(A_arr, LDA, RHSInPacket,
A_arr, LDA, RHSInPacket, AInPacket); AInPacket);
urolls::template storeRHS<isFWDSolve, unrollM, 3>(B_arr + k, LDB, RHSInPacket); urolls::template storeRHS<isFWDSolve, unrollM, 3>(B_arr + k, LDB, RHSInPacket);
k += U3; k += U3;
} }
if (K - k >= U2) { if (K - k >= U2) {
urolls::template loadRHS<isFWDSolve, unrollM, 2>(B_arr + k, LDB, RHSInPacket); urolls::template loadRHS<isFWDSolve, unrollM, 2>(B_arr + k, LDB, RHSInPacket);
urolls:: template triSolveMicroKernel<isARowMajor, isFWDSolve, isUnitDiag, unrollM, 2>( urolls::template triSolveMicroKernel<isARowMajor, isFWDSolve, isUnitDiag, unrollM, 2>(A_arr, LDA, RHSInPacket,
A_arr, LDA, RHSInPacket, AInPacket); AInPacket);
urolls::template storeRHS<isFWDSolve, unrollM, 2>(B_arr + k, LDB, RHSInPacket); urolls::template storeRHS<isFWDSolve, unrollM, 2>(B_arr + k, LDB, RHSInPacket);
k += U2; k += U2;
} }
if (K - k >= U1) { if (K - k >= U1) {
urolls::template loadRHS<isFWDSolve, unrollM, 1>(B_arr + k, LDB, RHSInPacket); urolls::template loadRHS<isFWDSolve, unrollM, 1>(B_arr + k, LDB, RHSInPacket);
urolls:: template triSolveMicroKernel<isARowMajor, isFWDSolve, isUnitDiag, unrollM, 1>( urolls::template triSolveMicroKernel<isARowMajor, isFWDSolve, isUnitDiag, unrollM, 1>(A_arr, LDA, RHSInPacket,
A_arr, LDA, RHSInPacket, AInPacket); AInPacket);
urolls::template storeRHS<isFWDSolve, unrollM, 1>(B_arr + k, LDB, RHSInPacket); urolls::template storeRHS<isFWDSolve, unrollM, 1>(B_arr + k, LDB, RHSInPacket);
k += U1; k += U1;
} }
if (K - k > 0) { if (K - k > 0) {
// Handle remaining number of RHS // Handle remaining number of RHS
urolls::template loadRHS<isFWDSolve, unrollM, 1, true>(B_arr + k, LDB, RHSInPacket, K - k); urolls::template loadRHS<isFWDSolve, unrollM, 1, true>(B_arr + k, LDB, RHSInPacket, K - k);
urolls::template triSolveMicroKernel<isARowMajor, isFWDSolve, isUnitDiag, unrollM, 1>( urolls::template triSolveMicroKernel<isARowMajor, isFWDSolve, isUnitDiag, unrollM, 1>(A_arr, LDA, RHSInPacket,
A_arr, LDA, RHSInPacket, AInPacket); AInPacket);
urolls::template storeRHS<isFWDSolve, unrollM, 1, true>(B_arr + k, LDB, RHSInPacket, K - k); urolls::template storeRHS<isFWDSolve, unrollM, 1, true>(B_arr + k, LDB, RHSInPacket, K - k);
} }
} }
@ -790,9 +804,8 @@ void triSolveKernelLxK(Scalar *A_arr, Scalar *B_arr, int64_t M, int64_t K, int64
* *
*/ */
template <typename Scalar, bool toTemp = true, bool remM = false> template <typename Scalar, bool toTemp = true, bool remM = false>
static EIGEN_ALWAYS_INLINE static EIGEN_ALWAYS_INLINE void copyBToRowMajor(Scalar *B_arr, int64_t LDB, int64_t K, Scalar *B_temp, int64_t LDB_,
void copyBToRowMajor(Scalar *B_arr, int64_t LDB, int64_t K, int64_t remM_ = 0) {
Scalar *B_temp, int64_t LDB_, int64_t remM_ = 0) {
EIGEN_UNUSED_VARIABLE(remM_); EIGEN_UNUSED_VARIABLE(remM_);
using urolls = unrolls::transB<Scalar>; using urolls = unrolls::transB<Scalar>;
using vecHalf = typename std::conditional<std::is_same<Scalar, float>::value, vecHalfFloat, vecFullDouble>::type; using vecHalf = typename std::conditional<std::is_same<Scalar, float>::value, vecHalfFloat, vecFullDouble>::type;
@ -809,11 +822,13 @@ void copyBToRowMajor(Scalar *B_arr, int64_t LDB, int64_t K,
} }
if (K - k >= U2) { if (K - k >= U2) {
urolls::template transB_kernel<U2, toTemp, remM>(B_arr + k * LDB, LDB, B_temp, LDB_, ymm, remM_); urolls::template transB_kernel<U2, toTemp, remM>(B_arr + k * LDB, LDB, B_temp, LDB_, ymm, remM_);
B_temp += U2; k += U2; B_temp += U2;
k += U2;
} }
if (K - k >= U1) { if (K - k >= U1) {
urolls::template transB_kernel<U1, toTemp, remM>(B_arr + k * LDB, LDB, B_temp, LDB_, ymm, remM_); urolls::template transB_kernel<U1, toTemp, remM>(B_arr + k * LDB, LDB, B_temp, LDB_, ymm, remM_);
B_temp += U1; k += U1; B_temp += U1;
k += U1;
} }
EIGEN_IF_CONSTEXPR(U1 > 8) { EIGEN_IF_CONSTEXPR(U1 > 8) {
// Note: without "if constexpr" this section of code will also be // Note: without "if constexpr" this section of code will also be
@ -821,7 +836,8 @@ void copyBToRowMajor(Scalar *B_arr, int64_t LDB, int64_t K,
// to make sure the counter is not non-negative. // to make sure the counter is not non-negative.
if (K - k >= 8) { if (K - k >= 8) {
urolls::template transB_kernel<8, toTemp, remM>(B_arr + k * LDB, LDB, B_temp, LDB_, ymm, remM_); urolls::template transB_kernel<8, toTemp, remM>(B_arr + k * LDB, LDB, B_temp, LDB_, ymm, remM_);
B_temp += 8; k += 8; B_temp += 8;
k += 8;
} }
} }
EIGEN_IF_CONSTEXPR(U1 > 4) { EIGEN_IF_CONSTEXPR(U1 > 4) {
@ -830,19 +846,48 @@ void copyBToRowMajor(Scalar *B_arr, int64_t LDB, int64_t K,
// to make sure the counter is not non-negative. // to make sure the counter is not non-negative.
if (K - k >= 4) { if (K - k >= 4) {
urolls::template transB_kernel<4, toTemp, remM>(B_arr + k * LDB, LDB, B_temp, LDB_, ymm, remM_); urolls::template transB_kernel<4, toTemp, remM>(B_arr + k * LDB, LDB, B_temp, LDB_, ymm, remM_);
B_temp += 4; k += 4; B_temp += 4;
k += 4;
} }
} }
if (K - k >= 2) { if (K - k >= 2) {
urolls::template transB_kernel<2, toTemp, remM>(B_arr + k * LDB, LDB, B_temp, LDB_, ymm, remM_); urolls::template transB_kernel<2, toTemp, remM>(B_arr + k * LDB, LDB, B_temp, LDB_, ymm, remM_);
B_temp += 2; k += 2; B_temp += 2;
k += 2;
} }
if (K - k >= 1) { if (K - k >= 1) {
urolls::template transB_kernel<1, toTemp, remM>(B_arr + k * LDB, LDB, B_temp, LDB_, ymm, remM_); urolls::template transB_kernel<1, toTemp, remM>(B_arr + k * LDB, LDB, B_temp, LDB_, ymm, remM_);
B_temp += 1; k += 1; B_temp += 1;
k += 1;
} }
} }
#if (EIGEN_USE_AVX512_TRSM_L_KERNELS) && defined(EIGEN_NO_MALLOC)
/**
* Reduce blocking sizes so that the size of the temporary workspace needed is less than "limit" bytes,
* - kB must be at least psize
* - numM must be at least EIGEN_AVX_MAX_NUM_ROW
*/
template <typename Scalar, bool isBRowMajor>
constexpr std::pair<int64_t, int64_t> trsmBlocking(const int64_t limit) {
constexpr int64_t psize = packet_traits<Scalar>::size;
int64_t kB = 15 * psize;
int64_t numM = 8 * EIGEN_AVX_MAX_NUM_ROW;
// If B is rowmajor, no temp workspace needed, so use default blocking sizes.
if (isBRowMajor) return {kB, numM};
// Very simple heuristic, prefer keeping kB as large as possible to fully use vector registers.
for (int64_t k = kB; k > psize; k -= psize) {
for (int64_t m = numM; m > EIGEN_AVX_MAX_NUM_ROW; m -= EIGEN_AVX_MAX_NUM_ROW) {
if ((((k + psize - 1) / psize + 4) * psize) * m * sizeof(Scalar) < limit) {
return {k, m};
}
}
}
return {psize, EIGEN_AVX_MAX_NUM_ROW}; // Minimum blocking size required
}
#endif // (EIGEN_USE_AVX512_TRSM_L_KERNELS) && defined(EIGEN_NO_MALLOC)
/** /**
* Main triangular solve driver * Main triangular solve driver
* *
@ -870,8 +915,10 @@ void copyBToRowMajor(Scalar *B_arr, int64_t LDB, int64_t K,
* Note: For RXX cases M,numRHS should be swapped. * Note: For RXX cases M,numRHS should be swapped.
* *
*/ */
template <typename Scalar, bool isARowMajor = true, bool isBRowMajor = true, bool isFWDSolve = true, bool isUnitDiag = false> template <typename Scalar, bool isARowMajor = true, bool isBRowMajor = true, bool isFWDSolve = true,
bool isUnitDiag = false>
void triSolve(Scalar *A_arr, Scalar *B_arr, int64_t M, int64_t numRHS, int64_t LDA, int64_t LDB) { void triSolve(Scalar *A_arr, Scalar *B_arr, int64_t M, int64_t numRHS, int64_t LDA, int64_t LDB) {
constexpr int64_t psize = packet_traits<Scalar>::size;
/** /**
* The values for kB, numM were determined experimentally. * The values for kB, numM were determined experimentally.
* kB: Number of RHS we process at a time. * kB: Number of RHS we process at a time.
@ -885,8 +932,30 @@ void triSolve(Scalar *A_arr, Scalar *B_arr, int64_t M, int64_t numRHS, int64_t L
* large enough to allow GEMM updates to have larger "K"s (see below.) No benchmarking has been done so far to * large enough to allow GEMM updates to have larger "K"s (see below.) No benchmarking has been done so far to
* determine optimal values for numM. * determine optimal values for numM.
*/ */
const int64_t kB = (3*packet_traits<Scalar>::size)*5; // 5*U3 #if (EIGEN_USE_AVX512_TRSM_L_KERNELS) && defined(EIGEN_NO_MALLOC)
const int64_t numM = 64; /**
* If EIGEN_NO_MALLOC is requested, we try to reduce kB and numM so the maximum temp workspace required is less
* than EIGEN_STACK_ALLOCATION_LIMIT. Actual workspace size may be less, depending on the number of vectors to
* solve.
* - kB must be at least psize
* - numM must be at least EIGEN_AVX_MAX_NUM_ROW
*
* If B is row-major, the blocking sizes are not reduced (no temp workspace needed).
*/
constexpr std::pair<int64_t, int64_t> blocking_ = trsmBlocking<Scalar, isBRowMajor>(EIGEN_STACK_ALLOCATION_LIMIT);
constexpr int64_t kB = blocking_.first;
constexpr int64_t numM = blocking_.second;
/**
* If the temp workspace size exceeds EIGEN_STACK_ALLOCATION_LIMIT even with the minimum blocking sizes,
* we throw an assertion. Use -DEIGEN_USE_AVX512_TRSM_L_KERNELS=0 if necessary
*/
static_assert(!(((((kB + psize - 1) / psize + 4) * psize) * numM * sizeof(Scalar) >= EIGEN_STACK_ALLOCATION_LIMIT) &&
!isBRowMajor),
"Temp workspace required is too large.");
#else
constexpr int64_t kB = (3 * psize) * 5; // 5*U3
constexpr int64_t numM = 8 * EIGEN_AVX_MAX_NUM_ROW;
#endif
int64_t sizeBTemp = 0; int64_t sizeBTemp = 0;
Scalar *B_temp = NULL; Scalar *B_temp = NULL;
@ -896,9 +965,17 @@ void triSolve(Scalar *A_arr, Scalar *B_arr, int64_t M, int64_t numRHS, int64_t L
* transpose it to row-major. Call the solve routine, and copy+transpose it back to the original array. * transpose it to row-major. Call the solve routine, and copy+transpose it back to the original array.
* The updated row-major copy of B is reused in the GEMM updates. * The updated row-major copy of B is reused in the GEMM updates.
*/ */
sizeBTemp = (((std::min(kB, numRHS) + 15)/16+ 4)*16)*numM; sizeBTemp = (((std::min(kB, numRHS) + psize - 1) / psize + 4) * psize) * numM;
B_temp = (Scalar*) handmade_aligned_malloc(sizeof(Scalar)*sizeBTemp,4096);
} }
#if !defined(EIGEN_NO_MALLOC)
EIGEN_IF_CONSTEXPR(!isBRowMajor) B_temp = (Scalar *)handmade_aligned_malloc(sizeof(Scalar) * sizeBTemp, 64);
#elif (EIGEN_USE_AVX512_TRSM_L_KERNELS) && defined(EIGEN_NO_MALLOC)
// Use alloca if malloc not allowed, requested temp workspace size should be less than EIGEN_STACK_ALLOCATION_LIMIT
ei_declare_aligned_stack_constructed_variable(Scalar, B_temp_alloca, sizeBTemp, 0);
B_temp = B_temp_alloca;
#endif
for (int64_t k = 0; k < numRHS; k += kB) { for (int64_t k = 0; k < numRHS; k += kB) {
int64_t bK = numRHS - k > kB ? kB : numRHS - k; int64_t bK = numRHS - k > kB ? kB : numRHS - k;
int64_t M_ = (M / EIGEN_AVX_MAX_NUM_ROW) * EIGEN_AVX_MAX_NUM_ROW, gemmOff = 0; int64_t M_ = (M / EIGEN_AVX_MAX_NUM_ROW) * EIGEN_AVX_MAX_NUM_ROW, gemmOff = 0;
@ -950,11 +1027,8 @@ void triSolve(Scalar *A_arr, Scalar *B_arr, int64_t M, int64_t numRHS, int64_t L
int64_t indB_i = isFWDSolve ? 0 : M - (i + EIGEN_AVX_MAX_NUM_ROW); int64_t indB_i = isFWDSolve ? 0 : M - (i + EIGEN_AVX_MAX_NUM_ROW);
int64_t indB_i2 = isFWDSolve ? i + EIGEN_AVX_MAX_NUM_ROW : M - (i + 2 * EIGEN_AVX_MAX_NUM_ROW); int64_t indB_i2 = isFWDSolve ? i + EIGEN_AVX_MAX_NUM_ROW : M - (i + 2 * EIGEN_AVX_MAX_NUM_ROW);
gemmKernel<Scalar, isARowMajor, isBRowMajor, false, false>( gemmKernel<Scalar, isARowMajor, isBRowMajor, false, false>(
&A_arr[idA<isARowMajor>(indA_i,indA_j,LDA)], &A_arr[idA<isARowMajor>(indA_i, indA_j, LDA)], B_arr + k + indB_i * LDB, B_arr + k + indB_i2 * LDB,
B_arr + k + indB_i*LDB, EIGEN_AVX_MAX_NUM_ROW, bK, i + EIGEN_AVX_MAX_NUM_ROW, LDA, LDB, LDB);
B_arr + k + indB_i2*LDB,
EIGEN_AVX_MAX_NUM_ROW, bK, i + EIGEN_AVX_MAX_NUM_ROW,
LDA, LDB, LDB);
} }
else { else {
if (offsetBTemp + EIGEN_AVX_MAX_NUM_ROW * LDT > sizeBTemp) { if (offsetBTemp + EIGEN_AVX_MAX_NUM_ROW * LDT > sizeBTemp) {
@ -971,14 +1045,11 @@ void triSolve(Scalar *A_arr, Scalar *B_arr, int64_t M, int64_t numRHS, int64_t L
int64_t indB_i = isFWDSolve ? i + EIGEN_AVX_MAX_NUM_ROW : 0; int64_t indB_i = isFWDSolve ? i + EIGEN_AVX_MAX_NUM_ROW : 0;
int64_t offB_1 = isFWDSolve ? 0 : sizeBTemp - offsetBTemp; int64_t offB_1 = isFWDSolve ? 0 : sizeBTemp - offsetBTemp;
gemmKernel<Scalar, isARowMajor, isBRowMajor, false, false>( gemmKernel<Scalar, isARowMajor, isBRowMajor, false, false>(
&A_arr[idA<isARowMajor>(indA_i, indA_j,LDA)], &A_arr[idA<isARowMajor>(indA_i, indA_j, LDA)], B_temp + offB_1, B_arr + indB_i + (k)*LDB,
B_temp + offB_1, M - (i + EIGEN_AVX_MAX_NUM_ROW), bK, i + EIGEN_AVX_MAX_NUM_ROW - gemmOff, LDA, LDT, LDB);
B_arr + indB_i + (k)*LDB, offsetBTemp = 0;
M - (i + EIGEN_AVX_MAX_NUM_ROW), bK, i + EIGEN_AVX_MAX_NUM_ROW - gemmOff, gemmOff = i + EIGEN_AVX_MAX_NUM_ROW;
LDA, LDT, LDB); } else {
offsetBTemp = 0; gemmOff = i + EIGEN_AVX_MAX_NUM_ROW;
}
else {
/** /**
* If there is enough space in B_temp, we only update the next 8xbK values of B. * If there is enough space in B_temp, we only update the next 8xbK values of B.
*/ */
@ -987,11 +1058,8 @@ void triSolve(Scalar *A_arr, Scalar *B_arr, int64_t M, int64_t numRHS, int64_t L
int64_t indB_i = isFWDSolve ? i + EIGEN_AVX_MAX_NUM_ROW : M - (i + 2 * EIGEN_AVX_MAX_NUM_ROW); int64_t indB_i = isFWDSolve ? i + EIGEN_AVX_MAX_NUM_ROW : M - (i + 2 * EIGEN_AVX_MAX_NUM_ROW);
int64_t offB_1 = isFWDSolve ? 0 : sizeBTemp - offsetBTemp; int64_t offB_1 = isFWDSolve ? 0 : sizeBTemp - offsetBTemp;
gemmKernel<Scalar, isARowMajor, isBRowMajor, false, false>( gemmKernel<Scalar, isARowMajor, isBRowMajor, false, false>(
&A_arr[idA<isARowMajor>(indA_i,indA_j,LDA)], &A_arr[idA<isARowMajor>(indA_i, indA_j, LDA)], B_temp + offB_1, B_arr + indB_i + (k)*LDB,
B_temp + offB_1, EIGEN_AVX_MAX_NUM_ROW, bK, i + EIGEN_AVX_MAX_NUM_ROW - gemmOff, LDA, LDT, LDB);
B_arr + indB_i + (k)*LDB,
EIGEN_AVX_MAX_NUM_ROW, bK, i + EIGEN_AVX_MAX_NUM_ROW - gemmOff,
LDA, LDT, LDB);
} }
} }
} }
@ -1006,23 +1074,17 @@ void triSolve(Scalar *A_arr, Scalar *B_arr, int64_t M, int64_t numRHS, int64_t L
int64_t indB_i = isFWDSolve ? 0 : bM; int64_t indB_i = isFWDSolve ? 0 : bM;
int64_t indB_i2 = isFWDSolve ? M_ : 0; int64_t indB_i2 = isFWDSolve ? M_ : 0;
gemmKernel<Scalar, isARowMajor, isBRowMajor, false, false>( gemmKernel<Scalar, isARowMajor, isBRowMajor, false, false>(
&A_arr[idA<isARowMajor>(indA_i,indA_j,LDA)], &A_arr[idA<isARowMajor>(indA_i, indA_j, LDA)], B_arr + k + indB_i * LDB, B_arr + k + indB_i2 * LDB, bM,
B_arr + k +indB_i*LDB, bK, M_, LDA, LDB, LDB);
B_arr + k + indB_i2*LDB,
bM , bK, M_,
LDA, LDB, LDB);
} }
else { else {
int64_t indA_i = isFWDSolve ? M_ : 0; int64_t indA_i = isFWDSolve ? M_ : 0;
int64_t indA_j = isFWDSolve ? gemmOff : bM; int64_t indA_j = isFWDSolve ? gemmOff : bM;
int64_t indB_i = isFWDSolve ? M_ : 0; int64_t indB_i = isFWDSolve ? M_ : 0;
int64_t offB_1 = isFWDSolve ? 0 : sizeBTemp - offsetBTemp; int64_t offB_1 = isFWDSolve ? 0 : sizeBTemp - offsetBTemp;
gemmKernel<Scalar,isARowMajor, isBRowMajor,false,false>( gemmKernel<Scalar, isARowMajor, isBRowMajor, false, false>(&A_arr[idA<isARowMajor>(indA_i, indA_j, LDA)],
&A_arr[idA<isARowMajor>(indA_i,indA_j,LDA)], B_temp + offB_1, B_arr + indB_i + (k)*LDB, bM, bK,
B_temp + offB_1, M_ - gemmOff, LDA, LDT, LDB);
B_arr + indB_i + (k)*LDB,
bM , bK, M_ - gemmOff,
LDA, LDT, LDB);
} }
} }
EIGEN_IF_CONSTEXPR(!isBRowMajor) { EIGEN_IF_CONSTEXPR(!isBRowMajor) {
@ -1030,44 +1092,45 @@ void triSolve(Scalar *A_arr, Scalar *B_arr, int64_t M, int64_t numRHS, int64_t L
int64_t indB_i = isFWDSolve ? M_ : 0; int64_t indB_i = isFWDSolve ? M_ : 0;
int64_t offB_1 = isFWDSolve ? 0 : (bM - 1) * bkL; int64_t offB_1 = isFWDSolve ? 0 : (bM - 1) * bkL;
copyBToRowMajor<Scalar, true, true>(B_arr + indB_i + k * LDB, LDB, bK, B_temp, bkL, bM); copyBToRowMajor<Scalar, true, true>(B_arr + indB_i + k * LDB, LDB, bK, B_temp, bkL, bM);
triSolveKernelLxK<Scalar, isARowMajor, isFWDSolve, isUnitDiag>( triSolveKernelLxK<Scalar, isARowMajor, isFWDSolve, isUnitDiag>(&A_arr[idA<isARowMajor>(indA_i, indA_i, LDA)],
&A_arr[idA<isARowMajor>(indA_i, indA_i, LDA)], B_temp + offB_1, bM, bkL, LDA, bkL); B_temp + offB_1, bM, bkL, LDA, bkL);
copyBToRowMajor<Scalar, false, true>(B_arr + indB_i + k * LDB, LDB, bK, B_temp, bkL, bM); copyBToRowMajor<Scalar, false, true>(B_arr + indB_i + k * LDB, LDB, bK, B_temp, bkL, bM);
} }
else { else {
int64_t ind = isFWDSolve ? M_ : M - 1 - M_; int64_t ind = isFWDSolve ? M_ : M - 1 - M_;
triSolveKernelLxK<Scalar, isARowMajor, isFWDSolve, isUnitDiag>( triSolveKernelLxK<Scalar, isARowMajor, isFWDSolve, isUnitDiag>(&A_arr[idA<isARowMajor>(ind, ind, LDA)],
&A_arr[idA<isARowMajor>(ind, ind, LDA)], B_arr + k + ind*LDB, bM, bK, LDA, LDB); B_arr + k + ind * LDB, bM, bK, LDA, LDB);
} }
} }
} }
#if !defined(EIGEN_NO_MALLOC)
EIGEN_IF_CONSTEXPR(!isBRowMajor) handmade_aligned_free(B_temp); EIGEN_IF_CONSTEXPR(!isBRowMajor) handmade_aligned_free(B_temp);
#endif
} }
// Template specializations of trsmKernelL/R for float/double and inner strides of 1. // Template specializations of trsmKernelL/R for float/double and inner strides of 1.
#if defined(EIGEN_USE_AVX512_TRSM_KERNELS) #if (EIGEN_USE_AVX512_TRSM_KERNELS)
#if (EIGEN_USE_AVX512_TRSM_R_KERNELS)
template <typename Scalar, typename Index, int Mode, bool Conjugate, int TriStorageOrder, int OtherInnerStride> template <typename Scalar, typename Index, int Mode, bool Conjugate, int TriStorageOrder, int OtherInnerStride>
struct trsmKernelR; struct trsmKernelR;
template <typename Index, int Mode, int TriStorageOrder> template <typename Index, int Mode, int TriStorageOrder>
struct trsmKernelR<float, Index, Mode, false, TriStorageOrder, 1> { struct trsmKernelR<float, Index, Mode, false, TriStorageOrder, 1> {
static void kernel(Index size, Index otherSize, const float* _tri, Index triStride, static void kernel(Index size, Index otherSize, const float *_tri, Index triStride, float *_other, Index otherIncr,
float* _other, Index otherIncr, Index otherStride); Index otherStride);
}; };
template <typename Index, int Mode, int TriStorageOrder> template <typename Index, int Mode, int TriStorageOrder>
struct trsmKernelR<double, Index, Mode, false, TriStorageOrder, 1> { struct trsmKernelR<double, Index, Mode, false, TriStorageOrder, 1> {
static void kernel(Index size, Index otherSize, const double* _tri, Index triStride, static void kernel(Index size, Index otherSize, const double *_tri, Index triStride, double *_other, Index otherIncr,
double* _other, Index otherIncr, Index otherStride); Index otherStride);
}; };
template <typename Index, int Mode, int TriStorageOrder> template <typename Index, int Mode, int TriStorageOrder>
EIGEN_DONT_INLINE void trsmKernelR<float, Index, Mode, false, TriStorageOrder, 1>::kernel( EIGEN_DONT_INLINE void trsmKernelR<float, Index, Mode, false, TriStorageOrder, 1>::kernel(
Index size, Index otherSize, Index size, Index otherSize, const float *_tri, Index triStride, float *_other, Index otherIncr,
const float* _tri, Index triStride, Index otherStride) {
float* _other, Index otherIncr, Index otherStride)
{
EIGEN_UNUSED_VARIABLE(otherIncr); EIGEN_UNUSED_VARIABLE(otherIncr);
triSolve<float, TriStorageOrder != RowMajor, true, (Mode & Lower) != Lower, (Mode & UnitDiag) != 0>( triSolve<float, TriStorageOrder != RowMajor, true, (Mode & Lower) != Lower, (Mode & UnitDiag) != 0>(
const_cast<float *>(_tri), _other, size, otherSize, triStride, otherStride); const_cast<float *>(_tri), _other, size, otherSize, triStride, otherStride);
@ -1075,38 +1138,35 @@ EIGEN_DONT_INLINE void trsmKernelR<float, Index, Mode, false, TriStorageOrder, 1
template <typename Index, int Mode, int TriStorageOrder> template <typename Index, int Mode, int TriStorageOrder>
EIGEN_DONT_INLINE void trsmKernelR<double, Index, Mode, false, TriStorageOrder, 1>::kernel( EIGEN_DONT_INLINE void trsmKernelR<double, Index, Mode, false, TriStorageOrder, 1>::kernel(
Index size, Index otherSize, Index size, Index otherSize, const double *_tri, Index triStride, double *_other, Index otherIncr,
const double* _tri, Index triStride, Index otherStride) {
double* _other, Index otherIncr, Index otherStride)
{
EIGEN_UNUSED_VARIABLE(otherIncr); EIGEN_UNUSED_VARIABLE(otherIncr);
triSolve<double, TriStorageOrder != RowMajor, true, (Mode & Lower) != Lower, (Mode & UnitDiag) != 0>( triSolve<double, TriStorageOrder != RowMajor, true, (Mode & Lower) != Lower, (Mode & UnitDiag) != 0>(
const_cast<double *>(_tri), _other, size, otherSize, triStride, otherStride); const_cast<double *>(_tri), _other, size, otherSize, triStride, otherStride);
} }
#endif // (EIGEN_USE_AVX512_TRSM_R_KERNELS)
// These trsm kernels require temporary memory allocation, so disable them if malloc is not allowed. // These trsm kernels require temporary memory allocation
#if defined(EIGEN_USE_AVX512_TRSM_L_KERNELS) #if (EIGEN_USE_AVX512_TRSM_L_KERNELS)
template <typename Scalar, typename Index, int Mode, bool Conjugate, int TriStorageOrder, int OtherInnerStride> template <typename Scalar, typename Index, int Mode, bool Conjugate, int TriStorageOrder, int OtherInnerStride>
struct trsmKernelL; struct trsmKernelL;
template <typename Index, int Mode, int TriStorageOrder> template <typename Index, int Mode, int TriStorageOrder>
struct trsmKernelL<float, Index, Mode, false, TriStorageOrder, 1> { struct trsmKernelL<float, Index, Mode, false, TriStorageOrder, 1> {
static void kernel(Index size, Index otherSize, const float* _tri, Index triStride, static void kernel(Index size, Index otherSize, const float *_tri, Index triStride, float *_other, Index otherIncr,
float* _other, Index otherIncr, Index otherStride); Index otherStride);
}; };
template <typename Index, int Mode, int TriStorageOrder> template <typename Index, int Mode, int TriStorageOrder>
struct trsmKernelL<double, Index, Mode, false, TriStorageOrder, 1> { struct trsmKernelL<double, Index, Mode, false, TriStorageOrder, 1> {
static void kernel(Index size, Index otherSize, const double* _tri, Index triStride, static void kernel(Index size, Index otherSize, const double *_tri, Index triStride, double *_other, Index otherIncr,
double* _other, Index otherIncr, Index otherStride); Index otherStride);
}; };
template <typename Index, int Mode, int TriStorageOrder> template <typename Index, int Mode, int TriStorageOrder>
EIGEN_DONT_INLINE void trsmKernelL<float, Index, Mode, false, TriStorageOrder, 1>::kernel( EIGEN_DONT_INLINE void trsmKernelL<float, Index, Mode, false, TriStorageOrder, 1>::kernel(
Index size, Index otherSize, Index size, Index otherSize, const float *_tri, Index triStride, float *_other, Index otherIncr,
const float* _tri, Index triStride, Index otherStride) {
float* _other, Index otherIncr, Index otherStride)
{
EIGEN_UNUSED_VARIABLE(otherIncr); EIGEN_UNUSED_VARIABLE(otherIncr);
triSolve<float, TriStorageOrder == RowMajor, false, (Mode & Lower) == Lower, (Mode & UnitDiag) != 0>( triSolve<float, TriStorageOrder == RowMajor, false, (Mode & Lower) == Lower, (Mode & UnitDiag) != 0>(
const_cast<float *>(_tri), _other, size, otherSize, triStride, otherStride); const_cast<float *>(_tri), _other, size, otherSize, triStride, otherStride);
@ -1114,16 +1174,14 @@ EIGEN_DONT_INLINE void trsmKernelL<float, Index, Mode, false, TriStorageOrder, 1
template <typename Index, int Mode, int TriStorageOrder> template <typename Index, int Mode, int TriStorageOrder>
EIGEN_DONT_INLINE void trsmKernelL<double, Index, Mode, false, TriStorageOrder, 1>::kernel( EIGEN_DONT_INLINE void trsmKernelL<double, Index, Mode, false, TriStorageOrder, 1>::kernel(
Index size, Index otherSize, Index size, Index otherSize, const double *_tri, Index triStride, double *_other, Index otherIncr,
const double* _tri, Index triStride, Index otherStride) {
double* _other, Index otherIncr, Index otherStride)
{
EIGEN_UNUSED_VARIABLE(otherIncr); EIGEN_UNUSED_VARIABLE(otherIncr);
triSolve<double, TriStorageOrder == RowMajor, false, (Mode & Lower) == Lower, (Mode & UnitDiag) != 0>( triSolve<double, TriStorageOrder == RowMajor, false, (Mode & Lower) == Lower, (Mode & UnitDiag) != 0>(
const_cast<double *>(_tri), _other, size, otherSize, triStride, otherStride); const_cast<double *>(_tri), _other, size, otherSize, triStride, otherStride);
} }
#endif // EIGEN_USE_AVX512_TRSM_L_KERNELS #endif // EIGEN_USE_AVX512_TRSM_L_KERNELS
#endif // EIGEN_USE_AVX512_TRSM_KERNELS #endif // EIGEN_USE_AVX512_TRSM_KERNELS
} } // namespace internal
} } // namespace Eigen
#endif // EIGEN_TRSM_KERNEL_IMPL_H #endif // EIGEN_TRSM_KERNEL_IMPL_H

View File

@ -11,8 +11,7 @@
#define EIGEN_UNROLLS_IMPL_H #define EIGEN_UNROLLS_IMPL_H
template <bool isARowMajor = true> template <bool isARowMajor = true>
static EIGEN_ALWAYS_INLINE static EIGEN_ALWAYS_INLINE int64_t idA(int64_t i, int64_t j, int64_t LDA) {
int64_t idA(int64_t i, int64_t j, int64_t LDA) {
EIGEN_IF_CONSTEXPR(isARowMajor) return i * LDA + j; EIGEN_IF_CONSTEXPR(isARowMajor) return i * LDA + j;
else return i + j * LDA; else return i + j * LDA;
} }
@ -60,8 +59,12 @@ namespace unrolls {
template <int64_t N> template <int64_t N>
EIGEN_ALWAYS_INLINE auto remMask(int64_t m) { EIGEN_ALWAYS_INLINE auto remMask(int64_t m) {
EIGEN_IF_CONSTEXPR(N == 16) { return 0xFFFF >> (16 - m); } EIGEN_IF_CONSTEXPR(N == 16) { return 0xFFFF >> (16 - m); }
else EIGEN_IF_CONSTEXPR( N == 8) { return 0xFF >> (8 - m); } else EIGEN_IF_CONSTEXPR(N == 8) {
else EIGEN_IF_CONSTEXPR( N == 4) { return 0x0F >> (4 - m); } return 0xFF >> (8 - m);
}
else EIGEN_IF_CONSTEXPR(N == 4) {
return 0x0F >> (4 - m);
}
return 0; return 0;
} }
@ -105,10 +108,14 @@ EIGEN_ALWAYS_INLINE void trans8x8blocks(PacketBlock<Packet16f, 8>& kernel) {
T7 = _mm512_castpd_ps(_mm512_permutex_pd(_mm512_castps_pd(kernel.packet[3]), 0x4E)); T7 = _mm512_castpd_ps(_mm512_permutex_pd(_mm512_castps_pd(kernel.packet[3]), 0x4E));
T7 = _mm512_mask_blend_ps(0xF0F0, T7, kernel.packet[7]); T7 = _mm512_mask_blend_ps(0xF0F0, T7, kernel.packet[7]);
kernel.packet[0] = T0; kernel.packet[1] = T1; kernel.packet[0] = T0;
kernel.packet[2] = T2; kernel.packet[3] = T3; kernel.packet[1] = T1;
kernel.packet[4] = T4; kernel.packet[5] = T5; kernel.packet[2] = T2;
kernel.packet[6] = T6; kernel.packet[7] = T7; kernel.packet[3] = T3;
kernel.packet[4] = T4;
kernel.packet[5] = T5;
kernel.packet[6] = T6;
kernel.packet[7] = T7;
} }
template <> template <>
@ -126,7 +133,6 @@ public:
using vecHalf = typename std::conditional<std::is_same<Scalar, float>::value, vecHalfFloat, vecFullDouble>::type; using vecHalf = typename std::conditional<std::is_same<Scalar, float>::value, vecHalfFloat, vecFullDouble>::type;
static constexpr int64_t PacketSize = packet_traits<Scalar>::size; static constexpr int64_t PacketSize = packet_traits<Scalar>::size;
/*********************************** /***********************************
* Auxillary Functions for: * Auxillary Functions for:
* - storeC * - storeC
@ -143,9 +149,8 @@ public:
* *
**/ **/
template <int64_t endN, int64_t counter, int64_t unrollN, int64_t packetIndexOffset, bool remM> template <int64_t endN, int64_t counter, int64_t unrollN, int64_t packetIndexOffset, bool remM>
static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0 && endN <= PacketSize)> static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0 && endN <= PacketSize)> aux_storeC(
aux_storeC(Scalar *C_arr, int64_t LDC, Scalar *C_arr, int64_t LDC, PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm, int64_t remM_ = 0) {
PacketBlock<vec,EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm, int64_t remM_ = 0) {
constexpr int64_t counterReverse = endN - counter; constexpr int64_t counterReverse = endN - counter;
constexpr int64_t startN = counterReverse; constexpr int64_t startN = counterReverse;
@ -159,8 +164,7 @@ public:
remMask<EIGEN_AVX_MAX_NUM_ROW>(remM_)); remMask<EIGEN_AVX_MAX_NUM_ROW>(remM_));
} }
else { else {
pstoreu<Scalar>( pstoreu<Scalar>(C_arr + LDC * startN,
C_arr + LDC*startN,
padd(ploadu<vecHalf>((const Scalar *)C_arr + LDC * startN), padd(ploadu<vecHalf>((const Scalar *)C_arr + LDC * startN),
preinterpret<vecHalf>(zmm.packet[packetIndexOffset + (unrollN / PacketSize) * startN]))); preinterpret<vecHalf>(zmm.packet[packetIndexOffset + (unrollN / PacketSize) * startN])));
} }
@ -176,26 +180,25 @@ public:
EIGEN_IF_CONSTEXPR(remM) { EIGEN_IF_CONSTEXPR(remM) {
pstoreu<Scalar>( pstoreu<Scalar>(
C_arr + LDC * startN, C_arr + LDC * startN,
padd(ploadu<vecHalf>((const Scalar*)C_arr + LDC*startN, padd(ploadu<vecHalf>((const Scalar *)C_arr + LDC * startN, remMask<EIGEN_AVX_MAX_NUM_ROW>(remM_)),
remMask<EIGEN_AVX_MAX_NUM_ROW>(remM_)), preinterpret<vecHalf>(
preinterpret<vecHalf>(zmm.packet[packetIndexOffset + (unrollN/PacketSize)*(startN-EIGEN_AVX_MAX_NUM_ROW)])), zmm.packet[packetIndexOffset + (unrollN / PacketSize) * (startN - EIGEN_AVX_MAX_NUM_ROW)])),
remMask<EIGEN_AVX_MAX_NUM_ROW>(remM_)); remMask<EIGEN_AVX_MAX_NUM_ROW>(remM_));
} }
else { else {
pstoreu<Scalar>( pstoreu<Scalar>(
C_arr + LDC * startN, C_arr + LDC * startN,
padd(ploadu<vecHalf>((const Scalar *)C_arr + LDC * startN), padd(ploadu<vecHalf>((const Scalar *)C_arr + LDC * startN),
preinterpret<vecHalf>(zmm.packet[packetIndexOffset + (unrollN/PacketSize)*(startN-EIGEN_AVX_MAX_NUM_ROW)]))); preinterpret<vecHalf>(
zmm.packet[packetIndexOffset + (unrollN / PacketSize) * (startN - EIGEN_AVX_MAX_NUM_ROW)])));
} }
} }
aux_storeC<endN, counter - 1, unrollN, packetIndexOffset, remM>(C_arr, LDC, zmm, remM_); aux_storeC<endN, counter - 1, unrollN, packetIndexOffset, remM>(C_arr, LDC, zmm, remM_);
} }
template <int64_t endN, int64_t counter, int64_t unrollN, int64_t packetIndexOffset, bool remM> template <int64_t endN, int64_t counter, int64_t unrollN, int64_t packetIndexOffset, bool remM>
static EIGEN_ALWAYS_INLINE std::enable_if_t<!(counter > 0 && endN <= PacketSize)> static EIGEN_ALWAYS_INLINE std::enable_if_t<!(counter > 0 && endN <= PacketSize)> aux_storeC(
aux_storeC(Scalar *C_arr, int64_t LDC, Scalar *C_arr, int64_t LDC, PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm, int64_t remM_ = 0) {
PacketBlock<vec,EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm, int64_t remM_ = 0)
{
EIGEN_UNUSED_VARIABLE(C_arr); EIGEN_UNUSED_VARIABLE(C_arr);
EIGEN_UNUSED_VARIABLE(LDC); EIGEN_UNUSED_VARIABLE(LDC);
EIGEN_UNUSED_VARIABLE(zmm); EIGEN_UNUSED_VARIABLE(zmm);
@ -203,9 +206,9 @@ public:
} }
template <int64_t endN, int64_t unrollN, int64_t packetIndexOffset, bool remM> template <int64_t endN, int64_t unrollN, int64_t packetIndexOffset, bool remM>
static EIGEN_ALWAYS_INLINE static EIGEN_ALWAYS_INLINE void storeC(Scalar *C_arr, int64_t LDC,
void storeC(Scalar *C_arr, int64_t LDC, PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm,
PacketBlock<vec,EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm, int64_t remM_ = 0){ int64_t remM_ = 0) {
aux_storeC<endN, endN, unrollN, packetIndexOffset, remM>(C_arr, LDC, zmm, remM_); aux_storeC<endN, endN, unrollN, packetIndexOffset, remM>(C_arr, LDC, zmm, remM_);
} }
@ -235,8 +238,7 @@ public:
* avx registers are being transposed. * avx registers are being transposed.
*/ */
template <int64_t unrollN, int64_t packetIndexOffset> template <int64_t unrollN, int64_t packetIndexOffset>
static EIGEN_ALWAYS_INLINE static EIGEN_ALWAYS_INLINE void transpose(PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm) {
void transpose(PacketBlock<vec,EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm) {
// Note: this assumes EIGEN_AVX_MAX_NUM_ROW = 8. Unrolls should be adjusted // Note: this assumes EIGEN_AVX_MAX_NUM_ROW = 8. Unrolls should be adjusted
// accordingly if EIGEN_AVX_MAX_NUM_ROW is smaller. // accordingly if EIGEN_AVX_MAX_NUM_ROW is smaller.
constexpr int64_t zmmStride = unrollN / PacketSize; constexpr int64_t zmmStride = unrollN / PacketSize;
@ -298,27 +300,25 @@ public:
* for(startN = 0; startN < endN; startN++) * for(startN = 0; startN < endN; startN++)
**/ **/
template <int64_t endN, int64_t counter, int64_t packetIndexOffset, bool remM> template <int64_t endN, int64_t counter, int64_t packetIndexOffset, bool remM>
static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_loadB(
aux_loadB(Scalar *B_arr, int64_t LDB, Scalar *B_arr, int64_t LDB, PacketBlock<vecHalf, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &ymm,
PacketBlock<vecHalf,EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &ymm, int64_t remM_ = 0) { int64_t remM_ = 0) {
constexpr int64_t counterReverse = endN - counter; constexpr int64_t counterReverse = endN - counter;
constexpr int64_t startN = counterReverse; constexpr int64_t startN = counterReverse;
EIGEN_IF_CONSTEXPR(remM) { EIGEN_IF_CONSTEXPR(remM) {
ymm.packet[packetIndexOffset + startN] = ploadu<vecHalf>( ymm.packet[packetIndexOffset + startN] =
(const Scalar*)&B_arr[startN*LDB], remMask<EIGEN_AVX_MAX_NUM_ROW>(remM_)); ploadu<vecHalf>((const Scalar *)&B_arr[startN * LDB], remMask<EIGEN_AVX_MAX_NUM_ROW>(remM_));
} }
else else ymm.packet[packetIndexOffset + startN] = ploadu<vecHalf>((const Scalar *)&B_arr[startN * LDB]);
ymm.packet[packetIndexOffset + startN] = ploadu<vecHalf>((const Scalar*)&B_arr[startN*LDB]);
aux_loadB<endN, counter - 1, packetIndexOffset, remM>(B_arr, LDB, ymm, remM_); aux_loadB<endN, counter - 1, packetIndexOffset, remM>(B_arr, LDB, ymm, remM_);
} }
template <int64_t endN, int64_t counter, int64_t packetIndexOffset, bool remM> template <int64_t endN, int64_t counter, int64_t packetIndexOffset, bool remM>
static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_loadB(
aux_loadB(Scalar *B_arr, int64_t LDB, Scalar *B_arr, int64_t LDB, PacketBlock<vecHalf, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &ymm,
PacketBlock<vecHalf,EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &ymm, int64_t remM_ = 0) int64_t remM_ = 0) {
{
EIGEN_UNUSED_VARIABLE(B_arr); EIGEN_UNUSED_VARIABLE(B_arr);
EIGEN_UNUSED_VARIABLE(LDB); EIGEN_UNUSED_VARIABLE(LDB);
EIGEN_UNUSED_VARIABLE(ymm); EIGEN_UNUSED_VARIABLE(ymm);
@ -332,16 +332,13 @@ public:
* for(startN = 0; startN < endN; startN++) * for(startN = 0; startN < endN; startN++)
**/ **/
template <int64_t endN, int64_t counter, int64_t packetIndexOffset, bool remK, bool remM> template <int64_t endN, int64_t counter, int64_t packetIndexOffset, bool remK, bool remM>
static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_storeB(
aux_storeB(Scalar *B_arr, int64_t LDB, Scalar *B_arr, int64_t LDB, PacketBlock<vecHalf, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &ymm, int64_t rem_ = 0) {
PacketBlock<vecHalf,EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &ymm, int64_t rem_ = 0) {
constexpr int64_t counterReverse = endN - counter; constexpr int64_t counterReverse = endN - counter;
constexpr int64_t startN = counterReverse; constexpr int64_t startN = counterReverse;
EIGEN_IF_CONSTEXPR(remK || remM) { EIGEN_IF_CONSTEXPR(remK || remM) {
pstoreu<Scalar>( pstoreu<Scalar>(&B_arr[startN * LDB], ymm.packet[packetIndexOffset + startN],
&B_arr[startN*LDB],
ymm.packet[packetIndexOffset + startN],
remMask<EIGEN_AVX_MAX_NUM_ROW>(rem_)); remMask<EIGEN_AVX_MAX_NUM_ROW>(rem_));
} }
else { else {
@ -352,10 +349,8 @@ public:
} }
template <int64_t endN, int64_t counter, int64_t packetIndexOffset, bool remK, bool remM> template <int64_t endN, int64_t counter, int64_t packetIndexOffset, bool remK, bool remM>
static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_storeB(
aux_storeB(Scalar *B_arr, int64_t LDB, Scalar *B_arr, int64_t LDB, PacketBlock<vecHalf, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &ymm, int64_t rem_ = 0) {
PacketBlock<vecHalf,EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &ymm, int64_t rem_ = 0)
{
EIGEN_UNUSED_VARIABLE(B_arr); EIGEN_UNUSED_VARIABLE(B_arr);
EIGEN_UNUSED_VARIABLE(LDB); EIGEN_UNUSED_VARIABLE(LDB);
EIGEN_UNUSED_VARIABLE(ymm); EIGEN_UNUSED_VARIABLE(ymm);
@ -369,23 +364,19 @@ public:
* for(startN = 0; startN < endN; startN += EIGEN_AVX_MAX_NUM_ROW) * for(startN = 0; startN < endN; startN += EIGEN_AVX_MAX_NUM_ROW)
**/ **/
template <int64_t endN, int64_t counter, bool toTemp, bool remM> template <int64_t endN, int64_t counter, bool toTemp, bool remM>
static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_loadBBlock(
aux_loadBBlock(Scalar *B_arr, int64_t LDB, Scalar *B_temp, int64_t LDB_, Scalar *B_arr, int64_t LDB, Scalar *B_temp, int64_t LDB_,
PacketBlock<vecHalf,EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &ymm, PacketBlock<vecHalf, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &ymm, int64_t remM_ = 0) {
int64_t remM_ = 0) {
constexpr int64_t counterReverse = endN - counter; constexpr int64_t counterReverse = endN - counter;
constexpr int64_t startN = counterReverse; constexpr int64_t startN = counterReverse;
transB::template loadB<EIGEN_AVX_MAX_NUM_ROW, startN, false>(&B_temp[startN], LDB_, ymm); transB::template loadB<EIGEN_AVX_MAX_NUM_ROW, startN, false>(&B_temp[startN], LDB_, ymm);
aux_loadBBlock<endN, counter-EIGEN_AVX_MAX_NUM_ROW, toTemp, remM>( aux_loadBBlock<endN, counter - EIGEN_AVX_MAX_NUM_ROW, toTemp, remM>(B_arr, LDB, B_temp, LDB_, ymm, remM_);
B_arr, LDB, B_temp, LDB_, ymm, remM_);
} }
template <int64_t endN, int64_t counter, bool toTemp, bool remM> template <int64_t endN, int64_t counter, bool toTemp, bool remM>
static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_loadBBlock(
aux_loadBBlock(Scalar *B_arr, int64_t LDB, Scalar *B_temp, int64_t LDB_, Scalar *B_arr, int64_t LDB, Scalar *B_temp, int64_t LDB_,
PacketBlock<vecHalf,EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &ymm, PacketBlock<vecHalf, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &ymm, int64_t remM_ = 0) {
int64_t remM_ = 0)
{
EIGEN_UNUSED_VARIABLE(B_arr); EIGEN_UNUSED_VARIABLE(B_arr);
EIGEN_UNUSED_VARIABLE(LDB); EIGEN_UNUSED_VARIABLE(LDB);
EIGEN_UNUSED_VARIABLE(B_temp); EIGEN_UNUSED_VARIABLE(B_temp);
@ -394,7 +385,6 @@ public:
EIGEN_UNUSED_VARIABLE(remM_); EIGEN_UNUSED_VARIABLE(remM_);
} }
/** /**
* aux_storeBBlock * aux_storeBBlock
* *
@ -402,31 +392,26 @@ public:
* for(startN = 0; startN < endN; startN += EIGEN_AVX_MAX_NUM_ROW) * for(startN = 0; startN < endN; startN += EIGEN_AVX_MAX_NUM_ROW)
**/ **/
template <int64_t endN, int64_t counter, bool toTemp, bool remM, int64_t remK_> template <int64_t endN, int64_t counter, bool toTemp, bool remM, int64_t remK_>
static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_storeBBlock(
aux_storeBBlock(Scalar *B_arr, int64_t LDB, Scalar *B_temp, int64_t LDB_, Scalar *B_arr, int64_t LDB, Scalar *B_temp, int64_t LDB_,
PacketBlock<vecHalf,EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &ymm, PacketBlock<vecHalf, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &ymm, int64_t remM_ = 0) {
int64_t remM_ = 0) {
constexpr int64_t counterReverse = endN - counter; constexpr int64_t counterReverse = endN - counter;
constexpr int64_t startN = counterReverse; constexpr int64_t startN = counterReverse;
EIGEN_IF_CONSTEXPR(toTemp) { EIGEN_IF_CONSTEXPR(toTemp) {
transB::template storeB<EIGEN_AVX_MAX_NUM_ROW,startN, remK_ != 0, false>( transB::template storeB<EIGEN_AVX_MAX_NUM_ROW, startN, remK_ != 0, false>(&B_temp[startN], LDB_, ymm, remK_);
&B_temp[startN], LDB_, ymm, remK_);
} }
else { else {
transB::template storeB<std::min(EIGEN_AVX_MAX_NUM_ROW,endN),startN, false, remM>( transB::template storeB<std::min(EIGEN_AVX_MAX_NUM_ROW, endN), startN, false, remM>(&B_arr[0 + startN * LDB], LDB,
&B_arr[0 + startN*LDB], LDB, ymm, remM_); ymm, remM_);
} }
aux_storeBBlock<endN, counter-EIGEN_AVX_MAX_NUM_ROW, toTemp, remM, remK_>( aux_storeBBlock<endN, counter - EIGEN_AVX_MAX_NUM_ROW, toTemp, remM, remK_>(B_arr, LDB, B_temp, LDB_, ymm, remM_);
B_arr, LDB, B_temp, LDB_, ymm, remM_);
} }
template <int64_t endN, int64_t counter, bool toTemp, bool remM, int64_t remK_> template <int64_t endN, int64_t counter, bool toTemp, bool remM, int64_t remK_>
static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_storeBBlock(
aux_storeBBlock(Scalar *B_arr, int64_t LDB, Scalar *B_temp, int64_t LDB_, Scalar *B_arr, int64_t LDB, Scalar *B_temp, int64_t LDB_,
PacketBlock<vecHalf,EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &ymm, PacketBlock<vecHalf, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &ymm, int64_t remM_ = 0) {
int64_t remM_ = 0)
{
EIGEN_UNUSED_VARIABLE(B_arr); EIGEN_UNUSED_VARIABLE(B_arr);
EIGEN_UNUSED_VARIABLE(LDB); EIGEN_UNUSED_VARIABLE(LDB);
EIGEN_UNUSED_VARIABLE(B_temp); EIGEN_UNUSED_VARIABLE(B_temp);
@ -435,51 +420,43 @@ public:
EIGEN_UNUSED_VARIABLE(remM_); EIGEN_UNUSED_VARIABLE(remM_);
} }
/******************************************************** /********************************************************
* Wrappers for aux_XXXX to hide counter parameter * Wrappers for aux_XXXX to hide counter parameter
********************************************************/ ********************************************************/
template <int64_t endN, int64_t packetIndexOffset, bool remM> template <int64_t endN, int64_t packetIndexOffset, bool remM>
static EIGEN_ALWAYS_INLINE static EIGEN_ALWAYS_INLINE void loadB(Scalar *B_arr, int64_t LDB,
void loadB(Scalar *B_arr, int64_t LDB, PacketBlock<vecHalf, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &ymm,
PacketBlock<vecHalf,EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &ymm, int64_t remM_ = 0) { int64_t remM_ = 0) {
aux_loadB<endN, endN, packetIndexOffset, remM>(B_arr, LDB, ymm, remM_); aux_loadB<endN, endN, packetIndexOffset, remM>(B_arr, LDB, ymm, remM_);
} }
template <int64_t endN, int64_t packetIndexOffset, bool remK, bool remM> template <int64_t endN, int64_t packetIndexOffset, bool remK, bool remM>
static EIGEN_ALWAYS_INLINE static EIGEN_ALWAYS_INLINE void storeB(Scalar *B_arr, int64_t LDB,
void storeB(Scalar *B_arr, int64_t LDB, PacketBlock<vecHalf, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &ymm,
PacketBlock<vecHalf,EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &ymm, int64_t rem_ = 0) { int64_t rem_ = 0) {
aux_storeB<endN, endN, packetIndexOffset, remK, remM>(B_arr, LDB, ymm, rem_); aux_storeB<endN, endN, packetIndexOffset, remK, remM>(B_arr, LDB, ymm, rem_);
} }
template <int64_t unrollN, bool toTemp, bool remM> template <int64_t unrollN, bool toTemp, bool remM>
static EIGEN_ALWAYS_INLINE static EIGEN_ALWAYS_INLINE void loadBBlock(Scalar *B_arr, int64_t LDB, Scalar *B_temp, int64_t LDB_,
void loadBBlock(Scalar *B_arr, int64_t LDB, Scalar *B_temp, int64_t LDB_,
PacketBlock<vecHalf, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &ymm, PacketBlock<vecHalf, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &ymm,
int64_t remM_ = 0) { int64_t remM_ = 0) {
EIGEN_IF_CONSTEXPR(toTemp) { EIGEN_IF_CONSTEXPR(toTemp) { transB::template loadB<unrollN, 0, remM>(&B_arr[0], LDB, ymm, remM_); }
transB::template loadB<unrollN,0,remM>(&B_arr[0],LDB, ymm, remM_);
}
else { else {
aux_loadBBlock<unrollN, unrollN, toTemp, remM>( aux_loadBBlock<unrollN, unrollN, toTemp, remM>(B_arr, LDB, B_temp, LDB_, ymm, remM_);
B_arr, LDB, B_temp, LDB_, ymm, remM_);
} }
} }
template <int64_t unrollN, bool toTemp, bool remM, int64_t remK_> template <int64_t unrollN, bool toTemp, bool remM, int64_t remK_>
static EIGEN_ALWAYS_INLINE static EIGEN_ALWAYS_INLINE void storeBBlock(Scalar *B_arr, int64_t LDB, Scalar *B_temp, int64_t LDB_,
void storeBBlock(Scalar *B_arr, int64_t LDB, Scalar *B_temp, int64_t LDB_,
PacketBlock<vecHalf, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &ymm, PacketBlock<vecHalf, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &ymm,
int64_t remM_ = 0) { int64_t remM_ = 0) {
aux_storeBBlock<unrollN, unrollN, toTemp, remM, remK_>( aux_storeBBlock<unrollN, unrollN, toTemp, remM, remK_>(B_arr, LDB, B_temp, LDB_, ymm, remM_);
B_arr, LDB, B_temp, LDB_, ymm, remM_);
} }
template <int64_t packetIndexOffset> template <int64_t packetIndexOffset>
static EIGEN_ALWAYS_INLINE static EIGEN_ALWAYS_INLINE void transposeLxL(PacketBlock<vecHalf, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &ymm) {
void transposeLxL(PacketBlock<vecHalf,EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &ymm){
// Note: this assumes EIGEN_AVX_MAX_NUM_ROW = 8. Unrolls should be adjusted // Note: this assumes EIGEN_AVX_MAX_NUM_ROW = 8. Unrolls should be adjusted
// accordingly if EIGEN_AVX_MAX_NUM_ROW is smaller. // accordingly if EIGEN_AVX_MAX_NUM_ROW is smaller.
PacketBlock<vecHalf, EIGEN_AVX_MAX_NUM_ROW> r; PacketBlock<vecHalf, EIGEN_AVX_MAX_NUM_ROW> r;
@ -503,9 +480,9 @@ public:
} }
template <int64_t unrollN, bool toTemp, bool remM> template <int64_t unrollN, bool toTemp, bool remM>
static EIGEN_ALWAYS_INLINE static EIGEN_ALWAYS_INLINE void transB_kernel(Scalar *B_arr, int64_t LDB, Scalar *B_temp, int64_t LDB_,
void transB_kernel(Scalar *B_arr, int64_t LDB, Scalar *B_temp, int64_t LDB_, PacketBlock<vecHalf, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &ymm,
PacketBlock<vecHalf,EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &ymm, int64_t remM_ = 0) { int64_t remM_ = 0) {
constexpr int64_t U3 = PacketSize * 3; constexpr int64_t U3 = PacketSize * 3;
constexpr int64_t U2 = PacketSize * 2; constexpr int64_t U2 = PacketSize * 2;
constexpr int64_t U1 = PacketSize * 1; constexpr int64_t U1 = PacketSize * 1;
@ -526,11 +503,13 @@ public:
transB::template storeBBlock<maxUBlock, toTemp, remM, 0>(B_arr, LDB, B_temp, LDB_, ymm, remM_); transB::template storeBBlock<maxUBlock, toTemp, remM, 0>(B_arr, LDB, B_temp, LDB_, ymm, remM_);
EIGEN_IF_CONSTEXPR(maxUBlock < U3) { EIGEN_IF_CONSTEXPR(maxUBlock < U3) {
transB::template loadBBlock<maxUBlock,toTemp, remM>(&B_arr[maxUBlock*LDB], LDB, &B_temp[maxUBlock], LDB_, ymm, remM_); transB::template loadBBlock<maxUBlock, toTemp, remM>(&B_arr[maxUBlock * LDB], LDB, &B_temp[maxUBlock], LDB_,
ymm, remM_);
transB::template transposeLxL<0 * EIGEN_AVX_MAX_NUM_ROW>(ymm); transB::template transposeLxL<0 * EIGEN_AVX_MAX_NUM_ROW>(ymm);
transB::template transposeLxL<1 * EIGEN_AVX_MAX_NUM_ROW>(ymm); transB::template transposeLxL<1 * EIGEN_AVX_MAX_NUM_ROW>(ymm);
transB::template transposeLxL<2 * EIGEN_AVX_MAX_NUM_ROW>(ymm); transB::template transposeLxL<2 * EIGEN_AVX_MAX_NUM_ROW>(ymm);
transB::template storeBBlock<maxUBlock,toTemp, remM,0>(&B_arr[maxUBlock*LDB], LDB, &B_temp[maxUBlock], LDB_, ymm, remM_); transB::template storeBBlock<maxUBlock, toTemp, remM, 0>(&B_arr[maxUBlock * LDB], LDB, &B_temp[maxUBlock], LDB_,
ymm, remM_);
} }
} }
else EIGEN_IF_CONSTEXPR(unrollN == U2) { else EIGEN_IF_CONSTEXPR(unrollN == U2) {
@ -543,20 +522,18 @@ public:
transB::template storeBBlock<maxUBlock, toTemp, remM, 0>(B_arr, LDB, B_temp, LDB_, ymm, remM_); transB::template storeBBlock<maxUBlock, toTemp, remM, 0>(B_arr, LDB, B_temp, LDB_, ymm, remM_);
EIGEN_IF_CONSTEXPR(maxUBlock < U2) { EIGEN_IF_CONSTEXPR(maxUBlock < U2) {
transB::template loadBBlock<EIGEN_AVX_MAX_NUM_ROW,toTemp, remM>( transB::template loadBBlock<EIGEN_AVX_MAX_NUM_ROW, toTemp, remM>(&B_arr[maxUBlock * LDB], LDB,
&B_arr[maxUBlock*LDB], LDB, &B_temp[maxUBlock], LDB_, ymm, remM_); &B_temp[maxUBlock], LDB_, ymm, remM_);
transB::template transposeLxL<0>(ymm); transB::template transposeLxL<0>(ymm);
transB::template storeBBlock<EIGEN_AVX_MAX_NUM_ROW,toTemp,remM,0>( transB::template storeBBlock<EIGEN_AVX_MAX_NUM_ROW, toTemp, remM, 0>(&B_arr[maxUBlock * LDB], LDB,
&B_arr[maxUBlock*LDB], LDB, &B_temp[maxUBlock], LDB_, ymm, remM_); &B_temp[maxUBlock], LDB_, ymm, remM_);
} }
} }
else EIGEN_IF_CONSTEXPR(unrollN == U1) { else EIGEN_IF_CONSTEXPR(unrollN == U1) {
// load LxU1 B col major, transpose LxU1 row major // load LxU1 B col major, transpose LxU1 row major
transB::template loadBBlock<U1, toTemp, remM>(B_arr, LDB, B_temp, LDB_, ymm, remM_); transB::template loadBBlock<U1, toTemp, remM>(B_arr, LDB, B_temp, LDB_, ymm, remM_);
transB::template transposeLxL<0>(ymm); transB::template transposeLxL<0>(ymm);
EIGEN_IF_CONSTEXPR(EIGEN_AVX_MAX_NUM_ROW < U1) { EIGEN_IF_CONSTEXPR(EIGEN_AVX_MAX_NUM_ROW < U1) { transB::template transposeLxL<1 * EIGEN_AVX_MAX_NUM_ROW>(ymm); }
transB::template transposeLxL<1*EIGEN_AVX_MAX_NUM_ROW>(ymm);
}
transB::template storeBBlock<U1, toTemp, remM, 0>(B_arr, LDB, B_temp, LDB_, ymm, remM_); transB::template storeBBlock<U1, toTemp, remM, 0>(B_arr, LDB, B_temp, LDB_, ymm, remM_);
} }
else EIGEN_IF_CONSTEXPR(unrollN == 8 && U1 > 8) { else EIGEN_IF_CONSTEXPR(unrollN == 8 && U1 > 8) {
@ -601,9 +578,7 @@ public:
template <typename Scalar> template <typename Scalar>
class trsm { class trsm {
public: public:
using vec = typename std::conditional<std::is_same<Scalar, float>::value, using vec = typename std::conditional<std::is_same<Scalar, float>::value, vecFullFloat, vecFullDouble>::type;
vecFullFloat,
vecFullDouble>::type;
static constexpr int64_t PacketSize = packet_traits<Scalar>::size; static constexpr int64_t PacketSize = packet_traits<Scalar>::size;
/*********************************** /***********************************
@ -622,9 +597,8 @@ public:
* for(startK = 0; startK < endK; startK++) * for(startK = 0; startK < endK; startK++)
**/ **/
template <bool isFWDSolve, int64_t endM, int64_t endK, int64_t counter, bool krem> template <bool isFWDSolve, int64_t endM, int64_t endK, int64_t counter, bool krem>
static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_loadRHS(
aux_loadRHS(Scalar* B_arr, int64_t LDB, PacketBlock<vec,EIGEN_AVX_MAX_NUM_ACC> &RHSInPacket, int64_t rem = 0) { Scalar *B_arr, int64_t LDB, PacketBlock<vec, EIGEN_AVX_MAX_NUM_ACC> &RHSInPacket, int64_t rem = 0) {
constexpr int64_t counterReverse = endM * endK - counter; constexpr int64_t counterReverse = endM * endK - counter;
constexpr int64_t startM = counterReverse / (endK); constexpr int64_t startM = counterReverse / (endK);
constexpr int64_t startK = counterReverse % endK; constexpr int64_t startK = counterReverse % endK;
@ -642,9 +616,8 @@ public:
} }
template <bool isFWDSolve, int64_t endM, int64_t endK, int64_t counter, bool krem> template <bool isFWDSolve, int64_t endM, int64_t endK, int64_t counter, bool krem>
static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_loadRHS(
aux_loadRHS(Scalar* B_arr, int64_t LDB, PacketBlock<vec,EIGEN_AVX_MAX_NUM_ACC> &RHSInPacket, int64_t rem = 0) Scalar *B_arr, int64_t LDB, PacketBlock<vec, EIGEN_AVX_MAX_NUM_ACC> &RHSInPacket, int64_t rem = 0) {
{
EIGEN_UNUSED_VARIABLE(B_arr); EIGEN_UNUSED_VARIABLE(B_arr);
EIGEN_UNUSED_VARIABLE(LDB); EIGEN_UNUSED_VARIABLE(LDB);
EIGEN_UNUSED_VARIABLE(RHSInPacket); EIGEN_UNUSED_VARIABLE(RHSInPacket);
@ -659,8 +632,8 @@ public:
* for(startK = 0; startK < endK; startK++) * for(startK = 0; startK < endK; startK++)
**/ **/
template <bool isFWDSolve, int64_t endM, int64_t endK, int64_t counter, bool krem> template <bool isFWDSolve, int64_t endM, int64_t endK, int64_t counter, bool krem>
static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_storeRHS(
aux_storeRHS(Scalar* B_arr, int64_t LDB, PacketBlock<vec,EIGEN_AVX_MAX_NUM_ACC> &RHSInPacket, int64_t rem = 0) { Scalar *B_arr, int64_t LDB, PacketBlock<vec, EIGEN_AVX_MAX_NUM_ACC> &RHSInPacket, int64_t rem = 0) {
constexpr int64_t counterReverse = endM * endK - counter; constexpr int64_t counterReverse = endM * endK - counter;
constexpr int64_t startM = counterReverse / (endK); constexpr int64_t startM = counterReverse / (endK);
constexpr int64_t startK = counterReverse % endK; constexpr int64_t startK = counterReverse % endK;
@ -678,9 +651,8 @@ public:
} }
template <bool isFWDSolve, int64_t endM, int64_t endK, int64_t counter, bool krem> template <bool isFWDSolve, int64_t endM, int64_t endK, int64_t counter, bool krem>
static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_storeRHS(
aux_storeRHS(Scalar* B_arr, int64_t LDB, PacketBlock<vec,EIGEN_AVX_MAX_NUM_ACC> &RHSInPacket, int64_t rem = 0) Scalar *B_arr, int64_t LDB, PacketBlock<vec, EIGEN_AVX_MAX_NUM_ACC> &RHSInPacket, int64_t rem = 0) {
{
EIGEN_UNUSED_VARIABLE(B_arr); EIGEN_UNUSED_VARIABLE(B_arr);
EIGEN_UNUSED_VARIABLE(LDB); EIGEN_UNUSED_VARIABLE(LDB);
EIGEN_UNUSED_VARIABLE(RHSInPacket); EIGEN_UNUSED_VARIABLE(RHSInPacket);
@ -696,8 +668,8 @@ public:
* for(startK = 0; startK < endK; startK++) * for(startK = 0; startK < endK; startK++)
**/ **/
template <int64_t currM, int64_t endK, int64_t counter> template <int64_t currM, int64_t endK, int64_t counter>
static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0 && currM >= 0)> static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0 && currM >= 0)> aux_divRHSByDiag(
aux_divRHSByDiag(PacketBlock<vec,EIGEN_AVX_MAX_NUM_ACC> &RHSInPacket, PacketBlock<vec,EIGEN_AVX_MAX_NUM_ROW> &AInPacket) { PacketBlock<vec, EIGEN_AVX_MAX_NUM_ACC> &RHSInPacket, PacketBlock<vec, EIGEN_AVX_MAX_NUM_ROW> &AInPacket) {
constexpr int64_t counterReverse = endK - counter; constexpr int64_t counterReverse = endK - counter;
constexpr int64_t startK = counterReverse; constexpr int64_t startK = counterReverse;
@ -707,8 +679,8 @@ public:
} }
template <int64_t currM, int64_t endK, int64_t counter> template <int64_t currM, int64_t endK, int64_t counter>
static EIGEN_ALWAYS_INLINE std::enable_if_t<!(counter > 0 && currM >= 0)> static EIGEN_ALWAYS_INLINE std::enable_if_t<!(counter > 0 && currM >= 0)> aux_divRHSByDiag(
aux_divRHSByDiag(PacketBlock<vec,EIGEN_AVX_MAX_NUM_ACC> &RHSInPacket, PacketBlock<vec,EIGEN_AVX_MAX_NUM_ROW> &AInPacket) { PacketBlock<vec, EIGEN_AVX_MAX_NUM_ACC> &RHSInPacket, PacketBlock<vec, EIGEN_AVX_MAX_NUM_ROW> &AInPacket) {
EIGEN_UNUSED_VARIABLE(RHSInPacket); EIGEN_UNUSED_VARIABLE(RHSInPacket);
EIGEN_UNUSED_VARIABLE(AInPacket); EIGEN_UNUSED_VARIABLE(AInPacket);
} }
@ -720,10 +692,11 @@ public:
* for(startM = initM; startM < endM; startM++) * for(startM = initM; startM < endM; startM++)
* for(startK = 0; startK < endK; startK++) * for(startK = 0; startK < endK; startK++)
**/ **/
template<bool isARowMajor, bool isFWDSolve, bool isUnitDiag, int64_t initM, int64_t endM, int64_t endK, int64_t counter, int64_t currentM> template <bool isARowMajor, bool isFWDSolve, bool isUnitDiag, int64_t initM, int64_t endM, int64_t endK,
static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> int64_t counter, int64_t currentM>
aux_updateRHS(Scalar *A_arr, int64_t LDA, PacketBlock<vec,EIGEN_AVX_MAX_NUM_ACC> &RHSInPacket, PacketBlock<vec,EIGEN_AVX_MAX_NUM_ROW> &AInPacket) { static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_updateRHS(
Scalar *A_arr, int64_t LDA, PacketBlock<vec, EIGEN_AVX_MAX_NUM_ACC> &RHSInPacket,
PacketBlock<vec, EIGEN_AVX_MAX_NUM_ROW> &AInPacket) {
constexpr int64_t counterReverse = (endM - initM) * endK - counter; constexpr int64_t counterReverse = (endM - initM) * endK - counter;
constexpr int64_t startM = initM + counterReverse / (endK); constexpr int64_t startM = initM + counterReverse / (endK);
constexpr int64_t startK = counterReverse % endK; constexpr int64_t startK = counterReverse % endK;
@ -732,8 +705,7 @@ public:
constexpr int64_t packetIndex = startM * endK + startK; constexpr int64_t packetIndex = startM * endK + startK;
EIGEN_IF_CONSTEXPR(currentM > 0) { EIGEN_IF_CONSTEXPR(currentM > 0) {
RHSInPacket.packet[packetIndex] = RHSInPacket.packet[packetIndex] =
pnmadd(AInPacket.packet[startM], pnmadd(AInPacket.packet[startM], RHSInPacket.packet[(currentM - 1) * endK + startK],
RHSInPacket.packet[(currentM-1)*endK+startK],
RHSInPacket.packet[packetIndex]); RHSInPacket.packet[packetIndex]);
} }
@ -744,24 +716,25 @@ public:
// This will be used in divRHSByDiag // This will be used in divRHSByDiag
EIGEN_IF_CONSTEXPR(isFWDSolve) EIGEN_IF_CONSTEXPR(isFWDSolve)
AInPacket.packet[currentM] = pset1<vec>(Scalar(1) / A_arr[idA<isARowMajor>(currentM, currentM, LDA)]); AInPacket.packet[currentM] = pset1<vec>(Scalar(1) / A_arr[idA<isARowMajor>(currentM, currentM, LDA)]);
else else AInPacket.packet[currentM] = pset1<vec>(Scalar(1) / A_arr[idA<isARowMajor>(-currentM, -currentM, LDA)]);
AInPacket.packet[currentM] = pset1<vec>(Scalar(1)/A_arr[idA<isARowMajor>(-currentM,-currentM,LDA)]);
} }
else { else {
// Broadcast next off diagonal element of A // Broadcast next off diagonal element of A
EIGEN_IF_CONSTEXPR(isFWDSolve) EIGEN_IF_CONSTEXPR(isFWDSolve)
AInPacket.packet[startM] = pset1<vec>(A_arr[idA<isARowMajor>(startM, currentM, LDA)]); AInPacket.packet[startM] = pset1<vec>(A_arr[idA<isARowMajor>(startM, currentM, LDA)]);
else else AInPacket.packet[startM] = pset1<vec>(A_arr[idA<isARowMajor>(-startM, -currentM, LDA)]);
AInPacket.packet[startM] = pset1<vec>(A_arr[idA<isARowMajor>(-startM,-currentM,LDA)]);
} }
} }
aux_updateRHS<isARowMajor, isFWDSolve, isUnitDiag, initM, endM, endK, counter - 1, currentM>(A_arr, LDA, RHSInPacket, AInPacket); aux_updateRHS<isARowMajor, isFWDSolve, isUnitDiag, initM, endM, endK, counter - 1, currentM>(
A_arr, LDA, RHSInPacket, AInPacket);
} }
template<bool isARowMajor, bool isFWDSolve, bool isUnitDiag, int64_t initM, int64_t endM, int64_t endK, int64_t counter, int64_t currentM> template <bool isARowMajor, bool isFWDSolve, bool isUnitDiag, int64_t initM, int64_t endM, int64_t endK,
static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> int64_t counter, int64_t currentM>
aux_updateRHS(Scalar *A_arr, int64_t LDA, PacketBlock<vec,EIGEN_AVX_MAX_NUM_ACC> &RHSInPacket, PacketBlock<vec,EIGEN_AVX_MAX_NUM_ROW> &AInPacket) { static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_updateRHS(
Scalar *A_arr, int64_t LDA, PacketBlock<vec, EIGEN_AVX_MAX_NUM_ACC> &RHSInPacket,
PacketBlock<vec, EIGEN_AVX_MAX_NUM_ROW> &AInPacket) {
EIGEN_UNUSED_VARIABLE(A_arr); EIGEN_UNUSED_VARIABLE(A_arr);
EIGEN_UNUSED_VARIABLE(LDA); EIGEN_UNUSED_VARIABLE(LDA);
EIGEN_UNUSED_VARIABLE(RHSInPacket); EIGEN_UNUSED_VARIABLE(RHSInPacket);
@ -775,9 +748,9 @@ public:
* for(startM = 0; startM < endM; startM++) * for(startM = 0; startM < endM; startM++)
**/ **/
template <bool isARowMajor, bool isFWDSolve, bool isUnitDiag, int64_t endM, int64_t counter, int64_t numK> template <bool isARowMajor, bool isFWDSolve, bool isUnitDiag, int64_t endM, int64_t counter, int64_t numK>
static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_triSolveMicroKernel(
aux_triSolveMicroKernel(Scalar *A_arr, int64_t LDA, PacketBlock<vec,EIGEN_AVX_MAX_NUM_ACC> &RHSInPacket, PacketBlock<vec,EIGEN_AVX_MAX_NUM_ROW> &AInPacket) { Scalar *A_arr, int64_t LDA, PacketBlock<vec, EIGEN_AVX_MAX_NUM_ACC> &RHSInPacket,
PacketBlock<vec, EIGEN_AVX_MAX_NUM_ROW> &AInPacket) {
constexpr int64_t counterReverse = endM - counter; constexpr int64_t counterReverse = endM - counter;
constexpr int64_t startM = counterReverse; constexpr int64_t startM = counterReverse;
@ -794,21 +767,21 @@ public:
// After division, the rhs corresponding to subsequent rows of A can be partially updated // After division, the rhs corresponding to subsequent rows of A can be partially updated
// We also broadcast the reciprocal of the next diagonal to AInPacket.packet[currentM] (if needed) // We also broadcast the reciprocal of the next diagonal to AInPacket.packet[currentM] (if needed)
// to be used in the next iteration. // to be used in the next iteration.
trsm::template trsm::template updateRHS<isARowMajor, isFWDSolve, isUnitDiag, startM, endM, numK, currentM>(A_arr, LDA, RHSInPacket,
updateRHS<isARowMajor, isFWDSolve, isUnitDiag, startM, endM, numK, currentM>( AInPacket);
A_arr, LDA, RHSInPacket, AInPacket);
// Handle division for the RHS corresponding to the final row of A. // Handle division for the RHS corresponding to the final row of A.
EIGEN_IF_CONSTEXPR(!isUnitDiag && startM == endM - 1) EIGEN_IF_CONSTEXPR(!isUnitDiag && startM == endM - 1)
trsm::template divRHSByDiag<startM, numK>(RHSInPacket, AInPacket); trsm::template divRHSByDiag<startM, numK>(RHSInPacket, AInPacket);
aux_triSolveMicroKernel<isARowMajor, isFWDSolve, isUnitDiag, endM, counter - 1, numK>(A_arr, LDA, RHSInPacket, AInPacket); aux_triSolveMicroKernel<isARowMajor, isFWDSolve, isUnitDiag, endM, counter - 1, numK>(A_arr, LDA, RHSInPacket,
AInPacket);
} }
template <bool isARowMajor, bool isFWDSolve, bool isUnitDiag, int64_t endM, int64_t counter, int64_t numK> template <bool isARowMajor, bool isFWDSolve, bool isUnitDiag, int64_t endM, int64_t counter, int64_t numK>
static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_triSolveMicroKernel(
aux_triSolveMicroKernel(Scalar *A_arr, int64_t LDA, PacketBlock<vec,EIGEN_AVX_MAX_NUM_ACC> &RHSInPacket, PacketBlock<vec,EIGEN_AVX_MAX_NUM_ROW> &AInPacket) Scalar *A_arr, int64_t LDA, PacketBlock<vec, EIGEN_AVX_MAX_NUM_ACC> &RHSInPacket,
{ PacketBlock<vec, EIGEN_AVX_MAX_NUM_ROW> &AInPacket) {
EIGEN_UNUSED_VARIABLE(A_arr); EIGEN_UNUSED_VARIABLE(A_arr);
EIGEN_UNUSED_VARIABLE(LDA); EIGEN_UNUSED_VARIABLE(LDA);
EIGEN_UNUSED_VARIABLE(RHSInPacket); EIGEN_UNUSED_VARIABLE(RHSInPacket);
@ -824,8 +797,8 @@ public:
* Masked loads are used for cases where endK is not a multiple of PacketSize * Masked loads are used for cases where endK is not a multiple of PacketSize
*/ */
template <bool isFWDSolve, int64_t endM, int64_t endK, bool krem = false> template <bool isFWDSolve, int64_t endM, int64_t endK, bool krem = false>
static EIGEN_ALWAYS_INLINE static EIGEN_ALWAYS_INLINE void loadRHS(Scalar *B_arr, int64_t LDB,
void loadRHS(Scalar* B_arr, int64_t LDB, PacketBlock<vec,EIGEN_AVX_MAX_NUM_ACC> &RHSInPacket, int64_t rem = 0) { PacketBlock<vec, EIGEN_AVX_MAX_NUM_ACC> &RHSInPacket, int64_t rem = 0) {
aux_loadRHS<isFWDSolve, endM, endK, endM * endK, krem>(B_arr, LDB, RHSInPacket, rem); aux_loadRHS<isFWDSolve, endM, endK, endM * endK, krem>(B_arr, LDB, RHSInPacket, rem);
} }
@ -834,8 +807,8 @@ public:
* Masked loads are used for cases where endK is not a multiple of PacketSize * Masked loads are used for cases where endK is not a multiple of PacketSize
*/ */
template <bool isFWDSolve, int64_t endM, int64_t endK, bool krem = false> template <bool isFWDSolve, int64_t endM, int64_t endK, bool krem = false>
static EIGEN_ALWAYS_INLINE static EIGEN_ALWAYS_INLINE void storeRHS(Scalar *B_arr, int64_t LDB,
void storeRHS(Scalar* B_arr, int64_t LDB, PacketBlock<vec,EIGEN_AVX_MAX_NUM_ACC> &RHSInPacket, int64_t rem = 0) { PacketBlock<vec, EIGEN_AVX_MAX_NUM_ACC> &RHSInPacket, int64_t rem = 0) {
aux_storeRHS<isFWDSolve, endM, endK, endM * endK, krem>(B_arr, LDB, RHSInPacket, rem); aux_storeRHS<isFWDSolve, endM, endK, endM * endK, krem>(B_arr, LDB, RHSInPacket, rem);
} }
@ -843,8 +816,8 @@ public:
* Only used if Triangular matrix has non-unit diagonal values * Only used if Triangular matrix has non-unit diagonal values
*/ */
template <int64_t currM, int64_t endK> template <int64_t currM, int64_t endK>
static EIGEN_ALWAYS_INLINE static EIGEN_ALWAYS_INLINE void divRHSByDiag(PacketBlock<vec, EIGEN_AVX_MAX_NUM_ACC> &RHSInPacket,
void divRHSByDiag(PacketBlock<vec,EIGEN_AVX_MAX_NUM_ACC> &RHSInPacket, PacketBlock<vec,EIGEN_AVX_MAX_NUM_ROW> &AInPacket) { PacketBlock<vec, EIGEN_AVX_MAX_NUM_ROW> &AInPacket) {
aux_divRHSByDiag<currM, endK, endK>(RHSInPacket, AInPacket); aux_divRHSByDiag<currM, endK, endK>(RHSInPacket, AInPacket);
} }
@ -852,9 +825,11 @@ public:
* Update right-hand sides (stored in avx registers) * Update right-hand sides (stored in avx registers)
* Traversing along the column A_{i,currentM}, where currentM <= i <= endM, and broadcasting each value to AInPacket. * Traversing along the column A_{i,currentM}, where currentM <= i <= endM, and broadcasting each value to AInPacket.
**/ **/
template<bool isARowMajor, bool isFWDSolve, bool isUnitDiag, int64_t startM, int64_t endM, int64_t endK, int64_t currentM> template <bool isARowMajor, bool isFWDSolve, bool isUnitDiag, int64_t startM, int64_t endM, int64_t endK,
static EIGEN_ALWAYS_INLINE int64_t currentM>
void updateRHS(Scalar *A_arr, int64_t LDA, PacketBlock<vec,EIGEN_AVX_MAX_NUM_ACC> &RHSInPacket, PacketBlock<vec,EIGEN_AVX_MAX_NUM_ROW> &AInPacket) { static EIGEN_ALWAYS_INLINE void updateRHS(Scalar *A_arr, int64_t LDA,
PacketBlock<vec, EIGEN_AVX_MAX_NUM_ACC> &RHSInPacket,
PacketBlock<vec, EIGEN_AVX_MAX_NUM_ROW> &AInPacket) {
aux_updateRHS<isARowMajor, isFWDSolve, isUnitDiag, startM, endM, endK, (endM - startM) * endK, currentM>( aux_updateRHS<isARowMajor, isFWDSolve, isUnitDiag, startM, endM, endK, (endM - startM) * endK, currentM>(
A_arr, LDA, RHSInPacket, AInPacket); A_arr, LDA, RHSInPacket, AInPacket);
} }
@ -866,11 +841,11 @@ public:
* isUnitDiag: true => triangular matrix has unit diagonal. * isUnitDiag: true => triangular matrix has unit diagonal.
*/ */
template <bool isARowMajor, bool isFWDSolve, bool isUnitDiag, int64_t endM, int64_t numK> template <bool isARowMajor, bool isFWDSolve, bool isUnitDiag, int64_t endM, int64_t numK>
static EIGEN_ALWAYS_INLINE static EIGEN_ALWAYS_INLINE void triSolveMicroKernel(Scalar *A_arr, int64_t LDA,
void triSolveMicroKernel(Scalar *A_arr, int64_t LDA, PacketBlock<vec,EIGEN_AVX_MAX_NUM_ACC> &RHSInPacket, PacketBlock<vec,EIGEN_AVX_MAX_NUM_ROW> &AInPacket) { PacketBlock<vec, EIGEN_AVX_MAX_NUM_ACC> &RHSInPacket,
PacketBlock<vec, EIGEN_AVX_MAX_NUM_ROW> &AInPacket) {
static_assert(numK >= 1 && numK <= 3, "numK out of range"); static_assert(numK >= 1 && numK <= 3, "numK out of range");
aux_triSolveMicroKernel<isARowMajor, isFWDSolve, isUnitDiag, endM, endM, numK>( aux_triSolveMicroKernel<isARowMajor, isFWDSolve, isUnitDiag, endM, endM, numK>(A_arr, LDA, RHSInPacket, AInPacket);
A_arr, LDA, RHSInPacket, AInPacket);
} }
}; };
@ -902,8 +877,8 @@ public:
* for(startN = 0; startN < endN; startN++) * for(startN = 0; startN < endN; startN++)
**/ **/
template <int64_t endM, int64_t endN, int64_t counter> template <int64_t endM, int64_t endN, int64_t counter>
static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_setzero(
aux_setzero(PacketBlock<vec,EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm) { PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm) {
constexpr int64_t counterReverse = endM * endN - counter; constexpr int64_t counterReverse = endM * endN - counter;
constexpr int64_t startM = counterReverse / (endN); constexpr int64_t startM = counterReverse / (endN);
constexpr int64_t startN = counterReverse % endN; constexpr int64_t startN = counterReverse % endN;
@ -913,9 +888,8 @@ public:
} }
template <int64_t endM, int64_t endN, int64_t counter> template <int64_t endM, int64_t endN, int64_t counter>
static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_setzero(
aux_setzero(PacketBlock<vec,EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm) PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm) {
{
EIGEN_UNUSED_VARIABLE(zmm); EIGEN_UNUSED_VARIABLE(zmm);
} }
@ -927,8 +901,8 @@ public:
* for(startN = 0; startN < endN; startN++) * for(startN = 0; startN < endN; startN++)
**/ **/
template <int64_t endM, int64_t endN, int64_t counter, bool rem> template <int64_t endM, int64_t endN, int64_t counter, bool rem>
static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_updateC(
aux_updateC(Scalar *C_arr, int64_t LDC, PacketBlock<vec,EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm, int64_t rem_ = 0) { Scalar *C_arr, int64_t LDC, PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm, int64_t rem_ = 0) {
EIGEN_UNUSED_VARIABLE(rem_); EIGEN_UNUSED_VARIABLE(rem_);
constexpr int64_t counterReverse = endM * endN - counter; constexpr int64_t counterReverse = endM * endN - counter;
constexpr int64_t startM = counterReverse / (endN); constexpr int64_t startM = counterReverse / (endN);
@ -937,18 +911,15 @@ public:
EIGEN_IF_CONSTEXPR(rem) EIGEN_IF_CONSTEXPR(rem)
zmm.packet[startN * endM + startM] = zmm.packet[startN * endM + startM] =
padd(ploadu<vec>(&C_arr[(startN)*LDC + startM * PacketSize], remMask<PacketSize>(rem_)), padd(ploadu<vec>(&C_arr[(startN)*LDC + startM * PacketSize], remMask<PacketSize>(rem_)),
zmm.packet[startN*endM + startM], zmm.packet[startN * endM + startM], remMask<PacketSize>(rem_));
remMask<PacketSize>(rem_)); else zmm.packet[startN * endM + startM] =
else
zmm.packet[startN*endM + startM] =
padd(ploadu<vec>(&C_arr[(startN)*LDC + startM * PacketSize]), zmm.packet[startN * endM + startM]); padd(ploadu<vec>(&C_arr[(startN)*LDC + startM * PacketSize]), zmm.packet[startN * endM + startM]);
aux_updateC<endM, endN, counter - 1, rem>(C_arr, LDC, zmm, rem_); aux_updateC<endM, endN, counter - 1, rem>(C_arr, LDC, zmm, rem_);
} }
template <int64_t endM, int64_t endN, int64_t counter, bool rem> template <int64_t endM, int64_t endN, int64_t counter, bool rem>
static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_updateC(
aux_updateC(Scalar *C_arr, int64_t LDC, PacketBlock<vec,EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm, int64_t rem_ = 0) Scalar *C_arr, int64_t LDC, PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm, int64_t rem_ = 0) {
{
EIGEN_UNUSED_VARIABLE(C_arr); EIGEN_UNUSED_VARIABLE(C_arr);
EIGEN_UNUSED_VARIABLE(LDC); EIGEN_UNUSED_VARIABLE(LDC);
EIGEN_UNUSED_VARIABLE(zmm); EIGEN_UNUSED_VARIABLE(zmm);
@ -963,24 +934,23 @@ public:
* for(startN = 0; startN < endN; startN++) * for(startN = 0; startN < endN; startN++)
**/ **/
template <int64_t endM, int64_t endN, int64_t counter, bool rem> template <int64_t endM, int64_t endN, int64_t counter, bool rem>
static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_storeC(
aux_storeC(Scalar *C_arr, int64_t LDC, PacketBlock<vec,EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm, int64_t rem_ = 0) { Scalar *C_arr, int64_t LDC, PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm, int64_t rem_ = 0) {
EIGEN_UNUSED_VARIABLE(rem_); EIGEN_UNUSED_VARIABLE(rem_);
constexpr int64_t counterReverse = endM * endN - counter; constexpr int64_t counterReverse = endM * endN - counter;
constexpr int64_t startM = counterReverse / (endN); constexpr int64_t startM = counterReverse / (endN);
constexpr int64_t startN = counterReverse % endN; constexpr int64_t startN = counterReverse % endN;
EIGEN_IF_CONSTEXPR(rem) EIGEN_IF_CONSTEXPR(rem)
pstoreu<Scalar>(&C_arr[(startN) * LDC + startM*PacketSize], zmm.packet[startN*endM + startM], remMask<PacketSize>(rem_)); pstoreu<Scalar>(&C_arr[(startN)*LDC + startM * PacketSize], zmm.packet[startN * endM + startM],
else remMask<PacketSize>(rem_));
pstoreu<Scalar>(&C_arr[(startN) * LDC + startM*PacketSize], zmm.packet[startN*endM + startM]); else pstoreu<Scalar>(&C_arr[(startN)*LDC + startM * PacketSize], zmm.packet[startN * endM + startM]);
aux_storeC<endM, endN, counter - 1, rem>(C_arr, LDC, zmm, rem_); aux_storeC<endM, endN, counter - 1, rem>(C_arr, LDC, zmm, rem_);
} }
template <int64_t endM, int64_t endN, int64_t counter, bool rem> template <int64_t endM, int64_t endN, int64_t counter, bool rem>
static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_storeC(
aux_storeC(Scalar *C_arr, int64_t LDC, PacketBlock<vec,EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm, int64_t rem_ = 0) Scalar *C_arr, int64_t LDC, PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm, int64_t rem_ = 0) {
{
EIGEN_UNUSED_VARIABLE(C_arr); EIGEN_UNUSED_VARIABLE(C_arr);
EIGEN_UNUSED_VARIABLE(LDC); EIGEN_UNUSED_VARIABLE(LDC);
EIGEN_UNUSED_VARIABLE(zmm); EIGEN_UNUSED_VARIABLE(zmm);
@ -994,8 +964,8 @@ public:
* for(startL = 0; startL < endL; startL++) * for(startL = 0; startL < endL; startL++)
**/ **/
template <int64_t unrollM, int64_t unrollN, int64_t endL, int64_t counter, bool rem> template <int64_t unrollM, int64_t unrollN, int64_t endL, int64_t counter, bool rem>
static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_startLoadB(
aux_startLoadB(Scalar *B_t, int64_t LDB, PacketBlock<vec,EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm, int64_t rem_ = 0) { Scalar *B_t, int64_t LDB, PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm, int64_t rem_ = 0) {
EIGEN_UNUSED_VARIABLE(rem_); EIGEN_UNUSED_VARIABLE(rem_);
constexpr int64_t counterReverse = endL - counter; constexpr int64_t counterReverse = endL - counter;
constexpr int64_t startL = counterReverse; constexpr int64_t startL = counterReverse;
@ -1003,18 +973,15 @@ public:
EIGEN_IF_CONSTEXPR(rem) EIGEN_IF_CONSTEXPR(rem)
zmm.packet[unrollM * unrollN + startL] = zmm.packet[unrollM * unrollN + startL] =
ploadu<vec>(&B_t[(startL / unrollM) * LDB + (startL % unrollM) * PacketSize], remMask<PacketSize>(rem_)); ploadu<vec>(&B_t[(startL / unrollM) * LDB + (startL % unrollM) * PacketSize], remMask<PacketSize>(rem_));
else else zmm.packet[unrollM * unrollN + startL] =
zmm.packet[unrollM*unrollN+startL] = ploadu<vec>(&B_t[(startL/unrollM)*LDB + (startL%unrollM)*PacketSize]); ploadu<vec>(&B_t[(startL / unrollM) * LDB + (startL % unrollM) * PacketSize]);
aux_startLoadB<unrollM, unrollN, endL, counter - 1, rem>(B_t, LDB, zmm, rem_); aux_startLoadB<unrollM, unrollN, endL, counter - 1, rem>(B_t, LDB, zmm, rem_);
} }
template <int64_t unrollM, int64_t unrollN, int64_t endL, int64_t counter, bool rem> template <int64_t unrollM, int64_t unrollN, int64_t endL, int64_t counter, bool rem>
static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_startLoadB(
aux_startLoadB( Scalar *B_t, int64_t LDB, PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm, int64_t rem_ = 0) {
Scalar *B_t, int64_t LDB,
PacketBlock<vec,EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm, int64_t rem_ = 0)
{
EIGEN_UNUSED_VARIABLE(B_t); EIGEN_UNUSED_VARIABLE(B_t);
EIGEN_UNUSED_VARIABLE(LDB); EIGEN_UNUSED_VARIABLE(LDB);
EIGEN_UNUSED_VARIABLE(zmm); EIGEN_UNUSED_VARIABLE(zmm);
@ -1028,8 +995,8 @@ public:
* for(startB = 0; startB < endB; startB++) * for(startB = 0; startB < endB; startB++)
**/ **/
template <bool isARowMajor, int64_t unrollM, int64_t unrollN, int64_t endB, int64_t counter, int64_t numLoad> template <bool isARowMajor, int64_t unrollM, int64_t unrollN, int64_t endB, int64_t counter, int64_t numLoad>
static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_startBCastA(
aux_startBCastA(Scalar *A_t, int64_t LDA, PacketBlock<vec,EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm) { Scalar *A_t, int64_t LDA, PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm) {
constexpr int64_t counterReverse = endB - counter; constexpr int64_t counterReverse = endB - counter;
constexpr int64_t startB = counterReverse; constexpr int64_t startB = counterReverse;
@ -1039,9 +1006,8 @@ public:
} }
template <bool isARowMajor, int64_t unrollM, int64_t unrollN, int64_t endB, int64_t counter, int64_t numLoad> template <bool isARowMajor, int64_t unrollM, int64_t unrollN, int64_t endB, int64_t counter, int64_t numLoad>
static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_startBCastA(
aux_startBCastA(Scalar *A_t, int64_t LDA, PacketBlock<vec,EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm) Scalar *A_t, int64_t LDA, PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm) {
{
EIGEN_UNUSED_VARIABLE(A_t); EIGEN_UNUSED_VARIABLE(A_t);
EIGEN_UNUSED_VARIABLE(LDA); EIGEN_UNUSED_VARIABLE(LDA);
EIGEN_UNUSED_VARIABLE(zmm); EIGEN_UNUSED_VARIABLE(zmm);
@ -1054,9 +1020,10 @@ public:
* 1-D unroll * 1-D unroll
* for(startM = 0; startM < endM; startM++) * for(startM = 0; startM < endM; startM++)
**/ **/
template<int64_t endM, int64_t counter, int64_t unrollN, int64_t currK, int64_t unrollK, int64_t numLoad, int64_t numBCast, bool rem> template <int64_t endM, int64_t counter, int64_t unrollN, int64_t currK, int64_t unrollK, int64_t numLoad,
static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> int64_t numBCast, bool rem>
aux_loadB(Scalar *B_t, int64_t LDB, PacketBlock<vec,EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm, int64_t rem_ = 0) { static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_loadB(
Scalar *B_t, int64_t LDB, PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm, int64_t rem_ = 0) {
EIGEN_UNUSED_VARIABLE(rem_); EIGEN_UNUSED_VARIABLE(rem_);
if ((numLoad / endM + currK < unrollK)) { if ((numLoad / endM + currK < unrollK)) {
constexpr int64_t counterReverse = endM - counter; constexpr int64_t counterReverse = endM - counter;
@ -1075,12 +1042,10 @@ public:
} }
} }
template<int64_t endM, int64_t counter, int64_t unrollN, int64_t currK, int64_t unrollK, int64_t numLoad, int64_t numBCast, bool rem> template <int64_t endM, int64_t counter, int64_t unrollN, int64_t currK, int64_t unrollK, int64_t numLoad,
static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> int64_t numBCast, bool rem>
aux_loadB( static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_loadB(
Scalar *B_t, int64_t LDB, Scalar *B_t, int64_t LDB, PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm, int64_t rem_ = 0) {
PacketBlock<vec,EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm, int64_t rem_ = 0)
{
EIGEN_UNUSED_VARIABLE(B_t); EIGEN_UNUSED_VARIABLE(B_t);
EIGEN_UNUSED_VARIABLE(LDB); EIGEN_UNUSED_VARIABLE(LDB);
EIGEN_UNUSED_VARIABLE(zmm); EIGEN_UNUSED_VARIABLE(zmm);
@ -1095,11 +1060,11 @@ public:
* for(startN = 0; startN < endN; startN++) * for(startN = 0; startN < endN; startN++)
* for(startK = 0; startK < endK; startK++) * for(startK = 0; startK < endK; startK++)
**/ **/
template<bool isARowMajor, int64_t endM, int64_t endN, int64_t endK, int64_t counter, int64_t numLoad, int64_t numBCast, bool rem> template <bool isARowMajor, int64_t endM, int64_t endN, int64_t endK, int64_t counter, int64_t numLoad,
static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> int64_t numBCast, bool rem>
aux_microKernel( static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_microKernel(
Scalar *B_t, Scalar* A_t, int64_t LDB, int64_t LDA, Scalar *B_t, Scalar *A_t, int64_t LDB, int64_t LDA, PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm,
PacketBlock<vec,EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm, int64_t rem_ = 0) { int64_t rem_ = 0) {
EIGEN_UNUSED_VARIABLE(rem_); EIGEN_UNUSED_VARIABLE(rem_);
constexpr int64_t counterReverse = endM * endN * endK - counter; constexpr int64_t counterReverse = endM * endN * endK - counter;
constexpr int startK = counterReverse / (endM * endN); constexpr int startK = counterReverse / (endM * endN);
@ -1107,10 +1072,8 @@ public:
constexpr int startM = counterReverse % endM; constexpr int startM = counterReverse % endM;
EIGEN_IF_CONSTEXPR(startK == 0 && startM == 0 && startN == 0) { EIGEN_IF_CONSTEXPR(startK == 0 && startM == 0 && startN == 0) {
gemm:: template gemm::template startLoadB<endM, endN, numLoad, rem>(B_t, LDB, zmm, rem_);
startLoadB<endM, endN, numLoad, rem>(B_t, LDB, zmm, rem_); gemm::template startBCastA<isARowMajor, endM, endN, numBCast, numLoad>(A_t, LDA, zmm);
gemm:: template
startBCastA<isARowMajor, endM, endN, numBCast, numLoad>(A_t, LDA, zmm);
} }
{ {
@ -1127,9 +1090,8 @@ public:
} }
// Bcast // Bcast
EIGEN_IF_CONSTEXPR(startM == endM - 1 && (numBCast + startN + startK * endN < endK * endN)) { EIGEN_IF_CONSTEXPR(startM == endM - 1 && (numBCast + startN + startK * endN < endK * endN)) {
zmm.packet[endM*endN+numLoad+(startN+startK*endN)%numBCast] = zmm.packet[endM * endN + numLoad + (startN + startK * endN) % numBCast] = pload1<vec>(&A_t[idA<isARowMajor>(
pload1<vec>(&A_t[idA<isARowMajor>((numBCast + startN + startK*endN)%endN, (numBCast + startN + startK * endN) % endN, (numBCast + startN + startK * endN) / endN, LDA)]);
(numBCast + startN + startK*endN)/endN, LDA)]);
} }
} }
@ -1138,15 +1100,13 @@ public:
gemm::template loadB<endM, endN, startK, endK, numLoad, numBCast, rem>(B_t, LDB, zmm, rem_); gemm::template loadB<endM, endN, startK, endK, numLoad, numBCast, rem>(B_t, LDB, zmm, rem_);
} }
aux_microKernel<isARowMajor, endM, endN, endK, counter - 1, numLoad, numBCast, rem>(B_t, A_t, LDB, LDA, zmm, rem_); aux_microKernel<isARowMajor, endM, endN, endK, counter - 1, numLoad, numBCast, rem>(B_t, A_t, LDB, LDA, zmm, rem_);
} }
template<bool isARowMajor, int64_t endM, int64_t endN, int64_t endK, int64_t counter, int64_t numLoad, int64_t numBCast, bool rem> template <bool isARowMajor, int64_t endM, int64_t endN, int64_t endK, int64_t counter, int64_t numLoad,
static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> int64_t numBCast, bool rem>
aux_microKernel( static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_microKernel(
Scalar *B_t, Scalar* A_t, int64_t LDB, int64_t LDA, Scalar *B_t, Scalar *A_t, int64_t LDB, int64_t LDA, PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm,
PacketBlock<vec,EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm, int64_t rem_ = 0) int64_t rem_ = 0) {
{
EIGEN_UNUSED_VARIABLE(B_t); EIGEN_UNUSED_VARIABLE(B_t);
EIGEN_UNUSED_VARIABLE(A_t); EIGEN_UNUSED_VARIABLE(A_t);
EIGEN_UNUSED_VARIABLE(LDB); EIGEN_UNUSED_VARIABLE(LDB);
@ -1160,8 +1120,7 @@ public:
********************************************************/ ********************************************************/
template <int64_t endM, int64_t endN> template <int64_t endM, int64_t endN>
static EIGEN_ALWAYS_INLINE static EIGEN_ALWAYS_INLINE void setzero(PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm) {
void setzero(PacketBlock<vec,EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm){
aux_setzero<endM, endN, endM * endN>(zmm); aux_setzero<endM, endN, endM * endN>(zmm);
} }
@ -1169,15 +1128,17 @@ public:
* Ideally the compiler folds these into vaddp{s,d} with an embedded memory load. * Ideally the compiler folds these into vaddp{s,d} with an embedded memory load.
*/ */
template <int64_t endM, int64_t endN, bool rem = false> template <int64_t endM, int64_t endN, bool rem = false>
static EIGEN_ALWAYS_INLINE static EIGEN_ALWAYS_INLINE void updateC(Scalar *C_arr, int64_t LDC,
void updateC(Scalar *C_arr, int64_t LDC, PacketBlock<vec,EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm, int64_t rem_ = 0){ PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm,
int64_t rem_ = 0) {
EIGEN_UNUSED_VARIABLE(rem_); EIGEN_UNUSED_VARIABLE(rem_);
aux_updateC<endM, endN, endM * endN, rem>(C_arr, LDC, zmm, rem_); aux_updateC<endM, endN, endM * endN, rem>(C_arr, LDC, zmm, rem_);
} }
template <int64_t endM, int64_t endN, bool rem = false> template <int64_t endM, int64_t endN, bool rem = false>
static EIGEN_ALWAYS_INLINE static EIGEN_ALWAYS_INLINE void storeC(Scalar *C_arr, int64_t LDC,
void storeC(Scalar *C_arr, int64_t LDC, PacketBlock<vec,EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm, int64_t rem_ = 0){ PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm,
int64_t rem_ = 0) {
EIGEN_UNUSED_VARIABLE(rem_); EIGEN_UNUSED_VARIABLE(rem_);
aux_storeC<endM, endN, endM * endN, rem>(C_arr, LDC, zmm, rem_); aux_storeC<endM, endN, endM * endN, rem>(C_arr, LDC, zmm, rem_);
} }
@ -1186,8 +1147,9 @@ public:
* Use numLoad registers for loading B at start of microKernel * Use numLoad registers for loading B at start of microKernel
*/ */
template <int64_t unrollM, int64_t unrollN, int64_t endL, bool rem> template <int64_t unrollM, int64_t unrollN, int64_t endL, bool rem>
static EIGEN_ALWAYS_INLINE static EIGEN_ALWAYS_INLINE void startLoadB(Scalar *B_t, int64_t LDB,
void startLoadB(Scalar *B_t, int64_t LDB, PacketBlock<vec,EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm, int64_t rem_ = 0){ PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm,
int64_t rem_ = 0) {
EIGEN_UNUSED_VARIABLE(rem_); EIGEN_UNUSED_VARIABLE(rem_);
aux_startLoadB<unrollM, unrollN, endL, endL, rem>(B_t, LDB, zmm, rem_); aux_startLoadB<unrollM, unrollN, endL, endL, rem>(B_t, LDB, zmm, rem_);
} }
@ -1196,8 +1158,8 @@ public:
* Use numBCast registers for broadcasting A at start of microKernel * Use numBCast registers for broadcasting A at start of microKernel
*/ */
template <bool isARowMajor, int64_t unrollM, int64_t unrollN, int64_t endB, int64_t numLoad> template <bool isARowMajor, int64_t unrollM, int64_t unrollN, int64_t endB, int64_t numLoad>
static EIGEN_ALWAYS_INLINE static EIGEN_ALWAYS_INLINE void startBCastA(Scalar *A_t, int64_t LDA,
void startBCastA(Scalar *A_t, int64_t LDA, PacketBlock<vec,EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm){ PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm) {
aux_startBCastA<isARowMajor, unrollM, unrollN, endB, endB, numLoad>(A_t, LDA, zmm); aux_startBCastA<isARowMajor, unrollM, unrollN, endB, endB, numLoad>(A_t, LDA, zmm);
} }
@ -1205,9 +1167,9 @@ public:
* Loads next set of B into vector registers between each K unroll. * Loads next set of B into vector registers between each K unroll.
*/ */
template <int64_t endM, int64_t unrollN, int64_t currK, int64_t unrollK, int64_t numLoad, int64_t numBCast, bool rem> template <int64_t endM, int64_t unrollN, int64_t currK, int64_t unrollK, int64_t numLoad, int64_t numBCast, bool rem>
static EIGEN_ALWAYS_INLINE static EIGEN_ALWAYS_INLINE void loadB(Scalar *B_t, int64_t LDB,
void loadB( PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm,
Scalar *B_t, int64_t LDB, PacketBlock<vec,EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm, int64_t rem_ = 0){ int64_t rem_ = 0) {
EIGEN_UNUSED_VARIABLE(rem_); EIGEN_UNUSED_VARIABLE(rem_);
aux_loadB<endM, endM, unrollN, currK, unrollK, numLoad, numBCast, rem>(B_t, LDB, zmm, rem_); aux_loadB<endM, endM, unrollN, currK, unrollK, numLoad, numBCast, rem>(B_t, LDB, zmm, rem_);
} }
@ -1235,18 +1197,16 @@ public:
* From testing, there are no register spills with clang. There are register spills with GNU, which * From testing, there are no register spills with clang. There are register spills with GNU, which
* causes a performance hit. * causes a performance hit.
*/ */
template<bool isARowMajor, int64_t endM, int64_t endN, int64_t endK, int64_t numLoad, int64_t numBCast, bool rem = false> template <bool isARowMajor, int64_t endM, int64_t endN, int64_t endK, int64_t numLoad, int64_t numBCast,
static EIGEN_ALWAYS_INLINE bool rem = false>
void microKernel( static EIGEN_ALWAYS_INLINE void microKernel(Scalar *B_t, Scalar *A_t, int64_t LDB, int64_t LDA,
Scalar *B_t, Scalar* A_t, int64_t LDB, int64_t LDA, PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm,
PacketBlock<vec,EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm, int64_t rem_ = 0){ int64_t rem_ = 0) {
EIGEN_UNUSED_VARIABLE(rem_); EIGEN_UNUSED_VARIABLE(rem_);
aux_microKernel<isARowMajor,endM, endN, endK, endM*endN*endK, numLoad, numBCast, rem>( aux_microKernel<isARowMajor, endM, endN, endK, endM * endN * endK, numLoad, numBCast, rem>(B_t, A_t, LDB, LDA, zmm,
B_t, A_t, LDB, LDA, zmm, rem_); rem_);
} }
}; };
} // namespace unrolls } // namespace unrolls
#endif // EIGEN_UNROLLS_IMPL_H #endif // EIGEN_UNROLLS_IMPL_H

View File

@ -171,7 +171,7 @@ EIGEN_DONT_INLINE void triangular_solve_matrix<Scalar,Index,OnTheLeft,Mode,Conju
std::ptrdiff_t l1, l2, l3; std::ptrdiff_t l1, l2, l3;
manage_caching_sizes(GetAction, &l1, &l2, &l3); manage_caching_sizes(GetAction, &l1, &l2, &l3);
#if defined(EIGEN_ENABLE_AVX512_NOCOPY_TRSML_CUTOFFS) #if (EIGEN_ENABLE_AVX512_NOCOPY_TRSM_L_CUTOFFS)
EIGEN_IF_CONSTEXPR( (OtherInnerStride == 1 && EIGEN_IF_CONSTEXPR( (OtherInnerStride == 1 &&
(std::is_same<Scalar,float>::value || (std::is_same<Scalar,float>::value ||
std::is_same<Scalar,double>::value)) ) { std::is_same<Scalar,double>::value)) ) {
@ -246,7 +246,7 @@ EIGEN_DONT_INLINE void triangular_solve_matrix<Scalar,Index,OnTheLeft,Mode,Conju
// tr solve // tr solve
{ {
Index i = IsLower ? k2+k1 : k2-k1; Index i = IsLower ? k2+k1 : k2-k1;
#if defined(EIGEN_USE_AVX512_TRSM_L_KERNELS) #if EIGEN_USE_AVX512_TRSM_L_KERNELS
EIGEN_IF_CONSTEXPR( (OtherInnerStride == 1 && EIGEN_IF_CONSTEXPR( (OtherInnerStride == 1 &&
(std::is_same<Scalar,float>::value || (std::is_same<Scalar,float>::value ||
std::is_same<Scalar,double>::value)) ) { std::is_same<Scalar,double>::value)) ) {
@ -318,7 +318,7 @@ EIGEN_DONT_INLINE void triangular_solve_matrix<Scalar,Index,OnTheRight,Mode,Conj
{ {
Index rows = otherSize; Index rows = otherSize;
#if defined(EIGEN_ENABLE_AVX512_NOCOPY_TRSM_R_CUTOFFS) #if EIGEN_ENABLE_AVX512_NOCOPY_TRSM_R_CUTOFFS
EIGEN_IF_CONSTEXPR( (OtherInnerStride == 1 && EIGEN_IF_CONSTEXPR( (OtherInnerStride == 1 &&
(std::is_same<Scalar,float>::value || (std::is_same<Scalar,float>::value ||
std::is_same<Scalar,double>::value)) ) { std::is_same<Scalar,double>::value)) ) {