mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-05-03 01:04:23 +08:00
AVX512 TRSM kernels use alloca if EIGEN_NO_MALLOC requested
This commit is contained in:
parent
4d1c16eab8
commit
37673ca1bc
File diff suppressed because it is too large
Load Diff
@ -12,13 +12,20 @@
|
|||||||
|
|
||||||
#include "../../InternalHeaderCheck.h"
|
#include "../../InternalHeaderCheck.h"
|
||||||
|
|
||||||
#define EIGEN_USE_AVX512_TRSM_KERNELS // Comment out to prevent using optimized trsm kernels.
|
#if !defined(EIGEN_USE_AVX512_TRSM_KERNELS)
|
||||||
|
#define EIGEN_USE_AVX512_TRSM_KERNELS 1
|
||||||
#if defined(EIGEN_USE_AVX512_TRSM_KERNELS)
|
|
||||||
#define EIGEN_USE_AVX512_TRSM_R_KERNELS
|
|
||||||
#if !defined(EIGEN_NO_MALLOC) // Separate MACRO since these kernels require malloc
|
|
||||||
#define EIGEN_USE_AVX512_TRSM_L_KERNELS
|
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
#if EIGEN_USE_AVX512_TRSM_KERNELS
|
||||||
|
#if !defined(EIGEN_USE_AVX512_TRSM_R_KERNELS)
|
||||||
|
#define EIGEN_USE_AVX512_TRSM_R_KERNELS 1
|
||||||
|
#endif
|
||||||
|
#if !defined(EIGEN_USE_AVX512_TRSM_L_KERNELS)
|
||||||
|
#define EIGEN_USE_AVX512_TRSM_L_KERNELS 1
|
||||||
|
#endif
|
||||||
|
#else // EIGEN_USE_AVX512_TRSM_KERNELS == 0
|
||||||
|
#define EIGEN_USE_AVX512_TRSM_R_KERNELS 0
|
||||||
|
#define EIGEN_USE_AVX512_TRSM_L_KERNELS 0
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#if defined(EIGEN_HAS_CXX17_IFCONSTEXPR)
|
#if defined(EIGEN_HAS_CXX17_IFCONSTEXPR)
|
||||||
@ -49,8 +56,7 @@ typedef Packet4d vecHalfDouble;
|
|||||||
// Note: this depends on macros and typedefs above.
|
// Note: this depends on macros and typedefs above.
|
||||||
#include "TrsmUnrolls.inc"
|
#include "TrsmUnrolls.inc"
|
||||||
|
|
||||||
|
#if (EIGEN_USE_AVX512_TRSM_KERNELS) && (EIGEN_COMP_CLANG != 0)
|
||||||
#if defined(EIGEN_USE_AVX512_TRSM_KERNELS) && (EIGEN_COMP_CLANG != 0)
|
|
||||||
/**
|
/**
|
||||||
* For smaller problem sizes, and certain compilers, using the optimized kernels trsmKernelL/R directly
|
* For smaller problem sizes, and certain compilers, using the optimized kernels trsmKernelL/R directly
|
||||||
* is faster than the packed versions in TriangularSolverMatrix.h.
|
* is faster than the packed versions in TriangularSolverMatrix.h.
|
||||||
@ -67,34 +73,46 @@ typedef Packet4d vecHalfDouble;
|
|||||||
* M = Dimension of triangular matrix
|
* M = Dimension of triangular matrix
|
||||||
*
|
*
|
||||||
*/
|
*/
|
||||||
#define EIGEN_ENABLE_AVX512_NOCOPY_TRSM_CUTOFFS // Comment out to disable no-copy dispatch
|
#if !defined(EIGEN_ENABLE_AVX512_NOCOPY_TRSM_CUTOFFS)
|
||||||
|
#define EIGEN_ENABLE_AVX512_NOCOPY_TRSM_CUTOFFS 1
|
||||||
|
#endif
|
||||||
|
|
||||||
#if defined(EIGEN_ENABLE_AVX512_NOCOPY_TRSM_CUTOFFS)
|
#if EIGEN_ENABLE_AVX512_NOCOPY_TRSM_CUTOFFS
|
||||||
#define EIGEN_ENABLE_AVX512_NOCOPY_TRSM_R_CUTOFFS
|
|
||||||
#if !defined(EIGEN_NO_MALLOC)
|
#if EIGEN_USE_AVX512_TRSM_R_KERNELS
|
||||||
#define EIGEN_ENABLE_AVX512_NOCOPY_TRSM_L_CUTOFFS
|
#if !defined(EIGEN_ENABLE_AVX512_NOCOPY_TRSM_R_CUTOFFS)
|
||||||
|
#define EIGEN_ENABLE_AVX512_NOCOPY_TRSM_R_CUTOFFS 1
|
||||||
|
#endif // !defined(EIGEN_ENABLE_AVX512_NOCOPY_TRSM_R_CUTOFFS)
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
#if EIGEN_USE_AVX512_TRSM_L_KERNELS
|
||||||
|
#if !defined(EIGEN_ENABLE_AVX512_NOCOPY_TRSM_L_CUTOFFS)
|
||||||
|
#define EIGEN_ENABLE_AVX512_NOCOPY_TRSM_L_CUTOFFS 1
|
||||||
#endif
|
#endif
|
||||||
|
#endif // EIGEN_USE_AVX512_TRSM_L_KERNELS
|
||||||
|
|
||||||
|
#else // EIGEN_ENABLE_AVX512_NOCOPY_TRSM_CUTOFFS == 0
|
||||||
|
#define EIGEN_ENABLE_AVX512_NOCOPY_TRSM_R_CUTOFFS 0
|
||||||
|
#define EIGEN_ENABLE_AVX512_NOCOPY_TRSM_L_CUTOFFS 0
|
||||||
|
#endif // EIGEN_ENABLE_AVX512_NOCOPY_TRSM_CUTOFFS
|
||||||
|
|
||||||
template <typename Scalar>
|
template <typename Scalar>
|
||||||
int64_t avx512_trsm_cutoff(int64_t L2Size, int64_t N, double L2Cap) {
|
int64_t avx512_trsm_cutoff(int64_t L2Size, int64_t N, double L2Cap) {
|
||||||
const int64_t U3 = 3 * packet_traits<Scalar>::size;
|
const int64_t U3 = 3 * packet_traits<Scalar>::size;
|
||||||
const int64_t MaxNb = 5 * U3;
|
const int64_t MaxNb = 5 * U3;
|
||||||
int64_t Nb = std::min(MaxNb, N);
|
int64_t Nb = std::min(MaxNb, N);
|
||||||
double cutoff_d = (((L2Size*L2Cap)/(sizeof(Scalar)))-(EIGEN_AVX_MAX_NUM_ROW)*Nb)/
|
double cutoff_d =
|
||||||
((EIGEN_AVX_MAX_NUM_ROW)+Nb);
|
(((L2Size * L2Cap) / (sizeof(Scalar))) - (EIGEN_AVX_MAX_NUM_ROW)*Nb) / ((EIGEN_AVX_MAX_NUM_ROW) + Nb);
|
||||||
int64_t cutoff_l = static_cast<int64_t>(cutoff_d);
|
int64_t cutoff_l = static_cast<int64_t>(cutoff_d);
|
||||||
return (cutoff_l / EIGEN_AVX_MAX_NUM_ROW) * EIGEN_AVX_MAX_NUM_ROW;
|
return (cutoff_l / EIGEN_AVX_MAX_NUM_ROW) * EIGEN_AVX_MAX_NUM_ROW;
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Used by gemmKernel for the case A/B row-major and C col-major.
|
* Used by gemmKernel for the case A/B row-major and C col-major.
|
||||||
*/
|
*/
|
||||||
template <typename Scalar, typename vec, int64_t unrollM, int64_t unrollN, bool remM, bool remN>
|
template <typename Scalar, typename vec, int64_t unrollM, int64_t unrollN, bool remM, bool remN>
|
||||||
static EIGEN_ALWAYS_INLINE
|
static EIGEN_ALWAYS_INLINE void transStoreC(PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm,
|
||||||
void transStoreC(PacketBlock<vec,EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm,
|
|
||||||
Scalar *C_arr, int64_t LDC, int64_t remM_ = 0, int64_t remN_ = 0) {
|
Scalar *C_arr, int64_t LDC, int64_t remM_ = 0, int64_t remN_ = 0) {
|
||||||
EIGEN_UNUSED_VARIABLE(remN_);
|
EIGEN_UNUSED_VARIABLE(remN_);
|
||||||
EIGEN_UNUSED_VARIABLE(remM_);
|
EIGEN_UNUSED_VARIABLE(remM_);
|
||||||
@ -194,16 +212,13 @@ void transStoreC(PacketBlock<vec,EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm,
|
|||||||
* handleKRem: Handle arbitrary K? This is not needed for trsm.
|
* handleKRem: Handle arbitrary K? This is not needed for trsm.
|
||||||
*/
|
*/
|
||||||
template <typename Scalar, bool isARowMajor, bool isCRowMajor, bool isAdd, bool handleKRem>
|
template <typename Scalar, bool isARowMajor, bool isCRowMajor, bool isAdd, bool handleKRem>
|
||||||
void gemmKernel(Scalar *A_arr, Scalar *B_arr, Scalar *C_arr,
|
void gemmKernel(Scalar *A_arr, Scalar *B_arr, Scalar *C_arr, int64_t M, int64_t N, int64_t K, int64_t LDA, int64_t LDB,
|
||||||
int64_t M, int64_t N, int64_t K,
|
int64_t LDC) {
|
||||||
int64_t LDA, int64_t LDB, int64_t LDC) {
|
|
||||||
using urolls = unrolls::gemm<Scalar, isAdd>;
|
using urolls = unrolls::gemm<Scalar, isAdd>;
|
||||||
constexpr int64_t U3 = urolls::PacketSize * 3;
|
constexpr int64_t U3 = urolls::PacketSize * 3;
|
||||||
constexpr int64_t U2 = urolls::PacketSize * 2;
|
constexpr int64_t U2 = urolls::PacketSize * 2;
|
||||||
constexpr int64_t U1 = urolls::PacketSize * 1;
|
constexpr int64_t U1 = urolls::PacketSize * 1;
|
||||||
using vec = typename std::conditional<std::is_same<Scalar, float>::value,
|
using vec = typename std::conditional<std::is_same<Scalar, float>::value, vecFullFloat, vecFullDouble>::type;
|
||||||
vecFullFloat,
|
|
||||||
vecFullDouble>::type;
|
|
||||||
int64_t N_ = (N / U3) * U3;
|
int64_t N_ = (N / U3) * U3;
|
||||||
int64_t M_ = (M / EIGEN_AVX_MAX_NUM_ROW) * EIGEN_AVX_MAX_NUM_ROW;
|
int64_t M_ = (M / EIGEN_AVX_MAX_NUM_ROW) * EIGEN_AVX_MAX_NUM_ROW;
|
||||||
int64_t K_ = (K / EIGEN_AVX_MAX_K_UNROL) * EIGEN_AVX_MAX_K_UNROL;
|
int64_t K_ = (K / EIGEN_AVX_MAX_K_UNROL) * EIGEN_AVX_MAX_K_UNROL;
|
||||||
@ -216,19 +231,19 @@ void gemmKernel(Scalar *A_arr, Scalar *B_arr, Scalar *C_arr,
|
|||||||
PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> zmm;
|
PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> zmm;
|
||||||
urolls::template setzero<3, EIGEN_AVX_MAX_NUM_ROW>(zmm);
|
urolls::template setzero<3, EIGEN_AVX_MAX_NUM_ROW>(zmm);
|
||||||
for (int64_t k = 0; k < K_; k += EIGEN_AVX_MAX_K_UNROL) {
|
for (int64_t k = 0; k < K_; k += EIGEN_AVX_MAX_K_UNROL) {
|
||||||
urolls:: template microKernel<isARowMajor,3,EIGEN_AVX_MAX_NUM_ROW,EIGEN_AVX_MAX_K_UNROL,
|
urolls::template microKernel<isARowMajor, 3, EIGEN_AVX_MAX_NUM_ROW, EIGEN_AVX_MAX_K_UNROL, EIGEN_AVX_MAX_B_LOAD,
|
||||||
EIGEN_AVX_MAX_B_LOAD,EIGEN_AVX_MAX_A_BCAST>(
|
EIGEN_AVX_MAX_A_BCAST>(B_t, A_t, LDB, LDA, zmm);
|
||||||
B_t, A_t, LDB, LDA, zmm);
|
|
||||||
B_t += EIGEN_AVX_MAX_K_UNROL * LDB;
|
B_t += EIGEN_AVX_MAX_K_UNROL * LDB;
|
||||||
EIGEN_IF_CONSTEXPR(isARowMajor) A_t += EIGEN_AVX_MAX_K_UNROL; else A_t += EIGEN_AVX_MAX_K_UNROL*LDA;
|
EIGEN_IF_CONSTEXPR(isARowMajor) A_t += EIGEN_AVX_MAX_K_UNROL;
|
||||||
|
else A_t += EIGEN_AVX_MAX_K_UNROL * LDA;
|
||||||
}
|
}
|
||||||
EIGEN_IF_CONSTEXPR(handleKRem) {
|
EIGEN_IF_CONSTEXPR(handleKRem) {
|
||||||
for (int64_t k = K_; k < K; k++) {
|
for (int64_t k = K_; k < K; k++) {
|
||||||
urolls:: template microKernel<isARowMajor,3,EIGEN_AVX_MAX_NUM_ROW,1,
|
urolls::template microKernel<isARowMajor, 3, EIGEN_AVX_MAX_NUM_ROW, 1, EIGEN_AVX_B_LOAD_SETS * 3,
|
||||||
EIGEN_AVX_B_LOAD_SETS*3,EIGEN_AVX_MAX_A_BCAST>(
|
EIGEN_AVX_MAX_A_BCAST>(B_t, A_t, LDB, LDA, zmm);
|
||||||
B_t, A_t, LDB, LDA, zmm);
|
|
||||||
B_t += LDB;
|
B_t += LDB;
|
||||||
EIGEN_IF_CONSTEXPR(isARowMajor) A_t++; else A_t += LDA;
|
EIGEN_IF_CONSTEXPR(isARowMajor) A_t++;
|
||||||
|
else A_t += LDA;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
EIGEN_IF_CONSTEXPR(isCRowMajor) {
|
EIGEN_IF_CONSTEXPR(isCRowMajor) {
|
||||||
@ -245,19 +260,19 @@ void gemmKernel(Scalar *A_arr, Scalar *B_arr, Scalar *C_arr,
|
|||||||
PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> zmm;
|
PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> zmm;
|
||||||
urolls::template setzero<3, 4>(zmm);
|
urolls::template setzero<3, 4>(zmm);
|
||||||
for (int64_t k = 0; k < K_; k += EIGEN_AVX_MAX_K_UNROL) {
|
for (int64_t k = 0; k < K_; k += EIGEN_AVX_MAX_K_UNROL) {
|
||||||
urolls:: template microKernel<isARowMajor,3,4,EIGEN_AVX_MAX_K_UNROL,
|
urolls::template microKernel<isARowMajor, 3, 4, EIGEN_AVX_MAX_K_UNROL, EIGEN_AVX_B_LOAD_SETS * 3,
|
||||||
EIGEN_AVX_B_LOAD_SETS*3,EIGEN_AVX_MAX_A_BCAST>(
|
EIGEN_AVX_MAX_A_BCAST>(B_t, A_t, LDB, LDA, zmm);
|
||||||
B_t, A_t, LDB, LDA, zmm);
|
|
||||||
B_t += EIGEN_AVX_MAX_K_UNROL * LDB;
|
B_t += EIGEN_AVX_MAX_K_UNROL * LDB;
|
||||||
EIGEN_IF_CONSTEXPR(isARowMajor) A_t += EIGEN_AVX_MAX_K_UNROL; else A_t += EIGEN_AVX_MAX_K_UNROL*LDA;
|
EIGEN_IF_CONSTEXPR(isARowMajor) A_t += EIGEN_AVX_MAX_K_UNROL;
|
||||||
|
else A_t += EIGEN_AVX_MAX_K_UNROL * LDA;
|
||||||
}
|
}
|
||||||
EIGEN_IF_CONSTEXPR(handleKRem) {
|
EIGEN_IF_CONSTEXPR(handleKRem) {
|
||||||
for (int64_t k = K_; k < K; k++) {
|
for (int64_t k = K_; k < K; k++) {
|
||||||
urolls:: template microKernel<isARowMajor,3,4,1,
|
urolls::template microKernel<isARowMajor, 3, 4, 1, EIGEN_AVX_B_LOAD_SETS * 3, EIGEN_AVX_MAX_A_BCAST>(
|
||||||
EIGEN_AVX_B_LOAD_SETS*3,EIGEN_AVX_MAX_A_BCAST>(
|
|
||||||
B_t, A_t, LDB, LDA, zmm);
|
B_t, A_t, LDB, LDA, zmm);
|
||||||
B_t += LDB;
|
B_t += LDB;
|
||||||
EIGEN_IF_CONSTEXPR(isARowMajor) A_t++; else A_t += LDA;
|
EIGEN_IF_CONSTEXPR(isARowMajor) A_t++;
|
||||||
|
else A_t += LDA;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
EIGEN_IF_CONSTEXPR(isCRowMajor) {
|
EIGEN_IF_CONSTEXPR(isCRowMajor) {
|
||||||
@ -275,19 +290,19 @@ void gemmKernel(Scalar *A_arr, Scalar *B_arr, Scalar *C_arr,
|
|||||||
PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> zmm;
|
PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> zmm;
|
||||||
urolls::template setzero<3, 2>(zmm);
|
urolls::template setzero<3, 2>(zmm);
|
||||||
for (int64_t k = 0; k < K_; k += EIGEN_AVX_MAX_K_UNROL) {
|
for (int64_t k = 0; k < K_; k += EIGEN_AVX_MAX_K_UNROL) {
|
||||||
urolls:: template microKernel<isARowMajor,3,2,EIGEN_AVX_MAX_K_UNROL,
|
urolls::template microKernel<isARowMajor, 3, 2, EIGEN_AVX_MAX_K_UNROL, EIGEN_AVX_B_LOAD_SETS * 3,
|
||||||
EIGEN_AVX_B_LOAD_SETS*3,EIGEN_AVX_MAX_A_BCAST>(
|
EIGEN_AVX_MAX_A_BCAST>(B_t, A_t, LDB, LDA, zmm);
|
||||||
B_t, A_t, LDB, LDA, zmm);
|
|
||||||
B_t += EIGEN_AVX_MAX_K_UNROL * LDB;
|
B_t += EIGEN_AVX_MAX_K_UNROL * LDB;
|
||||||
EIGEN_IF_CONSTEXPR(isARowMajor) A_t += EIGEN_AVX_MAX_K_UNROL; else A_t += EIGEN_AVX_MAX_K_UNROL*LDA;
|
EIGEN_IF_CONSTEXPR(isARowMajor) A_t += EIGEN_AVX_MAX_K_UNROL;
|
||||||
|
else A_t += EIGEN_AVX_MAX_K_UNROL * LDA;
|
||||||
}
|
}
|
||||||
EIGEN_IF_CONSTEXPR(handleKRem) {
|
EIGEN_IF_CONSTEXPR(handleKRem) {
|
||||||
for (int64_t k = K_; k < K; k++) {
|
for (int64_t k = K_; k < K; k++) {
|
||||||
urolls:: template microKernel<isARowMajor,3,2,1,
|
urolls::template microKernel<isARowMajor, 3, 2, 1, EIGEN_AVX_B_LOAD_SETS * 3, EIGEN_AVX_MAX_A_BCAST>(
|
||||||
EIGEN_AVX_B_LOAD_SETS*3,EIGEN_AVX_MAX_A_BCAST>(
|
|
||||||
B_t, A_t, LDB, LDA, zmm);
|
B_t, A_t, LDB, LDA, zmm);
|
||||||
B_t += LDB;
|
B_t += LDB;
|
||||||
EIGEN_IF_CONSTEXPR(isARowMajor) A_t++; else A_t += LDA;
|
EIGEN_IF_CONSTEXPR(isARowMajor) A_t++;
|
||||||
|
else A_t += LDA;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
EIGEN_IF_CONSTEXPR(isCRowMajor) {
|
EIGEN_IF_CONSTEXPR(isCRowMajor) {
|
||||||
@ -306,18 +321,18 @@ void gemmKernel(Scalar *A_arr, Scalar *B_arr, Scalar *C_arr,
|
|||||||
urolls::template setzero<3, 1>(zmm);
|
urolls::template setzero<3, 1>(zmm);
|
||||||
{
|
{
|
||||||
for (int64_t k = 0; k < K_; k += EIGEN_AVX_MAX_K_UNROL) {
|
for (int64_t k = 0; k < K_; k += EIGEN_AVX_MAX_K_UNROL) {
|
||||||
urolls:: template microKernel<isARowMajor,3,1,EIGEN_AVX_MAX_K_UNROL,
|
urolls::template microKernel<isARowMajor, 3, 1, EIGEN_AVX_MAX_K_UNROL, EIGEN_AVX_B_LOAD_SETS * 3, 1>(
|
||||||
EIGEN_AVX_B_LOAD_SETS*3,1>(
|
|
||||||
B_t, A_t, LDB, LDA, zmm);
|
B_t, A_t, LDB, LDA, zmm);
|
||||||
B_t += EIGEN_AVX_MAX_K_UNROL * LDB;
|
B_t += EIGEN_AVX_MAX_K_UNROL * LDB;
|
||||||
EIGEN_IF_CONSTEXPR(isARowMajor) A_t += EIGEN_AVX_MAX_K_UNROL; else A_t += EIGEN_AVX_MAX_K_UNROL*LDA;
|
EIGEN_IF_CONSTEXPR(isARowMajor) A_t += EIGEN_AVX_MAX_K_UNROL;
|
||||||
|
else A_t += EIGEN_AVX_MAX_K_UNROL * LDA;
|
||||||
}
|
}
|
||||||
EIGEN_IF_CONSTEXPR(handleKRem) {
|
EIGEN_IF_CONSTEXPR(handleKRem) {
|
||||||
for (int64_t k = K_; k < K; k++) {
|
for (int64_t k = K_; k < K; k++) {
|
||||||
urolls:: template microKernel<isARowMajor,3,1,1,
|
urolls::template microKernel<isARowMajor, 3, 1, 1, EIGEN_AVX_B_LOAD_SETS * 3, 1>(B_t, A_t, LDB, LDA, zmm);
|
||||||
EIGEN_AVX_B_LOAD_SETS*3,1>(B_t, A_t, LDB, LDA, zmm);
|
|
||||||
B_t += LDB;
|
B_t += LDB;
|
||||||
EIGEN_IF_CONSTEXPR(isARowMajor) A_t++; else A_t += LDA;
|
EIGEN_IF_CONSTEXPR(isARowMajor) A_t++;
|
||||||
|
else A_t += LDA;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
EIGEN_IF_CONSTEXPR(isCRowMajor) {
|
EIGEN_IF_CONSTEXPR(isCRowMajor) {
|
||||||
@ -339,19 +354,19 @@ void gemmKernel(Scalar *A_arr, Scalar *B_arr, Scalar *C_arr,
|
|||||||
PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> zmm;
|
PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> zmm;
|
||||||
urolls::template setzero<2, EIGEN_AVX_MAX_NUM_ROW>(zmm);
|
urolls::template setzero<2, EIGEN_AVX_MAX_NUM_ROW>(zmm);
|
||||||
for (int64_t k = 0; k < K_; k += EIGEN_AVX_MAX_K_UNROL) {
|
for (int64_t k = 0; k < K_; k += EIGEN_AVX_MAX_K_UNROL) {
|
||||||
urolls:: template microKernel<isARowMajor,2,EIGEN_AVX_MAX_NUM_ROW,
|
urolls::template microKernel<isARowMajor, 2, EIGEN_AVX_MAX_NUM_ROW, EIGEN_AVX_MAX_K_UNROL, EIGEN_AVX_MAX_B_LOAD,
|
||||||
EIGEN_AVX_MAX_K_UNROL,EIGEN_AVX_MAX_B_LOAD,EIGEN_AVX_MAX_A_BCAST>(
|
EIGEN_AVX_MAX_A_BCAST>(B_t, A_t, LDB, LDA, zmm);
|
||||||
B_t, A_t, LDB, LDA, zmm);
|
|
||||||
B_t += EIGEN_AVX_MAX_K_UNROL * LDB;
|
B_t += EIGEN_AVX_MAX_K_UNROL * LDB;
|
||||||
EIGEN_IF_CONSTEXPR(isARowMajor) A_t += EIGEN_AVX_MAX_K_UNROL; else A_t += EIGEN_AVX_MAX_K_UNROL*LDA;
|
EIGEN_IF_CONSTEXPR(isARowMajor) A_t += EIGEN_AVX_MAX_K_UNROL;
|
||||||
|
else A_t += EIGEN_AVX_MAX_K_UNROL * LDA;
|
||||||
}
|
}
|
||||||
EIGEN_IF_CONSTEXPR(handleKRem) {
|
EIGEN_IF_CONSTEXPR(handleKRem) {
|
||||||
for (int64_t k = K_; k < K; k++) {
|
for (int64_t k = K_; k < K; k++) {
|
||||||
urolls:: template microKernel<isARowMajor,2,EIGEN_AVX_MAX_NUM_ROW,1,
|
urolls::template microKernel<isARowMajor, 2, EIGEN_AVX_MAX_NUM_ROW, 1, EIGEN_AVX_MAX_B_LOAD,
|
||||||
EIGEN_AVX_MAX_B_LOAD,EIGEN_AVX_MAX_A_BCAST>(
|
EIGEN_AVX_MAX_A_BCAST>(B_t, A_t, LDB, LDA, zmm);
|
||||||
B_t, A_t, LDB, LDA, zmm);
|
|
||||||
B_t += LDB;
|
B_t += LDB;
|
||||||
EIGEN_IF_CONSTEXPR(isARowMajor) A_t++; else A_t += LDA;
|
EIGEN_IF_CONSTEXPR(isARowMajor) A_t++;
|
||||||
|
else A_t += LDA;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
EIGEN_IF_CONSTEXPR(isCRowMajor) {
|
EIGEN_IF_CONSTEXPR(isCRowMajor) {
|
||||||
@ -368,19 +383,19 @@ void gemmKernel(Scalar *A_arr, Scalar *B_arr, Scalar *C_arr,
|
|||||||
PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> zmm;
|
PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> zmm;
|
||||||
urolls::template setzero<2, 4>(zmm);
|
urolls::template setzero<2, 4>(zmm);
|
||||||
for (int64_t k = 0; k < K_; k += EIGEN_AVX_MAX_K_UNROL) {
|
for (int64_t k = 0; k < K_; k += EIGEN_AVX_MAX_K_UNROL) {
|
||||||
urolls:: template microKernel<isARowMajor,2,4,EIGEN_AVX_MAX_K_UNROL,
|
urolls::template microKernel<isARowMajor, 2, 4, EIGEN_AVX_MAX_K_UNROL, EIGEN_AVX_MAX_B_LOAD,
|
||||||
EIGEN_AVX_MAX_B_LOAD,EIGEN_AVX_MAX_A_BCAST>(
|
EIGEN_AVX_MAX_A_BCAST>(B_t, A_t, LDB, LDA, zmm);
|
||||||
B_t, A_t, LDB, LDA, zmm);
|
|
||||||
B_t += EIGEN_AVX_MAX_K_UNROL * LDB;
|
B_t += EIGEN_AVX_MAX_K_UNROL * LDB;
|
||||||
EIGEN_IF_CONSTEXPR(isARowMajor) A_t += EIGEN_AVX_MAX_K_UNROL; else A_t += EIGEN_AVX_MAX_K_UNROL*LDA;
|
EIGEN_IF_CONSTEXPR(isARowMajor) A_t += EIGEN_AVX_MAX_K_UNROL;
|
||||||
|
else A_t += EIGEN_AVX_MAX_K_UNROL * LDA;
|
||||||
}
|
}
|
||||||
EIGEN_IF_CONSTEXPR(handleKRem) {
|
EIGEN_IF_CONSTEXPR(handleKRem) {
|
||||||
for (int64_t k = K_; k < K; k++) {
|
for (int64_t k = K_; k < K; k++) {
|
||||||
urolls:: template microKernel<isARowMajor,2,4,1,
|
urolls::template microKernel<isARowMajor, 2, 4, 1, EIGEN_AVX_MAX_B_LOAD, EIGEN_AVX_MAX_A_BCAST>(B_t, A_t, LDB,
|
||||||
EIGEN_AVX_MAX_B_LOAD,EIGEN_AVX_MAX_A_BCAST>(
|
LDA, zmm);
|
||||||
B_t, A_t, LDB, LDA, zmm);
|
|
||||||
B_t += LDB;
|
B_t += LDB;
|
||||||
EIGEN_IF_CONSTEXPR(isARowMajor) A_t++; else A_t += LDA;
|
EIGEN_IF_CONSTEXPR(isARowMajor) A_t++;
|
||||||
|
else A_t += LDA;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
EIGEN_IF_CONSTEXPR(isCRowMajor) {
|
EIGEN_IF_CONSTEXPR(isCRowMajor) {
|
||||||
@ -398,19 +413,19 @@ void gemmKernel(Scalar *A_arr, Scalar *B_arr, Scalar *C_arr,
|
|||||||
PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> zmm;
|
PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> zmm;
|
||||||
urolls::template setzero<2, 2>(zmm);
|
urolls::template setzero<2, 2>(zmm);
|
||||||
for (int64_t k = 0; k < K_; k += EIGEN_AVX_MAX_K_UNROL) {
|
for (int64_t k = 0; k < K_; k += EIGEN_AVX_MAX_K_UNROL) {
|
||||||
urolls:: template microKernel<isARowMajor,2,2,EIGEN_AVX_MAX_K_UNROL,
|
urolls::template microKernel<isARowMajor, 2, 2, EIGEN_AVX_MAX_K_UNROL, EIGEN_AVX_MAX_B_LOAD,
|
||||||
EIGEN_AVX_MAX_B_LOAD,EIGEN_AVX_MAX_A_BCAST>(
|
EIGEN_AVX_MAX_A_BCAST>(B_t, A_t, LDB, LDA, zmm);
|
||||||
B_t, A_t, LDB, LDA, zmm);
|
|
||||||
B_t += EIGEN_AVX_MAX_K_UNROL * LDB;
|
B_t += EIGEN_AVX_MAX_K_UNROL * LDB;
|
||||||
EIGEN_IF_CONSTEXPR(isARowMajor) A_t += EIGEN_AVX_MAX_K_UNROL; else A_t += EIGEN_AVX_MAX_K_UNROL*LDA;
|
EIGEN_IF_CONSTEXPR(isARowMajor) A_t += EIGEN_AVX_MAX_K_UNROL;
|
||||||
|
else A_t += EIGEN_AVX_MAX_K_UNROL * LDA;
|
||||||
}
|
}
|
||||||
EIGEN_IF_CONSTEXPR(handleKRem) {
|
EIGEN_IF_CONSTEXPR(handleKRem) {
|
||||||
for (int64_t k = K_; k < K; k++) {
|
for (int64_t k = K_; k < K; k++) {
|
||||||
urolls:: template microKernel<isARowMajor,2,2,1,
|
urolls::template microKernel<isARowMajor, 2, 2, 1, EIGEN_AVX_MAX_B_LOAD, EIGEN_AVX_MAX_A_BCAST>(B_t, A_t, LDB,
|
||||||
EIGEN_AVX_MAX_B_LOAD,EIGEN_AVX_MAX_A_BCAST>(
|
LDA, zmm);
|
||||||
B_t, A_t, LDB, LDA, zmm);
|
|
||||||
B_t += LDB;
|
B_t += LDB;
|
||||||
EIGEN_IF_CONSTEXPR(isARowMajor) A_t++; else A_t += LDA;
|
EIGEN_IF_CONSTEXPR(isARowMajor) A_t++;
|
||||||
|
else A_t += LDA;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
EIGEN_IF_CONSTEXPR(isCRowMajor) {
|
EIGEN_IF_CONSTEXPR(isCRowMajor) {
|
||||||
@ -428,18 +443,18 @@ void gemmKernel(Scalar *A_arr, Scalar *B_arr, Scalar *C_arr,
|
|||||||
PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> zmm;
|
PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> zmm;
|
||||||
urolls::template setzero<2, 1>(zmm);
|
urolls::template setzero<2, 1>(zmm);
|
||||||
for (int64_t k = 0; k < K_; k += EIGEN_AVX_MAX_K_UNROL) {
|
for (int64_t k = 0; k < K_; k += EIGEN_AVX_MAX_K_UNROL) {
|
||||||
urolls:: template microKernel<isARowMajor,2,1,EIGEN_AVX_MAX_K_UNROL,
|
urolls::template microKernel<isARowMajor, 2, 1, EIGEN_AVX_MAX_K_UNROL, EIGEN_AVX_MAX_B_LOAD, 1>(B_t, A_t, LDB,
|
||||||
EIGEN_AVX_MAX_B_LOAD,1>(
|
LDA, zmm);
|
||||||
B_t, A_t, LDB, LDA, zmm);
|
|
||||||
B_t += EIGEN_AVX_MAX_K_UNROL * LDB;
|
B_t += EIGEN_AVX_MAX_K_UNROL * LDB;
|
||||||
EIGEN_IF_CONSTEXPR(isARowMajor) A_t += EIGEN_AVX_MAX_K_UNROL; else A_t += EIGEN_AVX_MAX_K_UNROL*LDA;
|
EIGEN_IF_CONSTEXPR(isARowMajor) A_t += EIGEN_AVX_MAX_K_UNROL;
|
||||||
|
else A_t += EIGEN_AVX_MAX_K_UNROL * LDA;
|
||||||
}
|
}
|
||||||
EIGEN_IF_CONSTEXPR(handleKRem) {
|
EIGEN_IF_CONSTEXPR(handleKRem) {
|
||||||
for (int64_t k = K_; k < K; k++) {
|
for (int64_t k = K_; k < K; k++) {
|
||||||
urolls:: template microKernel<isARowMajor,2,1,1,
|
urolls::template microKernel<isARowMajor, 2, 1, 1, EIGEN_AVX_MAX_B_LOAD, 1>(B_t, A_t, LDB, LDA, zmm);
|
||||||
EIGEN_AVX_MAX_B_LOAD,1>(B_t, A_t, LDB, LDA, zmm);
|
|
||||||
B_t += LDB;
|
B_t += LDB;
|
||||||
EIGEN_IF_CONSTEXPR(isARowMajor) A_t++; else A_t += LDA;
|
EIGEN_IF_CONSTEXPR(isARowMajor) A_t++;
|
||||||
|
else A_t += LDA;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
EIGEN_IF_CONSTEXPR(isCRowMajor) {
|
EIGEN_IF_CONSTEXPR(isCRowMajor) {
|
||||||
@ -460,19 +475,19 @@ void gemmKernel(Scalar *A_arr, Scalar *B_arr, Scalar *C_arr,
|
|||||||
PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> zmm;
|
PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> zmm;
|
||||||
urolls::template setzero<1, EIGEN_AVX_MAX_NUM_ROW>(zmm);
|
urolls::template setzero<1, EIGEN_AVX_MAX_NUM_ROW>(zmm);
|
||||||
for (int64_t k = 0; k < K_; k += EIGEN_AVX_MAX_K_UNROL) {
|
for (int64_t k = 0; k < K_; k += EIGEN_AVX_MAX_K_UNROL) {
|
||||||
urolls:: template microKernel<isARowMajor,1,EIGEN_AVX_MAX_NUM_ROW,EIGEN_AVX_MAX_K_UNROL,
|
urolls::template microKernel<isARowMajor, 1, EIGEN_AVX_MAX_NUM_ROW, EIGEN_AVX_MAX_K_UNROL, EIGEN_AVX_MAX_B_LOAD,
|
||||||
EIGEN_AVX_MAX_B_LOAD,EIGEN_AVX_MAX_A_BCAST>(
|
EIGEN_AVX_MAX_A_BCAST>(B_t, A_t, LDB, LDA, zmm);
|
||||||
B_t, A_t, LDB, LDA, zmm);
|
|
||||||
B_t += EIGEN_AVX_MAX_K_UNROL * LDB;
|
B_t += EIGEN_AVX_MAX_K_UNROL * LDB;
|
||||||
EIGEN_IF_CONSTEXPR(isARowMajor) A_t += EIGEN_AVX_MAX_K_UNROL; else A_t += EIGEN_AVX_MAX_K_UNROL*LDA;
|
EIGEN_IF_CONSTEXPR(isARowMajor) A_t += EIGEN_AVX_MAX_K_UNROL;
|
||||||
|
else A_t += EIGEN_AVX_MAX_K_UNROL * LDA;
|
||||||
}
|
}
|
||||||
EIGEN_IF_CONSTEXPR(handleKRem) {
|
EIGEN_IF_CONSTEXPR(handleKRem) {
|
||||||
for (int64_t k = K_; k < K; k++) {
|
for (int64_t k = K_; k < K; k++) {
|
||||||
urolls:: template microKernel<isARowMajor,1,EIGEN_AVX_MAX_NUM_ROW,1,
|
urolls::template microKernel<isARowMajor, 1, EIGEN_AVX_MAX_NUM_ROW, 1, EIGEN_AVX_B_LOAD_SETS * 1,
|
||||||
EIGEN_AVX_B_LOAD_SETS*1,EIGEN_AVX_MAX_A_BCAST>(
|
EIGEN_AVX_MAX_A_BCAST>(B_t, A_t, LDB, LDA, zmm);
|
||||||
B_t, A_t, LDB, LDA, zmm);
|
|
||||||
B_t += LDB;
|
B_t += LDB;
|
||||||
EIGEN_IF_CONSTEXPR(isARowMajor) A_t++; else A_t += LDA;
|
EIGEN_IF_CONSTEXPR(isARowMajor) A_t++;
|
||||||
|
else A_t += LDA;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
EIGEN_IF_CONSTEXPR(isCRowMajor) {
|
EIGEN_IF_CONSTEXPR(isCRowMajor) {
|
||||||
@ -489,19 +504,19 @@ void gemmKernel(Scalar *A_arr, Scalar *B_arr, Scalar *C_arr,
|
|||||||
PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> zmm;
|
PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> zmm;
|
||||||
urolls::template setzero<1, 4>(zmm);
|
urolls::template setzero<1, 4>(zmm);
|
||||||
for (int64_t k = 0; k < K_; k += EIGEN_AVX_MAX_K_UNROL) {
|
for (int64_t k = 0; k < K_; k += EIGEN_AVX_MAX_K_UNROL) {
|
||||||
urolls:: template microKernel<isARowMajor,1,4,EIGEN_AVX_MAX_K_UNROL,
|
urolls::template microKernel<isARowMajor, 1, 4, EIGEN_AVX_MAX_K_UNROL, EIGEN_AVX_MAX_B_LOAD,
|
||||||
EIGEN_AVX_MAX_B_LOAD,EIGEN_AVX_MAX_A_BCAST>(
|
EIGEN_AVX_MAX_A_BCAST>(B_t, A_t, LDB, LDA, zmm);
|
||||||
B_t, A_t, LDB, LDA, zmm);
|
|
||||||
B_t += EIGEN_AVX_MAX_K_UNROL * LDB;
|
B_t += EIGEN_AVX_MAX_K_UNROL * LDB;
|
||||||
EIGEN_IF_CONSTEXPR(isARowMajor) A_t += EIGEN_AVX_MAX_K_UNROL; else A_t += EIGEN_AVX_MAX_K_UNROL*LDA;
|
EIGEN_IF_CONSTEXPR(isARowMajor) A_t += EIGEN_AVX_MAX_K_UNROL;
|
||||||
|
else A_t += EIGEN_AVX_MAX_K_UNROL * LDA;
|
||||||
}
|
}
|
||||||
EIGEN_IF_CONSTEXPR(handleKRem) {
|
EIGEN_IF_CONSTEXPR(handleKRem) {
|
||||||
for (int64_t k = K_; k < K; k++) {
|
for (int64_t k = K_; k < K; k++) {
|
||||||
urolls:: template microKernel<isARowMajor,1,4,1,
|
urolls::template microKernel<isARowMajor, 1, 4, 1, EIGEN_AVX_MAX_B_LOAD, EIGEN_AVX_MAX_A_BCAST>(B_t, A_t, LDB,
|
||||||
EIGEN_AVX_MAX_B_LOAD,EIGEN_AVX_MAX_A_BCAST>(
|
LDA, zmm);
|
||||||
B_t, A_t, LDB, LDA, zmm);
|
|
||||||
B_t += LDB;
|
B_t += LDB;
|
||||||
EIGEN_IF_CONSTEXPR(isARowMajor) A_t++; else A_t += LDA;
|
EIGEN_IF_CONSTEXPR(isARowMajor) A_t++;
|
||||||
|
else A_t += LDA;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
EIGEN_IF_CONSTEXPR(isCRowMajor) {
|
EIGEN_IF_CONSTEXPR(isCRowMajor) {
|
||||||
@ -519,19 +534,19 @@ void gemmKernel(Scalar *A_arr, Scalar *B_arr, Scalar *C_arr,
|
|||||||
PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> zmm;
|
PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> zmm;
|
||||||
urolls::template setzero<1, 2>(zmm);
|
urolls::template setzero<1, 2>(zmm);
|
||||||
for (int64_t k = 0; k < K_; k += EIGEN_AVX_MAX_K_UNROL) {
|
for (int64_t k = 0; k < K_; k += EIGEN_AVX_MAX_K_UNROL) {
|
||||||
urolls:: template microKernel<isARowMajor,1,2,EIGEN_AVX_MAX_K_UNROL,
|
urolls::template microKernel<isARowMajor, 1, 2, EIGEN_AVX_MAX_K_UNROL, EIGEN_AVX_MAX_B_LOAD,
|
||||||
EIGEN_AVX_MAX_B_LOAD,EIGEN_AVX_MAX_A_BCAST>(
|
EIGEN_AVX_MAX_A_BCAST>(B_t, A_t, LDB, LDA, zmm);
|
||||||
B_t, A_t, LDB, LDA, zmm);
|
|
||||||
B_t += EIGEN_AVX_MAX_K_UNROL * LDB;
|
B_t += EIGEN_AVX_MAX_K_UNROL * LDB;
|
||||||
EIGEN_IF_CONSTEXPR(isARowMajor) A_t += EIGEN_AVX_MAX_K_UNROL; else A_t += EIGEN_AVX_MAX_K_UNROL*LDA;
|
EIGEN_IF_CONSTEXPR(isARowMajor) A_t += EIGEN_AVX_MAX_K_UNROL;
|
||||||
|
else A_t += EIGEN_AVX_MAX_K_UNROL * LDA;
|
||||||
}
|
}
|
||||||
EIGEN_IF_CONSTEXPR(handleKRem) {
|
EIGEN_IF_CONSTEXPR(handleKRem) {
|
||||||
for (int64_t k = K_; k < K; k++) {
|
for (int64_t k = K_; k < K; k++) {
|
||||||
urolls:: template microKernel<isARowMajor,1,2,1,
|
urolls::template microKernel<isARowMajor, 1, 2, 1, EIGEN_AVX_MAX_B_LOAD, EIGEN_AVX_MAX_A_BCAST>(B_t, A_t, LDB,
|
||||||
EIGEN_AVX_MAX_B_LOAD,EIGEN_AVX_MAX_A_BCAST>(
|
LDA, zmm);
|
||||||
B_t, A_t, LDB, LDA, zmm);
|
|
||||||
B_t += LDB;
|
B_t += LDB;
|
||||||
EIGEN_IF_CONSTEXPR(isARowMajor) A_t++; else A_t += LDA;
|
EIGEN_IF_CONSTEXPR(isARowMajor) A_t++;
|
||||||
|
else A_t += LDA;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
EIGEN_IF_CONSTEXPR(isCRowMajor) {
|
EIGEN_IF_CONSTEXPR(isCRowMajor) {
|
||||||
@ -550,17 +565,18 @@ void gemmKernel(Scalar *A_arr, Scalar *B_arr, Scalar *C_arr,
|
|||||||
urolls::template setzero<1, 1>(zmm);
|
urolls::template setzero<1, 1>(zmm);
|
||||||
{
|
{
|
||||||
for (int64_t k = 0; k < K_; k += EIGEN_AVX_MAX_K_UNROL) {
|
for (int64_t k = 0; k < K_; k += EIGEN_AVX_MAX_K_UNROL) {
|
||||||
urolls:: template microKernel<isARowMajor,1,1,EIGEN_AVX_MAX_K_UNROL,
|
urolls::template microKernel<isARowMajor, 1, 1, EIGEN_AVX_MAX_K_UNROL, EIGEN_AVX_MAX_B_LOAD, 1>(B_t, A_t, LDB,
|
||||||
EIGEN_AVX_MAX_B_LOAD,1>(
|
LDA, zmm);
|
||||||
B_t, A_t, LDB, LDA, zmm);
|
|
||||||
B_t += EIGEN_AVX_MAX_K_UNROL * LDB;
|
B_t += EIGEN_AVX_MAX_K_UNROL * LDB;
|
||||||
EIGEN_IF_CONSTEXPR(isARowMajor) A_t += EIGEN_AVX_MAX_K_UNROL; else A_t += EIGEN_AVX_MAX_K_UNROL*LDA;
|
EIGEN_IF_CONSTEXPR(isARowMajor) A_t += EIGEN_AVX_MAX_K_UNROL;
|
||||||
|
else A_t += EIGEN_AVX_MAX_K_UNROL * LDA;
|
||||||
}
|
}
|
||||||
EIGEN_IF_CONSTEXPR(handleKRem) {
|
EIGEN_IF_CONSTEXPR(handleKRem) {
|
||||||
for (int64_t k = K_; k < K; k++) {
|
for (int64_t k = K_; k < K; k++) {
|
||||||
urolls::template microKernel<isARowMajor, 1, 1, 1, EIGEN_AVX_B_LOAD_SETS * 1, 1>(B_t, A_t, LDB, LDA, zmm);
|
urolls::template microKernel<isARowMajor, 1, 1, 1, EIGEN_AVX_B_LOAD_SETS * 1, 1>(B_t, A_t, LDB, LDA, zmm);
|
||||||
B_t += LDB;
|
B_t += LDB;
|
||||||
EIGEN_IF_CONSTEXPR(isARowMajor) A_t++; else A_t += LDA;
|
EIGEN_IF_CONSTEXPR(isARowMajor) A_t++;
|
||||||
|
else A_t += LDA;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
EIGEN_IF_CONSTEXPR(isCRowMajor) {
|
EIGEN_IF_CONSTEXPR(isCRowMajor) {
|
||||||
@ -583,19 +599,19 @@ void gemmKernel(Scalar *A_arr, Scalar *B_arr, Scalar *C_arr,
|
|||||||
PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> zmm;
|
PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> zmm;
|
||||||
urolls::template setzero<1, EIGEN_AVX_MAX_NUM_ROW>(zmm);
|
urolls::template setzero<1, EIGEN_AVX_MAX_NUM_ROW>(zmm);
|
||||||
for (int64_t k = 0; k < K_; k += EIGEN_AVX_MAX_K_UNROL) {
|
for (int64_t k = 0; k < K_; k += EIGEN_AVX_MAX_K_UNROL) {
|
||||||
urolls:: template microKernel<isARowMajor,1,EIGEN_AVX_MAX_NUM_ROW,EIGEN_AVX_MAX_K_UNROL,
|
urolls::template microKernel<isARowMajor, 1, EIGEN_AVX_MAX_NUM_ROW, EIGEN_AVX_MAX_K_UNROL, EIGEN_AVX_MAX_B_LOAD,
|
||||||
EIGEN_AVX_MAX_B_LOAD,EIGEN_AVX_MAX_A_BCAST,true>(
|
EIGEN_AVX_MAX_A_BCAST, true>(B_t, A_t, LDB, LDA, zmm, N - j);
|
||||||
B_t, A_t, LDB, LDA, zmm, N - j);
|
|
||||||
B_t += EIGEN_AVX_MAX_K_UNROL * LDB;
|
B_t += EIGEN_AVX_MAX_K_UNROL * LDB;
|
||||||
EIGEN_IF_CONSTEXPR(isARowMajor) A_t += EIGEN_AVX_MAX_K_UNROL; else A_t += EIGEN_AVX_MAX_K_UNROL*LDA;
|
EIGEN_IF_CONSTEXPR(isARowMajor) A_t += EIGEN_AVX_MAX_K_UNROL;
|
||||||
|
else A_t += EIGEN_AVX_MAX_K_UNROL * LDA;
|
||||||
}
|
}
|
||||||
EIGEN_IF_CONSTEXPR(handleKRem) {
|
EIGEN_IF_CONSTEXPR(handleKRem) {
|
||||||
for (int64_t k = K_; k < K; k++) {
|
for (int64_t k = K_; k < K; k++) {
|
||||||
urolls:: template microKernel<isARowMajor,1,EIGEN_AVX_MAX_NUM_ROW,1,
|
urolls::template microKernel<isARowMajor, 1, EIGEN_AVX_MAX_NUM_ROW, 1, EIGEN_AVX_MAX_B_LOAD,
|
||||||
EIGEN_AVX_MAX_B_LOAD,EIGEN_AVX_MAX_A_BCAST,true>(
|
EIGEN_AVX_MAX_A_BCAST, true>(B_t, A_t, LDB, LDA, zmm, N - j);
|
||||||
B_t, A_t, LDB, LDA, zmm, N - j);
|
|
||||||
B_t += LDB;
|
B_t += LDB;
|
||||||
EIGEN_IF_CONSTEXPR(isARowMajor) A_t++; else A_t += LDA;
|
EIGEN_IF_CONSTEXPR(isARowMajor) A_t++;
|
||||||
|
else A_t += LDA;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
EIGEN_IF_CONSTEXPR(isCRowMajor) {
|
EIGEN_IF_CONSTEXPR(isCRowMajor) {
|
||||||
@ -612,19 +628,19 @@ void gemmKernel(Scalar *A_arr, Scalar *B_arr, Scalar *C_arr,
|
|||||||
PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> zmm;
|
PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> zmm;
|
||||||
urolls::template setzero<1, 4>(zmm);
|
urolls::template setzero<1, 4>(zmm);
|
||||||
for (int64_t k = 0; k < K_; k += EIGEN_AVX_MAX_K_UNROL) {
|
for (int64_t k = 0; k < K_; k += EIGEN_AVX_MAX_K_UNROL) {
|
||||||
urolls:: template microKernel<isARowMajor,1,4,EIGEN_AVX_MAX_K_UNROL,
|
urolls::template microKernel<isARowMajor, 1, 4, EIGEN_AVX_MAX_K_UNROL, EIGEN_AVX_MAX_B_LOAD,
|
||||||
EIGEN_AVX_MAX_B_LOAD,EIGEN_AVX_MAX_A_BCAST,true>(
|
EIGEN_AVX_MAX_A_BCAST, true>(B_t, A_t, LDB, LDA, zmm, N - j);
|
||||||
B_t, A_t, LDB, LDA, zmm, N - j);
|
|
||||||
B_t += EIGEN_AVX_MAX_K_UNROL * LDB;
|
B_t += EIGEN_AVX_MAX_K_UNROL * LDB;
|
||||||
EIGEN_IF_CONSTEXPR(isARowMajor) A_t += EIGEN_AVX_MAX_K_UNROL; else A_t += EIGEN_AVX_MAX_K_UNROL*LDA;
|
EIGEN_IF_CONSTEXPR(isARowMajor) A_t += EIGEN_AVX_MAX_K_UNROL;
|
||||||
|
else A_t += EIGEN_AVX_MAX_K_UNROL * LDA;
|
||||||
}
|
}
|
||||||
EIGEN_IF_CONSTEXPR(handleKRem) {
|
EIGEN_IF_CONSTEXPR(handleKRem) {
|
||||||
for (int64_t k = K_; k < K; k++) {
|
for (int64_t k = K_; k < K; k++) {
|
||||||
urolls:: template microKernel<isARowMajor,1,4,1,
|
urolls::template microKernel<isARowMajor, 1, 4, 1, EIGEN_AVX_MAX_B_LOAD, EIGEN_AVX_MAX_A_BCAST, true>(
|
||||||
EIGEN_AVX_MAX_B_LOAD,EIGEN_AVX_MAX_A_BCAST,true>(
|
|
||||||
B_t, A_t, LDB, LDA, zmm, N - j);
|
B_t, A_t, LDB, LDA, zmm, N - j);
|
||||||
B_t += LDB;
|
B_t += LDB;
|
||||||
EIGEN_IF_CONSTEXPR(isARowMajor) A_t++; else A_t += LDA;
|
EIGEN_IF_CONSTEXPR(isARowMajor) A_t++;
|
||||||
|
else A_t += LDA;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
EIGEN_IF_CONSTEXPR(isCRowMajor) {
|
EIGEN_IF_CONSTEXPR(isCRowMajor) {
|
||||||
@ -642,19 +658,19 @@ void gemmKernel(Scalar *A_arr, Scalar *B_arr, Scalar *C_arr,
|
|||||||
PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> zmm;
|
PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> zmm;
|
||||||
urolls::template setzero<1, 2>(zmm);
|
urolls::template setzero<1, 2>(zmm);
|
||||||
for (int64_t k = 0; k < K_; k += EIGEN_AVX_MAX_K_UNROL) {
|
for (int64_t k = 0; k < K_; k += EIGEN_AVX_MAX_K_UNROL) {
|
||||||
urolls:: template microKernel<isARowMajor,1,2,EIGEN_AVX_MAX_K_UNROL,
|
urolls::template microKernel<isARowMajor, 1, 2, EIGEN_AVX_MAX_K_UNROL, EIGEN_AVX_MAX_B_LOAD,
|
||||||
EIGEN_AVX_MAX_B_LOAD,EIGEN_AVX_MAX_A_BCAST,true>(
|
EIGEN_AVX_MAX_A_BCAST, true>(B_t, A_t, LDB, LDA, zmm, N - j);
|
||||||
B_t, A_t, LDB, LDA, zmm, N - j);
|
|
||||||
B_t += EIGEN_AVX_MAX_K_UNROL * LDB;
|
B_t += EIGEN_AVX_MAX_K_UNROL * LDB;
|
||||||
EIGEN_IF_CONSTEXPR(isARowMajor) A_t += EIGEN_AVX_MAX_K_UNROL; else A_t += EIGEN_AVX_MAX_K_UNROL*LDA;
|
EIGEN_IF_CONSTEXPR(isARowMajor) A_t += EIGEN_AVX_MAX_K_UNROL;
|
||||||
|
else A_t += EIGEN_AVX_MAX_K_UNROL * LDA;
|
||||||
}
|
}
|
||||||
EIGEN_IF_CONSTEXPR(handleKRem) {
|
EIGEN_IF_CONSTEXPR(handleKRem) {
|
||||||
for (int64_t k = K_; k < K; k++) {
|
for (int64_t k = K_; k < K; k++) {
|
||||||
urolls:: template microKernel<isARowMajor,1,2,1,
|
urolls::template microKernel<isARowMajor, 1, 2, 1, EIGEN_AVX_MAX_B_LOAD, EIGEN_AVX_MAX_A_BCAST, true>(
|
||||||
EIGEN_AVX_MAX_B_LOAD,EIGEN_AVX_MAX_A_BCAST,true>(
|
|
||||||
B_t, A_t, LDB, LDA, zmm, N - j);
|
B_t, A_t, LDB, LDA, zmm, N - j);
|
||||||
B_t += LDB;
|
B_t += LDB;
|
||||||
EIGEN_IF_CONSTEXPR(isARowMajor) A_t++; else A_t += LDA;
|
EIGEN_IF_CONSTEXPR(isARowMajor) A_t++;
|
||||||
|
else A_t += LDA;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
EIGEN_IF_CONSTEXPR(isCRowMajor) {
|
EIGEN_IF_CONSTEXPR(isCRowMajor) {
|
||||||
@ -672,19 +688,19 @@ void gemmKernel(Scalar *A_arr, Scalar *B_arr, Scalar *C_arr,
|
|||||||
PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> zmm;
|
PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> zmm;
|
||||||
urolls::template setzero<1, 1>(zmm);
|
urolls::template setzero<1, 1>(zmm);
|
||||||
for (int64_t k = 0; k < K_; k += EIGEN_AVX_MAX_K_UNROL) {
|
for (int64_t k = 0; k < K_; k += EIGEN_AVX_MAX_K_UNROL) {
|
||||||
urolls:: template microKernel<isARowMajor,1,1,EIGEN_AVX_MAX_K_UNROL,
|
urolls::template microKernel<isARowMajor, 1, 1, EIGEN_AVX_MAX_K_UNROL, EIGEN_AVX_MAX_B_LOAD, 1, true>(
|
||||||
EIGEN_AVX_MAX_B_LOAD,1,true>(
|
|
||||||
B_t, A_t, LDB, LDA, zmm, N - j);
|
B_t, A_t, LDB, LDA, zmm, N - j);
|
||||||
B_t += EIGEN_AVX_MAX_K_UNROL * LDB;
|
B_t += EIGEN_AVX_MAX_K_UNROL * LDB;
|
||||||
EIGEN_IF_CONSTEXPR(isARowMajor) A_t += EIGEN_AVX_MAX_K_UNROL; else A_t += EIGEN_AVX_MAX_K_UNROL*LDA;
|
EIGEN_IF_CONSTEXPR(isARowMajor) A_t += EIGEN_AVX_MAX_K_UNROL;
|
||||||
|
else A_t += EIGEN_AVX_MAX_K_UNROL * LDA;
|
||||||
}
|
}
|
||||||
EIGEN_IF_CONSTEXPR(handleKRem) {
|
EIGEN_IF_CONSTEXPR(handleKRem) {
|
||||||
for (int64_t k = K_; k < K; k++) {
|
for (int64_t k = K_; k < K; k++) {
|
||||||
urolls:: template microKernel<isARowMajor,1,1,1,
|
urolls::template microKernel<isARowMajor, 1, 1, 1, EIGEN_AVX_MAX_B_LOAD, 1, true>(B_t, A_t, LDB, LDA, zmm,
|
||||||
EIGEN_AVX_MAX_B_LOAD,1,true>(
|
N - j);
|
||||||
B_t, A_t, LDB, LDA, zmm, N - j);
|
|
||||||
B_t += LDB;
|
B_t += LDB;
|
||||||
EIGEN_IF_CONSTEXPR(isARowMajor) A_t++; else A_t += LDA;
|
EIGEN_IF_CONSTEXPR(isARowMajor) A_t++;
|
||||||
|
else A_t += LDA;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
EIGEN_IF_CONSTEXPR(isCRowMajor) {
|
EIGEN_IF_CONSTEXPR(isCRowMajor) {
|
||||||
@ -707,9 +723,7 @@ void gemmKernel(Scalar *A_arr, Scalar *B_arr, Scalar *C_arr,
|
|||||||
* The B matrix (RHS) is assumed to be row-major
|
* The B matrix (RHS) is assumed to be row-major
|
||||||
*/
|
*/
|
||||||
template <typename Scalar, typename vec, int64_t unrollM, bool isARowMajor, bool isFWDSolve, bool isUnitDiag>
|
template <typename Scalar, typename vec, int64_t unrollM, bool isARowMajor, bool isFWDSolve, bool isUnitDiag>
|
||||||
static EIGEN_ALWAYS_INLINE
|
static EIGEN_ALWAYS_INLINE void triSolveKernel(Scalar *A_arr, Scalar *B_arr, int64_t K, int64_t LDA, int64_t LDB) {
|
||||||
void triSolveKernel(Scalar *A_arr, Scalar *B_arr, int64_t K, int64_t LDA, int64_t LDB) {
|
|
||||||
|
|
||||||
static_assert(unrollM <= EIGEN_AVX_MAX_NUM_ROW, "unrollM should be equal to EIGEN_AVX_MAX_NUM_ROW");
|
static_assert(unrollM <= EIGEN_AVX_MAX_NUM_ROW, "unrollM should be equal to EIGEN_AVX_MAX_NUM_ROW");
|
||||||
using urolls = unrolls::trsm<Scalar>;
|
using urolls = unrolls::trsm<Scalar>;
|
||||||
constexpr int64_t U3 = urolls::PacketSize * 3;
|
constexpr int64_t U3 = urolls::PacketSize * 3;
|
||||||
@ -722,30 +736,30 @@ void triSolveKernel(Scalar *A_arr, Scalar *B_arr, int64_t K, int64_t LDA, int64_
|
|||||||
int64_t k = 0;
|
int64_t k = 0;
|
||||||
while (K - k >= U3) {
|
while (K - k >= U3) {
|
||||||
urolls::template loadRHS<isFWDSolve, unrollM, 3>(B_arr + k, LDB, RHSInPacket);
|
urolls::template loadRHS<isFWDSolve, unrollM, 3>(B_arr + k, LDB, RHSInPacket);
|
||||||
urolls:: template triSolveMicroKernel<isARowMajor, isFWDSolve, isUnitDiag, unrollM, 3>(
|
urolls::template triSolveMicroKernel<isARowMajor, isFWDSolve, isUnitDiag, unrollM, 3>(A_arr, LDA, RHSInPacket,
|
||||||
A_arr, LDA, RHSInPacket, AInPacket);
|
AInPacket);
|
||||||
urolls::template storeRHS<isFWDSolve, unrollM, 3>(B_arr + k, LDB, RHSInPacket);
|
urolls::template storeRHS<isFWDSolve, unrollM, 3>(B_arr + k, LDB, RHSInPacket);
|
||||||
k += U3;
|
k += U3;
|
||||||
}
|
}
|
||||||
if (K - k >= U2) {
|
if (K - k >= U2) {
|
||||||
urolls::template loadRHS<isFWDSolve, unrollM, 2>(B_arr + k, LDB, RHSInPacket);
|
urolls::template loadRHS<isFWDSolve, unrollM, 2>(B_arr + k, LDB, RHSInPacket);
|
||||||
urolls:: template triSolveMicroKernel<isARowMajor, isFWDSolve, isUnitDiag, unrollM, 2>(
|
urolls::template triSolveMicroKernel<isARowMajor, isFWDSolve, isUnitDiag, unrollM, 2>(A_arr, LDA, RHSInPacket,
|
||||||
A_arr, LDA, RHSInPacket, AInPacket);
|
AInPacket);
|
||||||
urolls::template storeRHS<isFWDSolve, unrollM, 2>(B_arr + k, LDB, RHSInPacket);
|
urolls::template storeRHS<isFWDSolve, unrollM, 2>(B_arr + k, LDB, RHSInPacket);
|
||||||
k += U2;
|
k += U2;
|
||||||
}
|
}
|
||||||
if (K - k >= U1) {
|
if (K - k >= U1) {
|
||||||
urolls::template loadRHS<isFWDSolve, unrollM, 1>(B_arr + k, LDB, RHSInPacket);
|
urolls::template loadRHS<isFWDSolve, unrollM, 1>(B_arr + k, LDB, RHSInPacket);
|
||||||
urolls:: template triSolveMicroKernel<isARowMajor, isFWDSolve, isUnitDiag, unrollM, 1>(
|
urolls::template triSolveMicroKernel<isARowMajor, isFWDSolve, isUnitDiag, unrollM, 1>(A_arr, LDA, RHSInPacket,
|
||||||
A_arr, LDA, RHSInPacket, AInPacket);
|
AInPacket);
|
||||||
urolls::template storeRHS<isFWDSolve, unrollM, 1>(B_arr + k, LDB, RHSInPacket);
|
urolls::template storeRHS<isFWDSolve, unrollM, 1>(B_arr + k, LDB, RHSInPacket);
|
||||||
k += U1;
|
k += U1;
|
||||||
}
|
}
|
||||||
if (K - k > 0) {
|
if (K - k > 0) {
|
||||||
// Handle remaining number of RHS
|
// Handle remaining number of RHS
|
||||||
urolls::template loadRHS<isFWDSolve, unrollM, 1, true>(B_arr + k, LDB, RHSInPacket, K - k);
|
urolls::template loadRHS<isFWDSolve, unrollM, 1, true>(B_arr + k, LDB, RHSInPacket, K - k);
|
||||||
urolls::template triSolveMicroKernel<isARowMajor, isFWDSolve, isUnitDiag, unrollM, 1>(
|
urolls::template triSolveMicroKernel<isARowMajor, isFWDSolve, isUnitDiag, unrollM, 1>(A_arr, LDA, RHSInPacket,
|
||||||
A_arr, LDA, RHSInPacket, AInPacket);
|
AInPacket);
|
||||||
urolls::template storeRHS<isFWDSolve, unrollM, 1, true>(B_arr + k, LDB, RHSInPacket, K - k);
|
urolls::template storeRHS<isFWDSolve, unrollM, 1, true>(B_arr + k, LDB, RHSInPacket, K - k);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -790,9 +804,8 @@ void triSolveKernelLxK(Scalar *A_arr, Scalar *B_arr, int64_t M, int64_t K, int64
|
|||||||
*
|
*
|
||||||
*/
|
*/
|
||||||
template <typename Scalar, bool toTemp = true, bool remM = false>
|
template <typename Scalar, bool toTemp = true, bool remM = false>
|
||||||
static EIGEN_ALWAYS_INLINE
|
static EIGEN_ALWAYS_INLINE void copyBToRowMajor(Scalar *B_arr, int64_t LDB, int64_t K, Scalar *B_temp, int64_t LDB_,
|
||||||
void copyBToRowMajor(Scalar *B_arr, int64_t LDB, int64_t K,
|
int64_t remM_ = 0) {
|
||||||
Scalar *B_temp, int64_t LDB_, int64_t remM_ = 0) {
|
|
||||||
EIGEN_UNUSED_VARIABLE(remM_);
|
EIGEN_UNUSED_VARIABLE(remM_);
|
||||||
using urolls = unrolls::transB<Scalar>;
|
using urolls = unrolls::transB<Scalar>;
|
||||||
using vecHalf = typename std::conditional<std::is_same<Scalar, float>::value, vecHalfFloat, vecFullDouble>::type;
|
using vecHalf = typename std::conditional<std::is_same<Scalar, float>::value, vecHalfFloat, vecFullDouble>::type;
|
||||||
@ -809,11 +822,13 @@ void copyBToRowMajor(Scalar *B_arr, int64_t LDB, int64_t K,
|
|||||||
}
|
}
|
||||||
if (K - k >= U2) {
|
if (K - k >= U2) {
|
||||||
urolls::template transB_kernel<U2, toTemp, remM>(B_arr + k * LDB, LDB, B_temp, LDB_, ymm, remM_);
|
urolls::template transB_kernel<U2, toTemp, remM>(B_arr + k * LDB, LDB, B_temp, LDB_, ymm, remM_);
|
||||||
B_temp += U2; k += U2;
|
B_temp += U2;
|
||||||
|
k += U2;
|
||||||
}
|
}
|
||||||
if (K - k >= U1) {
|
if (K - k >= U1) {
|
||||||
urolls::template transB_kernel<U1, toTemp, remM>(B_arr + k * LDB, LDB, B_temp, LDB_, ymm, remM_);
|
urolls::template transB_kernel<U1, toTemp, remM>(B_arr + k * LDB, LDB, B_temp, LDB_, ymm, remM_);
|
||||||
B_temp += U1; k += U1;
|
B_temp += U1;
|
||||||
|
k += U1;
|
||||||
}
|
}
|
||||||
EIGEN_IF_CONSTEXPR(U1 > 8) {
|
EIGEN_IF_CONSTEXPR(U1 > 8) {
|
||||||
// Note: without "if constexpr" this section of code will also be
|
// Note: without "if constexpr" this section of code will also be
|
||||||
@ -821,7 +836,8 @@ void copyBToRowMajor(Scalar *B_arr, int64_t LDB, int64_t K,
|
|||||||
// to make sure the counter is not non-negative.
|
// to make sure the counter is not non-negative.
|
||||||
if (K - k >= 8) {
|
if (K - k >= 8) {
|
||||||
urolls::template transB_kernel<8, toTemp, remM>(B_arr + k * LDB, LDB, B_temp, LDB_, ymm, remM_);
|
urolls::template transB_kernel<8, toTemp, remM>(B_arr + k * LDB, LDB, B_temp, LDB_, ymm, remM_);
|
||||||
B_temp += 8; k += 8;
|
B_temp += 8;
|
||||||
|
k += 8;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
EIGEN_IF_CONSTEXPR(U1 > 4) {
|
EIGEN_IF_CONSTEXPR(U1 > 4) {
|
||||||
@ -830,19 +846,48 @@ void copyBToRowMajor(Scalar *B_arr, int64_t LDB, int64_t K,
|
|||||||
// to make sure the counter is not non-negative.
|
// to make sure the counter is not non-negative.
|
||||||
if (K - k >= 4) {
|
if (K - k >= 4) {
|
||||||
urolls::template transB_kernel<4, toTemp, remM>(B_arr + k * LDB, LDB, B_temp, LDB_, ymm, remM_);
|
urolls::template transB_kernel<4, toTemp, remM>(B_arr + k * LDB, LDB, B_temp, LDB_, ymm, remM_);
|
||||||
B_temp += 4; k += 4;
|
B_temp += 4;
|
||||||
|
k += 4;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (K - k >= 2) {
|
if (K - k >= 2) {
|
||||||
urolls::template transB_kernel<2, toTemp, remM>(B_arr + k * LDB, LDB, B_temp, LDB_, ymm, remM_);
|
urolls::template transB_kernel<2, toTemp, remM>(B_arr + k * LDB, LDB, B_temp, LDB_, ymm, remM_);
|
||||||
B_temp += 2; k += 2;
|
B_temp += 2;
|
||||||
|
k += 2;
|
||||||
}
|
}
|
||||||
if (K - k >= 1) {
|
if (K - k >= 1) {
|
||||||
urolls::template transB_kernel<1, toTemp, remM>(B_arr + k * LDB, LDB, B_temp, LDB_, ymm, remM_);
|
urolls::template transB_kernel<1, toTemp, remM>(B_arr + k * LDB, LDB, B_temp, LDB_, ymm, remM_);
|
||||||
B_temp += 1; k += 1;
|
B_temp += 1;
|
||||||
|
k += 1;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#if (EIGEN_USE_AVX512_TRSM_L_KERNELS) && defined(EIGEN_NO_MALLOC)
|
||||||
|
/**
|
||||||
|
* Reduce blocking sizes so that the size of the temporary workspace needed is less than "limit" bytes,
|
||||||
|
* - kB must be at least psize
|
||||||
|
* - numM must be at least EIGEN_AVX_MAX_NUM_ROW
|
||||||
|
*/
|
||||||
|
template <typename Scalar, bool isBRowMajor>
|
||||||
|
constexpr std::pair<int64_t, int64_t> trsmBlocking(const int64_t limit) {
|
||||||
|
constexpr int64_t psize = packet_traits<Scalar>::size;
|
||||||
|
int64_t kB = 15 * psize;
|
||||||
|
int64_t numM = 8 * EIGEN_AVX_MAX_NUM_ROW;
|
||||||
|
// If B is rowmajor, no temp workspace needed, so use default blocking sizes.
|
||||||
|
if (isBRowMajor) return {kB, numM};
|
||||||
|
|
||||||
|
// Very simple heuristic, prefer keeping kB as large as possible to fully use vector registers.
|
||||||
|
for (int64_t k = kB; k > psize; k -= psize) {
|
||||||
|
for (int64_t m = numM; m > EIGEN_AVX_MAX_NUM_ROW; m -= EIGEN_AVX_MAX_NUM_ROW) {
|
||||||
|
if ((((k + psize - 1) / psize + 4) * psize) * m * sizeof(Scalar) < limit) {
|
||||||
|
return {k, m};
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return {psize, EIGEN_AVX_MAX_NUM_ROW}; // Minimum blocking size required
|
||||||
|
}
|
||||||
|
#endif // (EIGEN_USE_AVX512_TRSM_L_KERNELS) && defined(EIGEN_NO_MALLOC)
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Main triangular solve driver
|
* Main triangular solve driver
|
||||||
*
|
*
|
||||||
@ -870,8 +915,10 @@ void copyBToRowMajor(Scalar *B_arr, int64_t LDB, int64_t K,
|
|||||||
* Note: For RXX cases M,numRHS should be swapped.
|
* Note: For RXX cases M,numRHS should be swapped.
|
||||||
*
|
*
|
||||||
*/
|
*/
|
||||||
template <typename Scalar, bool isARowMajor = true, bool isBRowMajor = true, bool isFWDSolve = true, bool isUnitDiag = false>
|
template <typename Scalar, bool isARowMajor = true, bool isBRowMajor = true, bool isFWDSolve = true,
|
||||||
|
bool isUnitDiag = false>
|
||||||
void triSolve(Scalar *A_arr, Scalar *B_arr, int64_t M, int64_t numRHS, int64_t LDA, int64_t LDB) {
|
void triSolve(Scalar *A_arr, Scalar *B_arr, int64_t M, int64_t numRHS, int64_t LDA, int64_t LDB) {
|
||||||
|
constexpr int64_t psize = packet_traits<Scalar>::size;
|
||||||
/**
|
/**
|
||||||
* The values for kB, numM were determined experimentally.
|
* The values for kB, numM were determined experimentally.
|
||||||
* kB: Number of RHS we process at a time.
|
* kB: Number of RHS we process at a time.
|
||||||
@ -885,8 +932,30 @@ void triSolve(Scalar *A_arr, Scalar *B_arr, int64_t M, int64_t numRHS, int64_t L
|
|||||||
* large enough to allow GEMM updates to have larger "K"s (see below.) No benchmarking has been done so far to
|
* large enough to allow GEMM updates to have larger "K"s (see below.) No benchmarking has been done so far to
|
||||||
* determine optimal values for numM.
|
* determine optimal values for numM.
|
||||||
*/
|
*/
|
||||||
const int64_t kB = (3*packet_traits<Scalar>::size)*5; // 5*U3
|
#if (EIGEN_USE_AVX512_TRSM_L_KERNELS) && defined(EIGEN_NO_MALLOC)
|
||||||
const int64_t numM = 64;
|
/**
|
||||||
|
* If EIGEN_NO_MALLOC is requested, we try to reduce kB and numM so the maximum temp workspace required is less
|
||||||
|
* than EIGEN_STACK_ALLOCATION_LIMIT. Actual workspace size may be less, depending on the number of vectors to
|
||||||
|
* solve.
|
||||||
|
* - kB must be at least psize
|
||||||
|
* - numM must be at least EIGEN_AVX_MAX_NUM_ROW
|
||||||
|
*
|
||||||
|
* If B is row-major, the blocking sizes are not reduced (no temp workspace needed).
|
||||||
|
*/
|
||||||
|
constexpr std::pair<int64_t, int64_t> blocking_ = trsmBlocking<Scalar, isBRowMajor>(EIGEN_STACK_ALLOCATION_LIMIT);
|
||||||
|
constexpr int64_t kB = blocking_.first;
|
||||||
|
constexpr int64_t numM = blocking_.second;
|
||||||
|
/**
|
||||||
|
* If the temp workspace size exceeds EIGEN_STACK_ALLOCATION_LIMIT even with the minimum blocking sizes,
|
||||||
|
* we throw an assertion. Use -DEIGEN_USE_AVX512_TRSM_L_KERNELS=0 if necessary
|
||||||
|
*/
|
||||||
|
static_assert(!(((((kB + psize - 1) / psize + 4) * psize) * numM * sizeof(Scalar) >= EIGEN_STACK_ALLOCATION_LIMIT) &&
|
||||||
|
!isBRowMajor),
|
||||||
|
"Temp workspace required is too large.");
|
||||||
|
#else
|
||||||
|
constexpr int64_t kB = (3 * psize) * 5; // 5*U3
|
||||||
|
constexpr int64_t numM = 8 * EIGEN_AVX_MAX_NUM_ROW;
|
||||||
|
#endif
|
||||||
|
|
||||||
int64_t sizeBTemp = 0;
|
int64_t sizeBTemp = 0;
|
||||||
Scalar *B_temp = NULL;
|
Scalar *B_temp = NULL;
|
||||||
@ -896,9 +965,17 @@ void triSolve(Scalar *A_arr, Scalar *B_arr, int64_t M, int64_t numRHS, int64_t L
|
|||||||
* transpose it to row-major. Call the solve routine, and copy+transpose it back to the original array.
|
* transpose it to row-major. Call the solve routine, and copy+transpose it back to the original array.
|
||||||
* The updated row-major copy of B is reused in the GEMM updates.
|
* The updated row-major copy of B is reused in the GEMM updates.
|
||||||
*/
|
*/
|
||||||
sizeBTemp = (((std::min(kB, numRHS) + 15)/16+ 4)*16)*numM;
|
sizeBTemp = (((std::min(kB, numRHS) + psize - 1) / psize + 4) * psize) * numM;
|
||||||
B_temp = (Scalar*) handmade_aligned_malloc(sizeof(Scalar)*sizeBTemp,4096);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#if !defined(EIGEN_NO_MALLOC)
|
||||||
|
EIGEN_IF_CONSTEXPR(!isBRowMajor) B_temp = (Scalar *)handmade_aligned_malloc(sizeof(Scalar) * sizeBTemp, 64);
|
||||||
|
#elif (EIGEN_USE_AVX512_TRSM_L_KERNELS) && defined(EIGEN_NO_MALLOC)
|
||||||
|
// Use alloca if malloc not allowed, requested temp workspace size should be less than EIGEN_STACK_ALLOCATION_LIMIT
|
||||||
|
ei_declare_aligned_stack_constructed_variable(Scalar, B_temp_alloca, sizeBTemp, 0);
|
||||||
|
B_temp = B_temp_alloca;
|
||||||
|
#endif
|
||||||
|
|
||||||
for (int64_t k = 0; k < numRHS; k += kB) {
|
for (int64_t k = 0; k < numRHS; k += kB) {
|
||||||
int64_t bK = numRHS - k > kB ? kB : numRHS - k;
|
int64_t bK = numRHS - k > kB ? kB : numRHS - k;
|
||||||
int64_t M_ = (M / EIGEN_AVX_MAX_NUM_ROW) * EIGEN_AVX_MAX_NUM_ROW, gemmOff = 0;
|
int64_t M_ = (M / EIGEN_AVX_MAX_NUM_ROW) * EIGEN_AVX_MAX_NUM_ROW, gemmOff = 0;
|
||||||
@ -950,11 +1027,8 @@ void triSolve(Scalar *A_arr, Scalar *B_arr, int64_t M, int64_t numRHS, int64_t L
|
|||||||
int64_t indB_i = isFWDSolve ? 0 : M - (i + EIGEN_AVX_MAX_NUM_ROW);
|
int64_t indB_i = isFWDSolve ? 0 : M - (i + EIGEN_AVX_MAX_NUM_ROW);
|
||||||
int64_t indB_i2 = isFWDSolve ? i + EIGEN_AVX_MAX_NUM_ROW : M - (i + 2 * EIGEN_AVX_MAX_NUM_ROW);
|
int64_t indB_i2 = isFWDSolve ? i + EIGEN_AVX_MAX_NUM_ROW : M - (i + 2 * EIGEN_AVX_MAX_NUM_ROW);
|
||||||
gemmKernel<Scalar, isARowMajor, isBRowMajor, false, false>(
|
gemmKernel<Scalar, isARowMajor, isBRowMajor, false, false>(
|
||||||
&A_arr[idA<isARowMajor>(indA_i,indA_j,LDA)],
|
&A_arr[idA<isARowMajor>(indA_i, indA_j, LDA)], B_arr + k + indB_i * LDB, B_arr + k + indB_i2 * LDB,
|
||||||
B_arr + k + indB_i*LDB,
|
EIGEN_AVX_MAX_NUM_ROW, bK, i + EIGEN_AVX_MAX_NUM_ROW, LDA, LDB, LDB);
|
||||||
B_arr + k + indB_i2*LDB,
|
|
||||||
EIGEN_AVX_MAX_NUM_ROW, bK, i + EIGEN_AVX_MAX_NUM_ROW,
|
|
||||||
LDA, LDB, LDB);
|
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
if (offsetBTemp + EIGEN_AVX_MAX_NUM_ROW * LDT > sizeBTemp) {
|
if (offsetBTemp + EIGEN_AVX_MAX_NUM_ROW * LDT > sizeBTemp) {
|
||||||
@ -971,14 +1045,11 @@ void triSolve(Scalar *A_arr, Scalar *B_arr, int64_t M, int64_t numRHS, int64_t L
|
|||||||
int64_t indB_i = isFWDSolve ? i + EIGEN_AVX_MAX_NUM_ROW : 0;
|
int64_t indB_i = isFWDSolve ? i + EIGEN_AVX_MAX_NUM_ROW : 0;
|
||||||
int64_t offB_1 = isFWDSolve ? 0 : sizeBTemp - offsetBTemp;
|
int64_t offB_1 = isFWDSolve ? 0 : sizeBTemp - offsetBTemp;
|
||||||
gemmKernel<Scalar, isARowMajor, isBRowMajor, false, false>(
|
gemmKernel<Scalar, isARowMajor, isBRowMajor, false, false>(
|
||||||
&A_arr[idA<isARowMajor>(indA_i, indA_j,LDA)],
|
&A_arr[idA<isARowMajor>(indA_i, indA_j, LDA)], B_temp + offB_1, B_arr + indB_i + (k)*LDB,
|
||||||
B_temp + offB_1,
|
M - (i + EIGEN_AVX_MAX_NUM_ROW), bK, i + EIGEN_AVX_MAX_NUM_ROW - gemmOff, LDA, LDT, LDB);
|
||||||
B_arr + indB_i + (k)*LDB,
|
offsetBTemp = 0;
|
||||||
M - (i + EIGEN_AVX_MAX_NUM_ROW), bK, i + EIGEN_AVX_MAX_NUM_ROW - gemmOff,
|
gemmOff = i + EIGEN_AVX_MAX_NUM_ROW;
|
||||||
LDA, LDT, LDB);
|
} else {
|
||||||
offsetBTemp = 0; gemmOff = i + EIGEN_AVX_MAX_NUM_ROW;
|
|
||||||
}
|
|
||||||
else {
|
|
||||||
/**
|
/**
|
||||||
* If there is enough space in B_temp, we only update the next 8xbK values of B.
|
* If there is enough space in B_temp, we only update the next 8xbK values of B.
|
||||||
*/
|
*/
|
||||||
@ -987,11 +1058,8 @@ void triSolve(Scalar *A_arr, Scalar *B_arr, int64_t M, int64_t numRHS, int64_t L
|
|||||||
int64_t indB_i = isFWDSolve ? i + EIGEN_AVX_MAX_NUM_ROW : M - (i + 2 * EIGEN_AVX_MAX_NUM_ROW);
|
int64_t indB_i = isFWDSolve ? i + EIGEN_AVX_MAX_NUM_ROW : M - (i + 2 * EIGEN_AVX_MAX_NUM_ROW);
|
||||||
int64_t offB_1 = isFWDSolve ? 0 : sizeBTemp - offsetBTemp;
|
int64_t offB_1 = isFWDSolve ? 0 : sizeBTemp - offsetBTemp;
|
||||||
gemmKernel<Scalar, isARowMajor, isBRowMajor, false, false>(
|
gemmKernel<Scalar, isARowMajor, isBRowMajor, false, false>(
|
||||||
&A_arr[idA<isARowMajor>(indA_i,indA_j,LDA)],
|
&A_arr[idA<isARowMajor>(indA_i, indA_j, LDA)], B_temp + offB_1, B_arr + indB_i + (k)*LDB,
|
||||||
B_temp + offB_1,
|
EIGEN_AVX_MAX_NUM_ROW, bK, i + EIGEN_AVX_MAX_NUM_ROW - gemmOff, LDA, LDT, LDB);
|
||||||
B_arr + indB_i + (k)*LDB,
|
|
||||||
EIGEN_AVX_MAX_NUM_ROW, bK, i + EIGEN_AVX_MAX_NUM_ROW - gemmOff,
|
|
||||||
LDA, LDT, LDB);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -1006,23 +1074,17 @@ void triSolve(Scalar *A_arr, Scalar *B_arr, int64_t M, int64_t numRHS, int64_t L
|
|||||||
int64_t indB_i = isFWDSolve ? 0 : bM;
|
int64_t indB_i = isFWDSolve ? 0 : bM;
|
||||||
int64_t indB_i2 = isFWDSolve ? M_ : 0;
|
int64_t indB_i2 = isFWDSolve ? M_ : 0;
|
||||||
gemmKernel<Scalar, isARowMajor, isBRowMajor, false, false>(
|
gemmKernel<Scalar, isARowMajor, isBRowMajor, false, false>(
|
||||||
&A_arr[idA<isARowMajor>(indA_i,indA_j,LDA)],
|
&A_arr[idA<isARowMajor>(indA_i, indA_j, LDA)], B_arr + k + indB_i * LDB, B_arr + k + indB_i2 * LDB, bM,
|
||||||
B_arr + k +indB_i*LDB,
|
bK, M_, LDA, LDB, LDB);
|
||||||
B_arr + k + indB_i2*LDB,
|
|
||||||
bM , bK, M_,
|
|
||||||
LDA, LDB, LDB);
|
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
int64_t indA_i = isFWDSolve ? M_ : 0;
|
int64_t indA_i = isFWDSolve ? M_ : 0;
|
||||||
int64_t indA_j = isFWDSolve ? gemmOff : bM;
|
int64_t indA_j = isFWDSolve ? gemmOff : bM;
|
||||||
int64_t indB_i = isFWDSolve ? M_ : 0;
|
int64_t indB_i = isFWDSolve ? M_ : 0;
|
||||||
int64_t offB_1 = isFWDSolve ? 0 : sizeBTemp - offsetBTemp;
|
int64_t offB_1 = isFWDSolve ? 0 : sizeBTemp - offsetBTemp;
|
||||||
gemmKernel<Scalar,isARowMajor, isBRowMajor,false,false>(
|
gemmKernel<Scalar, isARowMajor, isBRowMajor, false, false>(&A_arr[idA<isARowMajor>(indA_i, indA_j, LDA)],
|
||||||
&A_arr[idA<isARowMajor>(indA_i,indA_j,LDA)],
|
B_temp + offB_1, B_arr + indB_i + (k)*LDB, bM, bK,
|
||||||
B_temp + offB_1,
|
M_ - gemmOff, LDA, LDT, LDB);
|
||||||
B_arr + indB_i + (k)*LDB,
|
|
||||||
bM , bK, M_ - gemmOff,
|
|
||||||
LDA, LDT, LDB);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
EIGEN_IF_CONSTEXPR(!isBRowMajor) {
|
EIGEN_IF_CONSTEXPR(!isBRowMajor) {
|
||||||
@ -1030,44 +1092,45 @@ void triSolve(Scalar *A_arr, Scalar *B_arr, int64_t M, int64_t numRHS, int64_t L
|
|||||||
int64_t indB_i = isFWDSolve ? M_ : 0;
|
int64_t indB_i = isFWDSolve ? M_ : 0;
|
||||||
int64_t offB_1 = isFWDSolve ? 0 : (bM - 1) * bkL;
|
int64_t offB_1 = isFWDSolve ? 0 : (bM - 1) * bkL;
|
||||||
copyBToRowMajor<Scalar, true, true>(B_arr + indB_i + k * LDB, LDB, bK, B_temp, bkL, bM);
|
copyBToRowMajor<Scalar, true, true>(B_arr + indB_i + k * LDB, LDB, bK, B_temp, bkL, bM);
|
||||||
triSolveKernelLxK<Scalar, isARowMajor, isFWDSolve, isUnitDiag>(
|
triSolveKernelLxK<Scalar, isARowMajor, isFWDSolve, isUnitDiag>(&A_arr[idA<isARowMajor>(indA_i, indA_i, LDA)],
|
||||||
&A_arr[idA<isARowMajor>(indA_i, indA_i, LDA)], B_temp + offB_1, bM, bkL, LDA, bkL);
|
B_temp + offB_1, bM, bkL, LDA, bkL);
|
||||||
copyBToRowMajor<Scalar, false, true>(B_arr + indB_i + k * LDB, LDB, bK, B_temp, bkL, bM);
|
copyBToRowMajor<Scalar, false, true>(B_arr + indB_i + k * LDB, LDB, bK, B_temp, bkL, bM);
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
int64_t ind = isFWDSolve ? M_ : M - 1 - M_;
|
int64_t ind = isFWDSolve ? M_ : M - 1 - M_;
|
||||||
triSolveKernelLxK<Scalar, isARowMajor, isFWDSolve, isUnitDiag>(
|
triSolveKernelLxK<Scalar, isARowMajor, isFWDSolve, isUnitDiag>(&A_arr[idA<isARowMajor>(ind, ind, LDA)],
|
||||||
&A_arr[idA<isARowMajor>(ind, ind, LDA)], B_arr + k + ind*LDB, bM, bK, LDA, LDB);
|
B_arr + k + ind * LDB, bM, bK, LDA, LDB);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#if !defined(EIGEN_NO_MALLOC)
|
||||||
EIGEN_IF_CONSTEXPR(!isBRowMajor) handmade_aligned_free(B_temp);
|
EIGEN_IF_CONSTEXPR(!isBRowMajor) handmade_aligned_free(B_temp);
|
||||||
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
// Template specializations of trsmKernelL/R for float/double and inner strides of 1.
|
// Template specializations of trsmKernelL/R for float/double and inner strides of 1.
|
||||||
#if defined(EIGEN_USE_AVX512_TRSM_KERNELS)
|
#if (EIGEN_USE_AVX512_TRSM_KERNELS)
|
||||||
|
#if (EIGEN_USE_AVX512_TRSM_R_KERNELS)
|
||||||
template <typename Scalar, typename Index, int Mode, bool Conjugate, int TriStorageOrder, int OtherInnerStride>
|
template <typename Scalar, typename Index, int Mode, bool Conjugate, int TriStorageOrder, int OtherInnerStride>
|
||||||
struct trsmKernelR;
|
struct trsmKernelR;
|
||||||
|
|
||||||
template <typename Index, int Mode, int TriStorageOrder>
|
template <typename Index, int Mode, int TriStorageOrder>
|
||||||
struct trsmKernelR<float, Index, Mode, false, TriStorageOrder, 1> {
|
struct trsmKernelR<float, Index, Mode, false, TriStorageOrder, 1> {
|
||||||
static void kernel(Index size, Index otherSize, const float* _tri, Index triStride,
|
static void kernel(Index size, Index otherSize, const float *_tri, Index triStride, float *_other, Index otherIncr,
|
||||||
float* _other, Index otherIncr, Index otherStride);
|
Index otherStride);
|
||||||
};
|
};
|
||||||
|
|
||||||
template <typename Index, int Mode, int TriStorageOrder>
|
template <typename Index, int Mode, int TriStorageOrder>
|
||||||
struct trsmKernelR<double, Index, Mode, false, TriStorageOrder, 1> {
|
struct trsmKernelR<double, Index, Mode, false, TriStorageOrder, 1> {
|
||||||
static void kernel(Index size, Index otherSize, const double* _tri, Index triStride,
|
static void kernel(Index size, Index otherSize, const double *_tri, Index triStride, double *_other, Index otherIncr,
|
||||||
double* _other, Index otherIncr, Index otherStride);
|
Index otherStride);
|
||||||
};
|
};
|
||||||
|
|
||||||
template <typename Index, int Mode, int TriStorageOrder>
|
template <typename Index, int Mode, int TriStorageOrder>
|
||||||
EIGEN_DONT_INLINE void trsmKernelR<float, Index, Mode, false, TriStorageOrder, 1>::kernel(
|
EIGEN_DONT_INLINE void trsmKernelR<float, Index, Mode, false, TriStorageOrder, 1>::kernel(
|
||||||
Index size, Index otherSize,
|
Index size, Index otherSize, const float *_tri, Index triStride, float *_other, Index otherIncr,
|
||||||
const float* _tri, Index triStride,
|
Index otherStride) {
|
||||||
float* _other, Index otherIncr, Index otherStride)
|
|
||||||
{
|
|
||||||
EIGEN_UNUSED_VARIABLE(otherIncr);
|
EIGEN_UNUSED_VARIABLE(otherIncr);
|
||||||
triSolve<float, TriStorageOrder != RowMajor, true, (Mode & Lower) != Lower, (Mode & UnitDiag) != 0>(
|
triSolve<float, TriStorageOrder != RowMajor, true, (Mode & Lower) != Lower, (Mode & UnitDiag) != 0>(
|
||||||
const_cast<float *>(_tri), _other, size, otherSize, triStride, otherStride);
|
const_cast<float *>(_tri), _other, size, otherSize, triStride, otherStride);
|
||||||
@ -1075,38 +1138,35 @@ EIGEN_DONT_INLINE void trsmKernelR<float, Index, Mode, false, TriStorageOrder, 1
|
|||||||
|
|
||||||
template <typename Index, int Mode, int TriStorageOrder>
|
template <typename Index, int Mode, int TriStorageOrder>
|
||||||
EIGEN_DONT_INLINE void trsmKernelR<double, Index, Mode, false, TriStorageOrder, 1>::kernel(
|
EIGEN_DONT_INLINE void trsmKernelR<double, Index, Mode, false, TriStorageOrder, 1>::kernel(
|
||||||
Index size, Index otherSize,
|
Index size, Index otherSize, const double *_tri, Index triStride, double *_other, Index otherIncr,
|
||||||
const double* _tri, Index triStride,
|
Index otherStride) {
|
||||||
double* _other, Index otherIncr, Index otherStride)
|
|
||||||
{
|
|
||||||
EIGEN_UNUSED_VARIABLE(otherIncr);
|
EIGEN_UNUSED_VARIABLE(otherIncr);
|
||||||
triSolve<double, TriStorageOrder != RowMajor, true, (Mode & Lower) != Lower, (Mode & UnitDiag) != 0>(
|
triSolve<double, TriStorageOrder != RowMajor, true, (Mode & Lower) != Lower, (Mode & UnitDiag) != 0>(
|
||||||
const_cast<double *>(_tri), _other, size, otherSize, triStride, otherStride);
|
const_cast<double *>(_tri), _other, size, otherSize, triStride, otherStride);
|
||||||
}
|
}
|
||||||
|
#endif // (EIGEN_USE_AVX512_TRSM_R_KERNELS)
|
||||||
|
|
||||||
// These trsm kernels require temporary memory allocation, so disable them if malloc is not allowed.
|
// These trsm kernels require temporary memory allocation
|
||||||
#if defined(EIGEN_USE_AVX512_TRSM_L_KERNELS)
|
#if (EIGEN_USE_AVX512_TRSM_L_KERNELS)
|
||||||
template <typename Scalar, typename Index, int Mode, bool Conjugate, int TriStorageOrder, int OtherInnerStride>
|
template <typename Scalar, typename Index, int Mode, bool Conjugate, int TriStorageOrder, int OtherInnerStride>
|
||||||
struct trsmKernelL;
|
struct trsmKernelL;
|
||||||
|
|
||||||
template <typename Index, int Mode, int TriStorageOrder>
|
template <typename Index, int Mode, int TriStorageOrder>
|
||||||
struct trsmKernelL<float, Index, Mode, false, TriStorageOrder, 1> {
|
struct trsmKernelL<float, Index, Mode, false, TriStorageOrder, 1> {
|
||||||
static void kernel(Index size, Index otherSize, const float* _tri, Index triStride,
|
static void kernel(Index size, Index otherSize, const float *_tri, Index triStride, float *_other, Index otherIncr,
|
||||||
float* _other, Index otherIncr, Index otherStride);
|
Index otherStride);
|
||||||
};
|
};
|
||||||
|
|
||||||
template <typename Index, int Mode, int TriStorageOrder>
|
template <typename Index, int Mode, int TriStorageOrder>
|
||||||
struct trsmKernelL<double, Index, Mode, false, TriStorageOrder, 1> {
|
struct trsmKernelL<double, Index, Mode, false, TriStorageOrder, 1> {
|
||||||
static void kernel(Index size, Index otherSize, const double* _tri, Index triStride,
|
static void kernel(Index size, Index otherSize, const double *_tri, Index triStride, double *_other, Index otherIncr,
|
||||||
double* _other, Index otherIncr, Index otherStride);
|
Index otherStride);
|
||||||
};
|
};
|
||||||
|
|
||||||
template <typename Index, int Mode, int TriStorageOrder>
|
template <typename Index, int Mode, int TriStorageOrder>
|
||||||
EIGEN_DONT_INLINE void trsmKernelL<float, Index, Mode, false, TriStorageOrder, 1>::kernel(
|
EIGEN_DONT_INLINE void trsmKernelL<float, Index, Mode, false, TriStorageOrder, 1>::kernel(
|
||||||
Index size, Index otherSize,
|
Index size, Index otherSize, const float *_tri, Index triStride, float *_other, Index otherIncr,
|
||||||
const float* _tri, Index triStride,
|
Index otherStride) {
|
||||||
float* _other, Index otherIncr, Index otherStride)
|
|
||||||
{
|
|
||||||
EIGEN_UNUSED_VARIABLE(otherIncr);
|
EIGEN_UNUSED_VARIABLE(otherIncr);
|
||||||
triSolve<float, TriStorageOrder == RowMajor, false, (Mode & Lower) == Lower, (Mode & UnitDiag) != 0>(
|
triSolve<float, TriStorageOrder == RowMajor, false, (Mode & Lower) == Lower, (Mode & UnitDiag) != 0>(
|
||||||
const_cast<float *>(_tri), _other, size, otherSize, triStride, otherStride);
|
const_cast<float *>(_tri), _other, size, otherSize, triStride, otherStride);
|
||||||
@ -1114,16 +1174,14 @@ EIGEN_DONT_INLINE void trsmKernelL<float, Index, Mode, false, TriStorageOrder, 1
|
|||||||
|
|
||||||
template <typename Index, int Mode, int TriStorageOrder>
|
template <typename Index, int Mode, int TriStorageOrder>
|
||||||
EIGEN_DONT_INLINE void trsmKernelL<double, Index, Mode, false, TriStorageOrder, 1>::kernel(
|
EIGEN_DONT_INLINE void trsmKernelL<double, Index, Mode, false, TriStorageOrder, 1>::kernel(
|
||||||
Index size, Index otherSize,
|
Index size, Index otherSize, const double *_tri, Index triStride, double *_other, Index otherIncr,
|
||||||
const double* _tri, Index triStride,
|
Index otherStride) {
|
||||||
double* _other, Index otherIncr, Index otherStride)
|
|
||||||
{
|
|
||||||
EIGEN_UNUSED_VARIABLE(otherIncr);
|
EIGEN_UNUSED_VARIABLE(otherIncr);
|
||||||
triSolve<double, TriStorageOrder == RowMajor, false, (Mode & Lower) == Lower, (Mode & UnitDiag) != 0>(
|
triSolve<double, TriStorageOrder == RowMajor, false, (Mode & Lower) == Lower, (Mode & UnitDiag) != 0>(
|
||||||
const_cast<double *>(_tri), _other, size, otherSize, triStride, otherStride);
|
const_cast<double *>(_tri), _other, size, otherSize, triStride, otherStride);
|
||||||
}
|
}
|
||||||
#endif // EIGEN_USE_AVX512_TRSM_L_KERNELS
|
#endif // EIGEN_USE_AVX512_TRSM_L_KERNELS
|
||||||
#endif // EIGEN_USE_AVX512_TRSM_KERNELS
|
#endif // EIGEN_USE_AVX512_TRSM_KERNELS
|
||||||
}
|
} // namespace internal
|
||||||
}
|
} // namespace Eigen
|
||||||
#endif // EIGEN_TRSM_KERNEL_IMPL_H
|
#endif // EIGEN_TRSM_KERNEL_IMPL_H
|
||||||
|
@ -11,8 +11,7 @@
|
|||||||
#define EIGEN_UNROLLS_IMPL_H
|
#define EIGEN_UNROLLS_IMPL_H
|
||||||
|
|
||||||
template <bool isARowMajor = true>
|
template <bool isARowMajor = true>
|
||||||
static EIGEN_ALWAYS_INLINE
|
static EIGEN_ALWAYS_INLINE int64_t idA(int64_t i, int64_t j, int64_t LDA) {
|
||||||
int64_t idA(int64_t i, int64_t j, int64_t LDA) {
|
|
||||||
EIGEN_IF_CONSTEXPR(isARowMajor) return i * LDA + j;
|
EIGEN_IF_CONSTEXPR(isARowMajor) return i * LDA + j;
|
||||||
else return i + j * LDA;
|
else return i + j * LDA;
|
||||||
}
|
}
|
||||||
@ -60,8 +59,12 @@ namespace unrolls {
|
|||||||
template <int64_t N>
|
template <int64_t N>
|
||||||
EIGEN_ALWAYS_INLINE auto remMask(int64_t m) {
|
EIGEN_ALWAYS_INLINE auto remMask(int64_t m) {
|
||||||
EIGEN_IF_CONSTEXPR(N == 16) { return 0xFFFF >> (16 - m); }
|
EIGEN_IF_CONSTEXPR(N == 16) { return 0xFFFF >> (16 - m); }
|
||||||
else EIGEN_IF_CONSTEXPR( N == 8) { return 0xFF >> (8 - m); }
|
else EIGEN_IF_CONSTEXPR(N == 8) {
|
||||||
else EIGEN_IF_CONSTEXPR( N == 4) { return 0x0F >> (4 - m); }
|
return 0xFF >> (8 - m);
|
||||||
|
}
|
||||||
|
else EIGEN_IF_CONSTEXPR(N == 4) {
|
||||||
|
return 0x0F >> (4 - m);
|
||||||
|
}
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -105,10 +108,14 @@ EIGEN_ALWAYS_INLINE void trans8x8blocks(PacketBlock<Packet16f, 8>& kernel) {
|
|||||||
T7 = _mm512_castpd_ps(_mm512_permutex_pd(_mm512_castps_pd(kernel.packet[3]), 0x4E));
|
T7 = _mm512_castpd_ps(_mm512_permutex_pd(_mm512_castps_pd(kernel.packet[3]), 0x4E));
|
||||||
T7 = _mm512_mask_blend_ps(0xF0F0, T7, kernel.packet[7]);
|
T7 = _mm512_mask_blend_ps(0xF0F0, T7, kernel.packet[7]);
|
||||||
|
|
||||||
kernel.packet[0] = T0; kernel.packet[1] = T1;
|
kernel.packet[0] = T0;
|
||||||
kernel.packet[2] = T2; kernel.packet[3] = T3;
|
kernel.packet[1] = T1;
|
||||||
kernel.packet[4] = T4; kernel.packet[5] = T5;
|
kernel.packet[2] = T2;
|
||||||
kernel.packet[6] = T6; kernel.packet[7] = T7;
|
kernel.packet[3] = T3;
|
||||||
|
kernel.packet[4] = T4;
|
||||||
|
kernel.packet[5] = T5;
|
||||||
|
kernel.packet[6] = T6;
|
||||||
|
kernel.packet[7] = T7;
|
||||||
}
|
}
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
@ -126,7 +133,6 @@ public:
|
|||||||
using vecHalf = typename std::conditional<std::is_same<Scalar, float>::value, vecHalfFloat, vecFullDouble>::type;
|
using vecHalf = typename std::conditional<std::is_same<Scalar, float>::value, vecHalfFloat, vecFullDouble>::type;
|
||||||
static constexpr int64_t PacketSize = packet_traits<Scalar>::size;
|
static constexpr int64_t PacketSize = packet_traits<Scalar>::size;
|
||||||
|
|
||||||
|
|
||||||
/***********************************
|
/***********************************
|
||||||
* Auxillary Functions for:
|
* Auxillary Functions for:
|
||||||
* - storeC
|
* - storeC
|
||||||
@ -143,9 +149,8 @@ public:
|
|||||||
*
|
*
|
||||||
**/
|
**/
|
||||||
template <int64_t endN, int64_t counter, int64_t unrollN, int64_t packetIndexOffset, bool remM>
|
template <int64_t endN, int64_t counter, int64_t unrollN, int64_t packetIndexOffset, bool remM>
|
||||||
static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0 && endN <= PacketSize)>
|
static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0 && endN <= PacketSize)> aux_storeC(
|
||||||
aux_storeC(Scalar *C_arr, int64_t LDC,
|
Scalar *C_arr, int64_t LDC, PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm, int64_t remM_ = 0) {
|
||||||
PacketBlock<vec,EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm, int64_t remM_ = 0) {
|
|
||||||
constexpr int64_t counterReverse = endN - counter;
|
constexpr int64_t counterReverse = endN - counter;
|
||||||
constexpr int64_t startN = counterReverse;
|
constexpr int64_t startN = counterReverse;
|
||||||
|
|
||||||
@ -159,8 +164,7 @@ public:
|
|||||||
remMask<EIGEN_AVX_MAX_NUM_ROW>(remM_));
|
remMask<EIGEN_AVX_MAX_NUM_ROW>(remM_));
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
pstoreu<Scalar>(
|
pstoreu<Scalar>(C_arr + LDC * startN,
|
||||||
C_arr + LDC*startN,
|
|
||||||
padd(ploadu<vecHalf>((const Scalar *)C_arr + LDC * startN),
|
padd(ploadu<vecHalf>((const Scalar *)C_arr + LDC * startN),
|
||||||
preinterpret<vecHalf>(zmm.packet[packetIndexOffset + (unrollN / PacketSize) * startN])));
|
preinterpret<vecHalf>(zmm.packet[packetIndexOffset + (unrollN / PacketSize) * startN])));
|
||||||
}
|
}
|
||||||
@ -176,26 +180,25 @@ public:
|
|||||||
EIGEN_IF_CONSTEXPR(remM) {
|
EIGEN_IF_CONSTEXPR(remM) {
|
||||||
pstoreu<Scalar>(
|
pstoreu<Scalar>(
|
||||||
C_arr + LDC * startN,
|
C_arr + LDC * startN,
|
||||||
padd(ploadu<vecHalf>((const Scalar*)C_arr + LDC*startN,
|
padd(ploadu<vecHalf>((const Scalar *)C_arr + LDC * startN, remMask<EIGEN_AVX_MAX_NUM_ROW>(remM_)),
|
||||||
remMask<EIGEN_AVX_MAX_NUM_ROW>(remM_)),
|
preinterpret<vecHalf>(
|
||||||
preinterpret<vecHalf>(zmm.packet[packetIndexOffset + (unrollN/PacketSize)*(startN-EIGEN_AVX_MAX_NUM_ROW)])),
|
zmm.packet[packetIndexOffset + (unrollN / PacketSize) * (startN - EIGEN_AVX_MAX_NUM_ROW)])),
|
||||||
remMask<EIGEN_AVX_MAX_NUM_ROW>(remM_));
|
remMask<EIGEN_AVX_MAX_NUM_ROW>(remM_));
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
pstoreu<Scalar>(
|
pstoreu<Scalar>(
|
||||||
C_arr + LDC * startN,
|
C_arr + LDC * startN,
|
||||||
padd(ploadu<vecHalf>((const Scalar *)C_arr + LDC * startN),
|
padd(ploadu<vecHalf>((const Scalar *)C_arr + LDC * startN),
|
||||||
preinterpret<vecHalf>(zmm.packet[packetIndexOffset + (unrollN/PacketSize)*(startN-EIGEN_AVX_MAX_NUM_ROW)])));
|
preinterpret<vecHalf>(
|
||||||
|
zmm.packet[packetIndexOffset + (unrollN / PacketSize) * (startN - EIGEN_AVX_MAX_NUM_ROW)])));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
aux_storeC<endN, counter - 1, unrollN, packetIndexOffset, remM>(C_arr, LDC, zmm, remM_);
|
aux_storeC<endN, counter - 1, unrollN, packetIndexOffset, remM>(C_arr, LDC, zmm, remM_);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <int64_t endN, int64_t counter, int64_t unrollN, int64_t packetIndexOffset, bool remM>
|
template <int64_t endN, int64_t counter, int64_t unrollN, int64_t packetIndexOffset, bool remM>
|
||||||
static EIGEN_ALWAYS_INLINE std::enable_if_t<!(counter > 0 && endN <= PacketSize)>
|
static EIGEN_ALWAYS_INLINE std::enable_if_t<!(counter > 0 && endN <= PacketSize)> aux_storeC(
|
||||||
aux_storeC(Scalar *C_arr, int64_t LDC,
|
Scalar *C_arr, int64_t LDC, PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm, int64_t remM_ = 0) {
|
||||||
PacketBlock<vec,EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm, int64_t remM_ = 0)
|
|
||||||
{
|
|
||||||
EIGEN_UNUSED_VARIABLE(C_arr);
|
EIGEN_UNUSED_VARIABLE(C_arr);
|
||||||
EIGEN_UNUSED_VARIABLE(LDC);
|
EIGEN_UNUSED_VARIABLE(LDC);
|
||||||
EIGEN_UNUSED_VARIABLE(zmm);
|
EIGEN_UNUSED_VARIABLE(zmm);
|
||||||
@ -203,9 +206,9 @@ public:
|
|||||||
}
|
}
|
||||||
|
|
||||||
template <int64_t endN, int64_t unrollN, int64_t packetIndexOffset, bool remM>
|
template <int64_t endN, int64_t unrollN, int64_t packetIndexOffset, bool remM>
|
||||||
static EIGEN_ALWAYS_INLINE
|
static EIGEN_ALWAYS_INLINE void storeC(Scalar *C_arr, int64_t LDC,
|
||||||
void storeC(Scalar *C_arr, int64_t LDC,
|
PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm,
|
||||||
PacketBlock<vec,EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm, int64_t remM_ = 0){
|
int64_t remM_ = 0) {
|
||||||
aux_storeC<endN, endN, unrollN, packetIndexOffset, remM>(C_arr, LDC, zmm, remM_);
|
aux_storeC<endN, endN, unrollN, packetIndexOffset, remM>(C_arr, LDC, zmm, remM_);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -235,8 +238,7 @@ public:
|
|||||||
* avx registers are being transposed.
|
* avx registers are being transposed.
|
||||||
*/
|
*/
|
||||||
template <int64_t unrollN, int64_t packetIndexOffset>
|
template <int64_t unrollN, int64_t packetIndexOffset>
|
||||||
static EIGEN_ALWAYS_INLINE
|
static EIGEN_ALWAYS_INLINE void transpose(PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm) {
|
||||||
void transpose(PacketBlock<vec,EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm) {
|
|
||||||
// Note: this assumes EIGEN_AVX_MAX_NUM_ROW = 8. Unrolls should be adjusted
|
// Note: this assumes EIGEN_AVX_MAX_NUM_ROW = 8. Unrolls should be adjusted
|
||||||
// accordingly if EIGEN_AVX_MAX_NUM_ROW is smaller.
|
// accordingly if EIGEN_AVX_MAX_NUM_ROW is smaller.
|
||||||
constexpr int64_t zmmStride = unrollN / PacketSize;
|
constexpr int64_t zmmStride = unrollN / PacketSize;
|
||||||
@ -298,27 +300,25 @@ public:
|
|||||||
* for(startN = 0; startN < endN; startN++)
|
* for(startN = 0; startN < endN; startN++)
|
||||||
**/
|
**/
|
||||||
template <int64_t endN, int64_t counter, int64_t packetIndexOffset, bool remM>
|
template <int64_t endN, int64_t counter, int64_t packetIndexOffset, bool remM>
|
||||||
static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)>
|
static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_loadB(
|
||||||
aux_loadB(Scalar *B_arr, int64_t LDB,
|
Scalar *B_arr, int64_t LDB, PacketBlock<vecHalf, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &ymm,
|
||||||
PacketBlock<vecHalf,EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &ymm, int64_t remM_ = 0) {
|
int64_t remM_ = 0) {
|
||||||
constexpr int64_t counterReverse = endN - counter;
|
constexpr int64_t counterReverse = endN - counter;
|
||||||
constexpr int64_t startN = counterReverse;
|
constexpr int64_t startN = counterReverse;
|
||||||
|
|
||||||
EIGEN_IF_CONSTEXPR(remM) {
|
EIGEN_IF_CONSTEXPR(remM) {
|
||||||
ymm.packet[packetIndexOffset + startN] = ploadu<vecHalf>(
|
ymm.packet[packetIndexOffset + startN] =
|
||||||
(const Scalar*)&B_arr[startN*LDB], remMask<EIGEN_AVX_MAX_NUM_ROW>(remM_));
|
ploadu<vecHalf>((const Scalar *)&B_arr[startN * LDB], remMask<EIGEN_AVX_MAX_NUM_ROW>(remM_));
|
||||||
}
|
}
|
||||||
else
|
else ymm.packet[packetIndexOffset + startN] = ploadu<vecHalf>((const Scalar *)&B_arr[startN * LDB]);
|
||||||
ymm.packet[packetIndexOffset + startN] = ploadu<vecHalf>((const Scalar*)&B_arr[startN*LDB]);
|
|
||||||
|
|
||||||
aux_loadB<endN, counter - 1, packetIndexOffset, remM>(B_arr, LDB, ymm, remM_);
|
aux_loadB<endN, counter - 1, packetIndexOffset, remM>(B_arr, LDB, ymm, remM_);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <int64_t endN, int64_t counter, int64_t packetIndexOffset, bool remM>
|
template <int64_t endN, int64_t counter, int64_t packetIndexOffset, bool remM>
|
||||||
static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)>
|
static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_loadB(
|
||||||
aux_loadB(Scalar *B_arr, int64_t LDB,
|
Scalar *B_arr, int64_t LDB, PacketBlock<vecHalf, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &ymm,
|
||||||
PacketBlock<vecHalf,EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &ymm, int64_t remM_ = 0)
|
int64_t remM_ = 0) {
|
||||||
{
|
|
||||||
EIGEN_UNUSED_VARIABLE(B_arr);
|
EIGEN_UNUSED_VARIABLE(B_arr);
|
||||||
EIGEN_UNUSED_VARIABLE(LDB);
|
EIGEN_UNUSED_VARIABLE(LDB);
|
||||||
EIGEN_UNUSED_VARIABLE(ymm);
|
EIGEN_UNUSED_VARIABLE(ymm);
|
||||||
@ -332,16 +332,13 @@ public:
|
|||||||
* for(startN = 0; startN < endN; startN++)
|
* for(startN = 0; startN < endN; startN++)
|
||||||
**/
|
**/
|
||||||
template <int64_t endN, int64_t counter, int64_t packetIndexOffset, bool remK, bool remM>
|
template <int64_t endN, int64_t counter, int64_t packetIndexOffset, bool remK, bool remM>
|
||||||
static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)>
|
static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_storeB(
|
||||||
aux_storeB(Scalar *B_arr, int64_t LDB,
|
Scalar *B_arr, int64_t LDB, PacketBlock<vecHalf, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &ymm, int64_t rem_ = 0) {
|
||||||
PacketBlock<vecHalf,EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &ymm, int64_t rem_ = 0) {
|
|
||||||
constexpr int64_t counterReverse = endN - counter;
|
constexpr int64_t counterReverse = endN - counter;
|
||||||
constexpr int64_t startN = counterReverse;
|
constexpr int64_t startN = counterReverse;
|
||||||
|
|
||||||
EIGEN_IF_CONSTEXPR(remK || remM) {
|
EIGEN_IF_CONSTEXPR(remK || remM) {
|
||||||
pstoreu<Scalar>(
|
pstoreu<Scalar>(&B_arr[startN * LDB], ymm.packet[packetIndexOffset + startN],
|
||||||
&B_arr[startN*LDB],
|
|
||||||
ymm.packet[packetIndexOffset + startN],
|
|
||||||
remMask<EIGEN_AVX_MAX_NUM_ROW>(rem_));
|
remMask<EIGEN_AVX_MAX_NUM_ROW>(rem_));
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
@ -352,10 +349,8 @@ public:
|
|||||||
}
|
}
|
||||||
|
|
||||||
template <int64_t endN, int64_t counter, int64_t packetIndexOffset, bool remK, bool remM>
|
template <int64_t endN, int64_t counter, int64_t packetIndexOffset, bool remK, bool remM>
|
||||||
static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)>
|
static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_storeB(
|
||||||
aux_storeB(Scalar *B_arr, int64_t LDB,
|
Scalar *B_arr, int64_t LDB, PacketBlock<vecHalf, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &ymm, int64_t rem_ = 0) {
|
||||||
PacketBlock<vecHalf,EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &ymm, int64_t rem_ = 0)
|
|
||||||
{
|
|
||||||
EIGEN_UNUSED_VARIABLE(B_arr);
|
EIGEN_UNUSED_VARIABLE(B_arr);
|
||||||
EIGEN_UNUSED_VARIABLE(LDB);
|
EIGEN_UNUSED_VARIABLE(LDB);
|
||||||
EIGEN_UNUSED_VARIABLE(ymm);
|
EIGEN_UNUSED_VARIABLE(ymm);
|
||||||
@ -369,23 +364,19 @@ public:
|
|||||||
* for(startN = 0; startN < endN; startN += EIGEN_AVX_MAX_NUM_ROW)
|
* for(startN = 0; startN < endN; startN += EIGEN_AVX_MAX_NUM_ROW)
|
||||||
**/
|
**/
|
||||||
template <int64_t endN, int64_t counter, bool toTemp, bool remM>
|
template <int64_t endN, int64_t counter, bool toTemp, bool remM>
|
||||||
static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)>
|
static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_loadBBlock(
|
||||||
aux_loadBBlock(Scalar *B_arr, int64_t LDB, Scalar *B_temp, int64_t LDB_,
|
Scalar *B_arr, int64_t LDB, Scalar *B_temp, int64_t LDB_,
|
||||||
PacketBlock<vecHalf,EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &ymm,
|
PacketBlock<vecHalf, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &ymm, int64_t remM_ = 0) {
|
||||||
int64_t remM_ = 0) {
|
|
||||||
constexpr int64_t counterReverse = endN - counter;
|
constexpr int64_t counterReverse = endN - counter;
|
||||||
constexpr int64_t startN = counterReverse;
|
constexpr int64_t startN = counterReverse;
|
||||||
transB::template loadB<EIGEN_AVX_MAX_NUM_ROW, startN, false>(&B_temp[startN], LDB_, ymm);
|
transB::template loadB<EIGEN_AVX_MAX_NUM_ROW, startN, false>(&B_temp[startN], LDB_, ymm);
|
||||||
aux_loadBBlock<endN, counter-EIGEN_AVX_MAX_NUM_ROW, toTemp, remM>(
|
aux_loadBBlock<endN, counter - EIGEN_AVX_MAX_NUM_ROW, toTemp, remM>(B_arr, LDB, B_temp, LDB_, ymm, remM_);
|
||||||
B_arr, LDB, B_temp, LDB_, ymm, remM_);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
template <int64_t endN, int64_t counter, bool toTemp, bool remM>
|
template <int64_t endN, int64_t counter, bool toTemp, bool remM>
|
||||||
static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)>
|
static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_loadBBlock(
|
||||||
aux_loadBBlock(Scalar *B_arr, int64_t LDB, Scalar *B_temp, int64_t LDB_,
|
Scalar *B_arr, int64_t LDB, Scalar *B_temp, int64_t LDB_,
|
||||||
PacketBlock<vecHalf,EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &ymm,
|
PacketBlock<vecHalf, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &ymm, int64_t remM_ = 0) {
|
||||||
int64_t remM_ = 0)
|
|
||||||
{
|
|
||||||
EIGEN_UNUSED_VARIABLE(B_arr);
|
EIGEN_UNUSED_VARIABLE(B_arr);
|
||||||
EIGEN_UNUSED_VARIABLE(LDB);
|
EIGEN_UNUSED_VARIABLE(LDB);
|
||||||
EIGEN_UNUSED_VARIABLE(B_temp);
|
EIGEN_UNUSED_VARIABLE(B_temp);
|
||||||
@ -394,7 +385,6 @@ public:
|
|||||||
EIGEN_UNUSED_VARIABLE(remM_);
|
EIGEN_UNUSED_VARIABLE(remM_);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* aux_storeBBlock
|
* aux_storeBBlock
|
||||||
*
|
*
|
||||||
@ -402,31 +392,26 @@ public:
|
|||||||
* for(startN = 0; startN < endN; startN += EIGEN_AVX_MAX_NUM_ROW)
|
* for(startN = 0; startN < endN; startN += EIGEN_AVX_MAX_NUM_ROW)
|
||||||
**/
|
**/
|
||||||
template <int64_t endN, int64_t counter, bool toTemp, bool remM, int64_t remK_>
|
template <int64_t endN, int64_t counter, bool toTemp, bool remM, int64_t remK_>
|
||||||
static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)>
|
static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_storeBBlock(
|
||||||
aux_storeBBlock(Scalar *B_arr, int64_t LDB, Scalar *B_temp, int64_t LDB_,
|
Scalar *B_arr, int64_t LDB, Scalar *B_temp, int64_t LDB_,
|
||||||
PacketBlock<vecHalf,EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &ymm,
|
PacketBlock<vecHalf, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &ymm, int64_t remM_ = 0) {
|
||||||
int64_t remM_ = 0) {
|
|
||||||
constexpr int64_t counterReverse = endN - counter;
|
constexpr int64_t counterReverse = endN - counter;
|
||||||
constexpr int64_t startN = counterReverse;
|
constexpr int64_t startN = counterReverse;
|
||||||
|
|
||||||
EIGEN_IF_CONSTEXPR(toTemp) {
|
EIGEN_IF_CONSTEXPR(toTemp) {
|
||||||
transB::template storeB<EIGEN_AVX_MAX_NUM_ROW,startN, remK_ != 0, false>(
|
transB::template storeB<EIGEN_AVX_MAX_NUM_ROW, startN, remK_ != 0, false>(&B_temp[startN], LDB_, ymm, remK_);
|
||||||
&B_temp[startN], LDB_, ymm, remK_);
|
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
transB::template storeB<std::min(EIGEN_AVX_MAX_NUM_ROW,endN),startN, false, remM>(
|
transB::template storeB<std::min(EIGEN_AVX_MAX_NUM_ROW, endN), startN, false, remM>(&B_arr[0 + startN * LDB], LDB,
|
||||||
&B_arr[0 + startN*LDB], LDB, ymm, remM_);
|
ymm, remM_);
|
||||||
}
|
}
|
||||||
aux_storeBBlock<endN, counter-EIGEN_AVX_MAX_NUM_ROW, toTemp, remM, remK_>(
|
aux_storeBBlock<endN, counter - EIGEN_AVX_MAX_NUM_ROW, toTemp, remM, remK_>(B_arr, LDB, B_temp, LDB_, ymm, remM_);
|
||||||
B_arr, LDB, B_temp, LDB_, ymm, remM_);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
template <int64_t endN, int64_t counter, bool toTemp, bool remM, int64_t remK_>
|
template <int64_t endN, int64_t counter, bool toTemp, bool remM, int64_t remK_>
|
||||||
static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)>
|
static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_storeBBlock(
|
||||||
aux_storeBBlock(Scalar *B_arr, int64_t LDB, Scalar *B_temp, int64_t LDB_,
|
Scalar *B_arr, int64_t LDB, Scalar *B_temp, int64_t LDB_,
|
||||||
PacketBlock<vecHalf,EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &ymm,
|
PacketBlock<vecHalf, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &ymm, int64_t remM_ = 0) {
|
||||||
int64_t remM_ = 0)
|
|
||||||
{
|
|
||||||
EIGEN_UNUSED_VARIABLE(B_arr);
|
EIGEN_UNUSED_VARIABLE(B_arr);
|
||||||
EIGEN_UNUSED_VARIABLE(LDB);
|
EIGEN_UNUSED_VARIABLE(LDB);
|
||||||
EIGEN_UNUSED_VARIABLE(B_temp);
|
EIGEN_UNUSED_VARIABLE(B_temp);
|
||||||
@ -435,51 +420,43 @@ public:
|
|||||||
EIGEN_UNUSED_VARIABLE(remM_);
|
EIGEN_UNUSED_VARIABLE(remM_);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
/********************************************************
|
/********************************************************
|
||||||
* Wrappers for aux_XXXX to hide counter parameter
|
* Wrappers for aux_XXXX to hide counter parameter
|
||||||
********************************************************/
|
********************************************************/
|
||||||
|
|
||||||
template <int64_t endN, int64_t packetIndexOffset, bool remM>
|
template <int64_t endN, int64_t packetIndexOffset, bool remM>
|
||||||
static EIGEN_ALWAYS_INLINE
|
static EIGEN_ALWAYS_INLINE void loadB(Scalar *B_arr, int64_t LDB,
|
||||||
void loadB(Scalar *B_arr, int64_t LDB,
|
PacketBlock<vecHalf, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &ymm,
|
||||||
PacketBlock<vecHalf,EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &ymm, int64_t remM_ = 0) {
|
int64_t remM_ = 0) {
|
||||||
aux_loadB<endN, endN, packetIndexOffset, remM>(B_arr, LDB, ymm, remM_);
|
aux_loadB<endN, endN, packetIndexOffset, remM>(B_arr, LDB, ymm, remM_);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <int64_t endN, int64_t packetIndexOffset, bool remK, bool remM>
|
template <int64_t endN, int64_t packetIndexOffset, bool remK, bool remM>
|
||||||
static EIGEN_ALWAYS_INLINE
|
static EIGEN_ALWAYS_INLINE void storeB(Scalar *B_arr, int64_t LDB,
|
||||||
void storeB(Scalar *B_arr, int64_t LDB,
|
PacketBlock<vecHalf, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &ymm,
|
||||||
PacketBlock<vecHalf,EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &ymm, int64_t rem_ = 0) {
|
int64_t rem_ = 0) {
|
||||||
aux_storeB<endN, endN, packetIndexOffset, remK, remM>(B_arr, LDB, ymm, rem_);
|
aux_storeB<endN, endN, packetIndexOffset, remK, remM>(B_arr, LDB, ymm, rem_);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <int64_t unrollN, bool toTemp, bool remM>
|
template <int64_t unrollN, bool toTemp, bool remM>
|
||||||
static EIGEN_ALWAYS_INLINE
|
static EIGEN_ALWAYS_INLINE void loadBBlock(Scalar *B_arr, int64_t LDB, Scalar *B_temp, int64_t LDB_,
|
||||||
void loadBBlock(Scalar *B_arr, int64_t LDB, Scalar *B_temp, int64_t LDB_,
|
|
||||||
PacketBlock<vecHalf, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &ymm,
|
PacketBlock<vecHalf, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &ymm,
|
||||||
int64_t remM_ = 0) {
|
int64_t remM_ = 0) {
|
||||||
EIGEN_IF_CONSTEXPR(toTemp) {
|
EIGEN_IF_CONSTEXPR(toTemp) { transB::template loadB<unrollN, 0, remM>(&B_arr[0], LDB, ymm, remM_); }
|
||||||
transB::template loadB<unrollN,0,remM>(&B_arr[0],LDB, ymm, remM_);
|
|
||||||
}
|
|
||||||
else {
|
else {
|
||||||
aux_loadBBlock<unrollN, unrollN, toTemp, remM>(
|
aux_loadBBlock<unrollN, unrollN, toTemp, remM>(B_arr, LDB, B_temp, LDB_, ymm, remM_);
|
||||||
B_arr, LDB, B_temp, LDB_, ymm, remM_);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <int64_t unrollN, bool toTemp, bool remM, int64_t remK_>
|
template <int64_t unrollN, bool toTemp, bool remM, int64_t remK_>
|
||||||
static EIGEN_ALWAYS_INLINE
|
static EIGEN_ALWAYS_INLINE void storeBBlock(Scalar *B_arr, int64_t LDB, Scalar *B_temp, int64_t LDB_,
|
||||||
void storeBBlock(Scalar *B_arr, int64_t LDB, Scalar *B_temp, int64_t LDB_,
|
|
||||||
PacketBlock<vecHalf, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &ymm,
|
PacketBlock<vecHalf, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &ymm,
|
||||||
int64_t remM_ = 0) {
|
int64_t remM_ = 0) {
|
||||||
aux_storeBBlock<unrollN, unrollN, toTemp, remM, remK_>(
|
aux_storeBBlock<unrollN, unrollN, toTemp, remM, remK_>(B_arr, LDB, B_temp, LDB_, ymm, remM_);
|
||||||
B_arr, LDB, B_temp, LDB_, ymm, remM_);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
template <int64_t packetIndexOffset>
|
template <int64_t packetIndexOffset>
|
||||||
static EIGEN_ALWAYS_INLINE
|
static EIGEN_ALWAYS_INLINE void transposeLxL(PacketBlock<vecHalf, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &ymm) {
|
||||||
void transposeLxL(PacketBlock<vecHalf,EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &ymm){
|
|
||||||
// Note: this assumes EIGEN_AVX_MAX_NUM_ROW = 8. Unrolls should be adjusted
|
// Note: this assumes EIGEN_AVX_MAX_NUM_ROW = 8. Unrolls should be adjusted
|
||||||
// accordingly if EIGEN_AVX_MAX_NUM_ROW is smaller.
|
// accordingly if EIGEN_AVX_MAX_NUM_ROW is smaller.
|
||||||
PacketBlock<vecHalf, EIGEN_AVX_MAX_NUM_ROW> r;
|
PacketBlock<vecHalf, EIGEN_AVX_MAX_NUM_ROW> r;
|
||||||
@ -503,9 +480,9 @@ public:
|
|||||||
}
|
}
|
||||||
|
|
||||||
template <int64_t unrollN, bool toTemp, bool remM>
|
template <int64_t unrollN, bool toTemp, bool remM>
|
||||||
static EIGEN_ALWAYS_INLINE
|
static EIGEN_ALWAYS_INLINE void transB_kernel(Scalar *B_arr, int64_t LDB, Scalar *B_temp, int64_t LDB_,
|
||||||
void transB_kernel(Scalar *B_arr, int64_t LDB, Scalar *B_temp, int64_t LDB_,
|
PacketBlock<vecHalf, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &ymm,
|
||||||
PacketBlock<vecHalf,EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &ymm, int64_t remM_ = 0) {
|
int64_t remM_ = 0) {
|
||||||
constexpr int64_t U3 = PacketSize * 3;
|
constexpr int64_t U3 = PacketSize * 3;
|
||||||
constexpr int64_t U2 = PacketSize * 2;
|
constexpr int64_t U2 = PacketSize * 2;
|
||||||
constexpr int64_t U1 = PacketSize * 1;
|
constexpr int64_t U1 = PacketSize * 1;
|
||||||
@ -526,11 +503,13 @@ public:
|
|||||||
transB::template storeBBlock<maxUBlock, toTemp, remM, 0>(B_arr, LDB, B_temp, LDB_, ymm, remM_);
|
transB::template storeBBlock<maxUBlock, toTemp, remM, 0>(B_arr, LDB, B_temp, LDB_, ymm, remM_);
|
||||||
|
|
||||||
EIGEN_IF_CONSTEXPR(maxUBlock < U3) {
|
EIGEN_IF_CONSTEXPR(maxUBlock < U3) {
|
||||||
transB::template loadBBlock<maxUBlock,toTemp, remM>(&B_arr[maxUBlock*LDB], LDB, &B_temp[maxUBlock], LDB_, ymm, remM_);
|
transB::template loadBBlock<maxUBlock, toTemp, remM>(&B_arr[maxUBlock * LDB], LDB, &B_temp[maxUBlock], LDB_,
|
||||||
|
ymm, remM_);
|
||||||
transB::template transposeLxL<0 * EIGEN_AVX_MAX_NUM_ROW>(ymm);
|
transB::template transposeLxL<0 * EIGEN_AVX_MAX_NUM_ROW>(ymm);
|
||||||
transB::template transposeLxL<1 * EIGEN_AVX_MAX_NUM_ROW>(ymm);
|
transB::template transposeLxL<1 * EIGEN_AVX_MAX_NUM_ROW>(ymm);
|
||||||
transB::template transposeLxL<2 * EIGEN_AVX_MAX_NUM_ROW>(ymm);
|
transB::template transposeLxL<2 * EIGEN_AVX_MAX_NUM_ROW>(ymm);
|
||||||
transB::template storeBBlock<maxUBlock,toTemp, remM,0>(&B_arr[maxUBlock*LDB], LDB, &B_temp[maxUBlock], LDB_, ymm, remM_);
|
transB::template storeBBlock<maxUBlock, toTemp, remM, 0>(&B_arr[maxUBlock * LDB], LDB, &B_temp[maxUBlock], LDB_,
|
||||||
|
ymm, remM_);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
else EIGEN_IF_CONSTEXPR(unrollN == U2) {
|
else EIGEN_IF_CONSTEXPR(unrollN == U2) {
|
||||||
@ -543,20 +522,18 @@ public:
|
|||||||
transB::template storeBBlock<maxUBlock, toTemp, remM, 0>(B_arr, LDB, B_temp, LDB_, ymm, remM_);
|
transB::template storeBBlock<maxUBlock, toTemp, remM, 0>(B_arr, LDB, B_temp, LDB_, ymm, remM_);
|
||||||
|
|
||||||
EIGEN_IF_CONSTEXPR(maxUBlock < U2) {
|
EIGEN_IF_CONSTEXPR(maxUBlock < U2) {
|
||||||
transB::template loadBBlock<EIGEN_AVX_MAX_NUM_ROW,toTemp, remM>(
|
transB::template loadBBlock<EIGEN_AVX_MAX_NUM_ROW, toTemp, remM>(&B_arr[maxUBlock * LDB], LDB,
|
||||||
&B_arr[maxUBlock*LDB], LDB, &B_temp[maxUBlock], LDB_, ymm, remM_);
|
&B_temp[maxUBlock], LDB_, ymm, remM_);
|
||||||
transB::template transposeLxL<0>(ymm);
|
transB::template transposeLxL<0>(ymm);
|
||||||
transB::template storeBBlock<EIGEN_AVX_MAX_NUM_ROW,toTemp,remM,0>(
|
transB::template storeBBlock<EIGEN_AVX_MAX_NUM_ROW, toTemp, remM, 0>(&B_arr[maxUBlock * LDB], LDB,
|
||||||
&B_arr[maxUBlock*LDB], LDB, &B_temp[maxUBlock], LDB_, ymm, remM_);
|
&B_temp[maxUBlock], LDB_, ymm, remM_);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
else EIGEN_IF_CONSTEXPR(unrollN == U1) {
|
else EIGEN_IF_CONSTEXPR(unrollN == U1) {
|
||||||
// load LxU1 B col major, transpose LxU1 row major
|
// load LxU1 B col major, transpose LxU1 row major
|
||||||
transB::template loadBBlock<U1, toTemp, remM>(B_arr, LDB, B_temp, LDB_, ymm, remM_);
|
transB::template loadBBlock<U1, toTemp, remM>(B_arr, LDB, B_temp, LDB_, ymm, remM_);
|
||||||
transB::template transposeLxL<0>(ymm);
|
transB::template transposeLxL<0>(ymm);
|
||||||
EIGEN_IF_CONSTEXPR(EIGEN_AVX_MAX_NUM_ROW < U1) {
|
EIGEN_IF_CONSTEXPR(EIGEN_AVX_MAX_NUM_ROW < U1) { transB::template transposeLxL<1 * EIGEN_AVX_MAX_NUM_ROW>(ymm); }
|
||||||
transB::template transposeLxL<1*EIGEN_AVX_MAX_NUM_ROW>(ymm);
|
|
||||||
}
|
|
||||||
transB::template storeBBlock<U1, toTemp, remM, 0>(B_arr, LDB, B_temp, LDB_, ymm, remM_);
|
transB::template storeBBlock<U1, toTemp, remM, 0>(B_arr, LDB, B_temp, LDB_, ymm, remM_);
|
||||||
}
|
}
|
||||||
else EIGEN_IF_CONSTEXPR(unrollN == 8 && U1 > 8) {
|
else EIGEN_IF_CONSTEXPR(unrollN == 8 && U1 > 8) {
|
||||||
@ -601,9 +578,7 @@ public:
|
|||||||
template <typename Scalar>
|
template <typename Scalar>
|
||||||
class trsm {
|
class trsm {
|
||||||
public:
|
public:
|
||||||
using vec = typename std::conditional<std::is_same<Scalar, float>::value,
|
using vec = typename std::conditional<std::is_same<Scalar, float>::value, vecFullFloat, vecFullDouble>::type;
|
||||||
vecFullFloat,
|
|
||||||
vecFullDouble>::type;
|
|
||||||
static constexpr int64_t PacketSize = packet_traits<Scalar>::size;
|
static constexpr int64_t PacketSize = packet_traits<Scalar>::size;
|
||||||
|
|
||||||
/***********************************
|
/***********************************
|
||||||
@ -622,9 +597,8 @@ public:
|
|||||||
* for(startK = 0; startK < endK; startK++)
|
* for(startK = 0; startK < endK; startK++)
|
||||||
**/
|
**/
|
||||||
template <bool isFWDSolve, int64_t endM, int64_t endK, int64_t counter, bool krem>
|
template <bool isFWDSolve, int64_t endM, int64_t endK, int64_t counter, bool krem>
|
||||||
static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)>
|
static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_loadRHS(
|
||||||
aux_loadRHS(Scalar* B_arr, int64_t LDB, PacketBlock<vec,EIGEN_AVX_MAX_NUM_ACC> &RHSInPacket, int64_t rem = 0) {
|
Scalar *B_arr, int64_t LDB, PacketBlock<vec, EIGEN_AVX_MAX_NUM_ACC> &RHSInPacket, int64_t rem = 0) {
|
||||||
|
|
||||||
constexpr int64_t counterReverse = endM * endK - counter;
|
constexpr int64_t counterReverse = endM * endK - counter;
|
||||||
constexpr int64_t startM = counterReverse / (endK);
|
constexpr int64_t startM = counterReverse / (endK);
|
||||||
constexpr int64_t startK = counterReverse % endK;
|
constexpr int64_t startK = counterReverse % endK;
|
||||||
@ -642,9 +616,8 @@ public:
|
|||||||
}
|
}
|
||||||
|
|
||||||
template <bool isFWDSolve, int64_t endM, int64_t endK, int64_t counter, bool krem>
|
template <bool isFWDSolve, int64_t endM, int64_t endK, int64_t counter, bool krem>
|
||||||
static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)>
|
static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_loadRHS(
|
||||||
aux_loadRHS(Scalar* B_arr, int64_t LDB, PacketBlock<vec,EIGEN_AVX_MAX_NUM_ACC> &RHSInPacket, int64_t rem = 0)
|
Scalar *B_arr, int64_t LDB, PacketBlock<vec, EIGEN_AVX_MAX_NUM_ACC> &RHSInPacket, int64_t rem = 0) {
|
||||||
{
|
|
||||||
EIGEN_UNUSED_VARIABLE(B_arr);
|
EIGEN_UNUSED_VARIABLE(B_arr);
|
||||||
EIGEN_UNUSED_VARIABLE(LDB);
|
EIGEN_UNUSED_VARIABLE(LDB);
|
||||||
EIGEN_UNUSED_VARIABLE(RHSInPacket);
|
EIGEN_UNUSED_VARIABLE(RHSInPacket);
|
||||||
@ -659,8 +632,8 @@ public:
|
|||||||
* for(startK = 0; startK < endK; startK++)
|
* for(startK = 0; startK < endK; startK++)
|
||||||
**/
|
**/
|
||||||
template <bool isFWDSolve, int64_t endM, int64_t endK, int64_t counter, bool krem>
|
template <bool isFWDSolve, int64_t endM, int64_t endK, int64_t counter, bool krem>
|
||||||
static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)>
|
static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_storeRHS(
|
||||||
aux_storeRHS(Scalar* B_arr, int64_t LDB, PacketBlock<vec,EIGEN_AVX_MAX_NUM_ACC> &RHSInPacket, int64_t rem = 0) {
|
Scalar *B_arr, int64_t LDB, PacketBlock<vec, EIGEN_AVX_MAX_NUM_ACC> &RHSInPacket, int64_t rem = 0) {
|
||||||
constexpr int64_t counterReverse = endM * endK - counter;
|
constexpr int64_t counterReverse = endM * endK - counter;
|
||||||
constexpr int64_t startM = counterReverse / (endK);
|
constexpr int64_t startM = counterReverse / (endK);
|
||||||
constexpr int64_t startK = counterReverse % endK;
|
constexpr int64_t startK = counterReverse % endK;
|
||||||
@ -678,9 +651,8 @@ public:
|
|||||||
}
|
}
|
||||||
|
|
||||||
template <bool isFWDSolve, int64_t endM, int64_t endK, int64_t counter, bool krem>
|
template <bool isFWDSolve, int64_t endM, int64_t endK, int64_t counter, bool krem>
|
||||||
static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)>
|
static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_storeRHS(
|
||||||
aux_storeRHS(Scalar* B_arr, int64_t LDB, PacketBlock<vec,EIGEN_AVX_MAX_NUM_ACC> &RHSInPacket, int64_t rem = 0)
|
Scalar *B_arr, int64_t LDB, PacketBlock<vec, EIGEN_AVX_MAX_NUM_ACC> &RHSInPacket, int64_t rem = 0) {
|
||||||
{
|
|
||||||
EIGEN_UNUSED_VARIABLE(B_arr);
|
EIGEN_UNUSED_VARIABLE(B_arr);
|
||||||
EIGEN_UNUSED_VARIABLE(LDB);
|
EIGEN_UNUSED_VARIABLE(LDB);
|
||||||
EIGEN_UNUSED_VARIABLE(RHSInPacket);
|
EIGEN_UNUSED_VARIABLE(RHSInPacket);
|
||||||
@ -696,8 +668,8 @@ public:
|
|||||||
* for(startK = 0; startK < endK; startK++)
|
* for(startK = 0; startK < endK; startK++)
|
||||||
**/
|
**/
|
||||||
template <int64_t currM, int64_t endK, int64_t counter>
|
template <int64_t currM, int64_t endK, int64_t counter>
|
||||||
static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0 && currM >= 0)>
|
static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0 && currM >= 0)> aux_divRHSByDiag(
|
||||||
aux_divRHSByDiag(PacketBlock<vec,EIGEN_AVX_MAX_NUM_ACC> &RHSInPacket, PacketBlock<vec,EIGEN_AVX_MAX_NUM_ROW> &AInPacket) {
|
PacketBlock<vec, EIGEN_AVX_MAX_NUM_ACC> &RHSInPacket, PacketBlock<vec, EIGEN_AVX_MAX_NUM_ROW> &AInPacket) {
|
||||||
constexpr int64_t counterReverse = endK - counter;
|
constexpr int64_t counterReverse = endK - counter;
|
||||||
constexpr int64_t startK = counterReverse;
|
constexpr int64_t startK = counterReverse;
|
||||||
|
|
||||||
@ -707,8 +679,8 @@ public:
|
|||||||
}
|
}
|
||||||
|
|
||||||
template <int64_t currM, int64_t endK, int64_t counter>
|
template <int64_t currM, int64_t endK, int64_t counter>
|
||||||
static EIGEN_ALWAYS_INLINE std::enable_if_t<!(counter > 0 && currM >= 0)>
|
static EIGEN_ALWAYS_INLINE std::enable_if_t<!(counter > 0 && currM >= 0)> aux_divRHSByDiag(
|
||||||
aux_divRHSByDiag(PacketBlock<vec,EIGEN_AVX_MAX_NUM_ACC> &RHSInPacket, PacketBlock<vec,EIGEN_AVX_MAX_NUM_ROW> &AInPacket) {
|
PacketBlock<vec, EIGEN_AVX_MAX_NUM_ACC> &RHSInPacket, PacketBlock<vec, EIGEN_AVX_MAX_NUM_ROW> &AInPacket) {
|
||||||
EIGEN_UNUSED_VARIABLE(RHSInPacket);
|
EIGEN_UNUSED_VARIABLE(RHSInPacket);
|
||||||
EIGEN_UNUSED_VARIABLE(AInPacket);
|
EIGEN_UNUSED_VARIABLE(AInPacket);
|
||||||
}
|
}
|
||||||
@ -720,10 +692,11 @@ public:
|
|||||||
* for(startM = initM; startM < endM; startM++)
|
* for(startM = initM; startM < endM; startM++)
|
||||||
* for(startK = 0; startK < endK; startK++)
|
* for(startK = 0; startK < endK; startK++)
|
||||||
**/
|
**/
|
||||||
template<bool isARowMajor, bool isFWDSolve, bool isUnitDiag, int64_t initM, int64_t endM, int64_t endK, int64_t counter, int64_t currentM>
|
template <bool isARowMajor, bool isFWDSolve, bool isUnitDiag, int64_t initM, int64_t endM, int64_t endK,
|
||||||
static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)>
|
int64_t counter, int64_t currentM>
|
||||||
aux_updateRHS(Scalar *A_arr, int64_t LDA, PacketBlock<vec,EIGEN_AVX_MAX_NUM_ACC> &RHSInPacket, PacketBlock<vec,EIGEN_AVX_MAX_NUM_ROW> &AInPacket) {
|
static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_updateRHS(
|
||||||
|
Scalar *A_arr, int64_t LDA, PacketBlock<vec, EIGEN_AVX_MAX_NUM_ACC> &RHSInPacket,
|
||||||
|
PacketBlock<vec, EIGEN_AVX_MAX_NUM_ROW> &AInPacket) {
|
||||||
constexpr int64_t counterReverse = (endM - initM) * endK - counter;
|
constexpr int64_t counterReverse = (endM - initM) * endK - counter;
|
||||||
constexpr int64_t startM = initM + counterReverse / (endK);
|
constexpr int64_t startM = initM + counterReverse / (endK);
|
||||||
constexpr int64_t startK = counterReverse % endK;
|
constexpr int64_t startK = counterReverse % endK;
|
||||||
@ -732,8 +705,7 @@ public:
|
|||||||
constexpr int64_t packetIndex = startM * endK + startK;
|
constexpr int64_t packetIndex = startM * endK + startK;
|
||||||
EIGEN_IF_CONSTEXPR(currentM > 0) {
|
EIGEN_IF_CONSTEXPR(currentM > 0) {
|
||||||
RHSInPacket.packet[packetIndex] =
|
RHSInPacket.packet[packetIndex] =
|
||||||
pnmadd(AInPacket.packet[startM],
|
pnmadd(AInPacket.packet[startM], RHSInPacket.packet[(currentM - 1) * endK + startK],
|
||||||
RHSInPacket.packet[(currentM-1)*endK+startK],
|
|
||||||
RHSInPacket.packet[packetIndex]);
|
RHSInPacket.packet[packetIndex]);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -744,24 +716,25 @@ public:
|
|||||||
// This will be used in divRHSByDiag
|
// This will be used in divRHSByDiag
|
||||||
EIGEN_IF_CONSTEXPR(isFWDSolve)
|
EIGEN_IF_CONSTEXPR(isFWDSolve)
|
||||||
AInPacket.packet[currentM] = pset1<vec>(Scalar(1) / A_arr[idA<isARowMajor>(currentM, currentM, LDA)]);
|
AInPacket.packet[currentM] = pset1<vec>(Scalar(1) / A_arr[idA<isARowMajor>(currentM, currentM, LDA)]);
|
||||||
else
|
else AInPacket.packet[currentM] = pset1<vec>(Scalar(1) / A_arr[idA<isARowMajor>(-currentM, -currentM, LDA)]);
|
||||||
AInPacket.packet[currentM] = pset1<vec>(Scalar(1)/A_arr[idA<isARowMajor>(-currentM,-currentM,LDA)]);
|
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
// Broadcast next off diagonal element of A
|
// Broadcast next off diagonal element of A
|
||||||
EIGEN_IF_CONSTEXPR(isFWDSolve)
|
EIGEN_IF_CONSTEXPR(isFWDSolve)
|
||||||
AInPacket.packet[startM] = pset1<vec>(A_arr[idA<isARowMajor>(startM, currentM, LDA)]);
|
AInPacket.packet[startM] = pset1<vec>(A_arr[idA<isARowMajor>(startM, currentM, LDA)]);
|
||||||
else
|
else AInPacket.packet[startM] = pset1<vec>(A_arr[idA<isARowMajor>(-startM, -currentM, LDA)]);
|
||||||
AInPacket.packet[startM] = pset1<vec>(A_arr[idA<isARowMajor>(-startM,-currentM,LDA)]);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
aux_updateRHS<isARowMajor, isFWDSolve, isUnitDiag, initM, endM, endK, counter - 1, currentM>(A_arr, LDA, RHSInPacket, AInPacket);
|
aux_updateRHS<isARowMajor, isFWDSolve, isUnitDiag, initM, endM, endK, counter - 1, currentM>(
|
||||||
|
A_arr, LDA, RHSInPacket, AInPacket);
|
||||||
}
|
}
|
||||||
|
|
||||||
template<bool isARowMajor, bool isFWDSolve, bool isUnitDiag, int64_t initM, int64_t endM, int64_t endK, int64_t counter, int64_t currentM>
|
template <bool isARowMajor, bool isFWDSolve, bool isUnitDiag, int64_t initM, int64_t endM, int64_t endK,
|
||||||
static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)>
|
int64_t counter, int64_t currentM>
|
||||||
aux_updateRHS(Scalar *A_arr, int64_t LDA, PacketBlock<vec,EIGEN_AVX_MAX_NUM_ACC> &RHSInPacket, PacketBlock<vec,EIGEN_AVX_MAX_NUM_ROW> &AInPacket) {
|
static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_updateRHS(
|
||||||
|
Scalar *A_arr, int64_t LDA, PacketBlock<vec, EIGEN_AVX_MAX_NUM_ACC> &RHSInPacket,
|
||||||
|
PacketBlock<vec, EIGEN_AVX_MAX_NUM_ROW> &AInPacket) {
|
||||||
EIGEN_UNUSED_VARIABLE(A_arr);
|
EIGEN_UNUSED_VARIABLE(A_arr);
|
||||||
EIGEN_UNUSED_VARIABLE(LDA);
|
EIGEN_UNUSED_VARIABLE(LDA);
|
||||||
EIGEN_UNUSED_VARIABLE(RHSInPacket);
|
EIGEN_UNUSED_VARIABLE(RHSInPacket);
|
||||||
@ -775,9 +748,9 @@ public:
|
|||||||
* for(startM = 0; startM < endM; startM++)
|
* for(startM = 0; startM < endM; startM++)
|
||||||
**/
|
**/
|
||||||
template <bool isARowMajor, bool isFWDSolve, bool isUnitDiag, int64_t endM, int64_t counter, int64_t numK>
|
template <bool isARowMajor, bool isFWDSolve, bool isUnitDiag, int64_t endM, int64_t counter, int64_t numK>
|
||||||
static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)>
|
static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_triSolveMicroKernel(
|
||||||
aux_triSolveMicroKernel(Scalar *A_arr, int64_t LDA, PacketBlock<vec,EIGEN_AVX_MAX_NUM_ACC> &RHSInPacket, PacketBlock<vec,EIGEN_AVX_MAX_NUM_ROW> &AInPacket) {
|
Scalar *A_arr, int64_t LDA, PacketBlock<vec, EIGEN_AVX_MAX_NUM_ACC> &RHSInPacket,
|
||||||
|
PacketBlock<vec, EIGEN_AVX_MAX_NUM_ROW> &AInPacket) {
|
||||||
constexpr int64_t counterReverse = endM - counter;
|
constexpr int64_t counterReverse = endM - counter;
|
||||||
constexpr int64_t startM = counterReverse;
|
constexpr int64_t startM = counterReverse;
|
||||||
|
|
||||||
@ -794,21 +767,21 @@ public:
|
|||||||
// After division, the rhs corresponding to subsequent rows of A can be partially updated
|
// After division, the rhs corresponding to subsequent rows of A can be partially updated
|
||||||
// We also broadcast the reciprocal of the next diagonal to AInPacket.packet[currentM] (if needed)
|
// We also broadcast the reciprocal of the next diagonal to AInPacket.packet[currentM] (if needed)
|
||||||
// to be used in the next iteration.
|
// to be used in the next iteration.
|
||||||
trsm::template
|
trsm::template updateRHS<isARowMajor, isFWDSolve, isUnitDiag, startM, endM, numK, currentM>(A_arr, LDA, RHSInPacket,
|
||||||
updateRHS<isARowMajor, isFWDSolve, isUnitDiag, startM, endM, numK, currentM>(
|
AInPacket);
|
||||||
A_arr, LDA, RHSInPacket, AInPacket);
|
|
||||||
|
|
||||||
// Handle division for the RHS corresponding to the final row of A.
|
// Handle division for the RHS corresponding to the final row of A.
|
||||||
EIGEN_IF_CONSTEXPR(!isUnitDiag && startM == endM - 1)
|
EIGEN_IF_CONSTEXPR(!isUnitDiag && startM == endM - 1)
|
||||||
trsm::template divRHSByDiag<startM, numK>(RHSInPacket, AInPacket);
|
trsm::template divRHSByDiag<startM, numK>(RHSInPacket, AInPacket);
|
||||||
|
|
||||||
aux_triSolveMicroKernel<isARowMajor, isFWDSolve, isUnitDiag, endM, counter - 1, numK>(A_arr, LDA, RHSInPacket, AInPacket);
|
aux_triSolveMicroKernel<isARowMajor, isFWDSolve, isUnitDiag, endM, counter - 1, numK>(A_arr, LDA, RHSInPacket,
|
||||||
|
AInPacket);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <bool isARowMajor, bool isFWDSolve, bool isUnitDiag, int64_t endM, int64_t counter, int64_t numK>
|
template <bool isARowMajor, bool isFWDSolve, bool isUnitDiag, int64_t endM, int64_t counter, int64_t numK>
|
||||||
static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)>
|
static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_triSolveMicroKernel(
|
||||||
aux_triSolveMicroKernel(Scalar *A_arr, int64_t LDA, PacketBlock<vec,EIGEN_AVX_MAX_NUM_ACC> &RHSInPacket, PacketBlock<vec,EIGEN_AVX_MAX_NUM_ROW> &AInPacket)
|
Scalar *A_arr, int64_t LDA, PacketBlock<vec, EIGEN_AVX_MAX_NUM_ACC> &RHSInPacket,
|
||||||
{
|
PacketBlock<vec, EIGEN_AVX_MAX_NUM_ROW> &AInPacket) {
|
||||||
EIGEN_UNUSED_VARIABLE(A_arr);
|
EIGEN_UNUSED_VARIABLE(A_arr);
|
||||||
EIGEN_UNUSED_VARIABLE(LDA);
|
EIGEN_UNUSED_VARIABLE(LDA);
|
||||||
EIGEN_UNUSED_VARIABLE(RHSInPacket);
|
EIGEN_UNUSED_VARIABLE(RHSInPacket);
|
||||||
@ -824,8 +797,8 @@ public:
|
|||||||
* Masked loads are used for cases where endK is not a multiple of PacketSize
|
* Masked loads are used for cases where endK is not a multiple of PacketSize
|
||||||
*/
|
*/
|
||||||
template <bool isFWDSolve, int64_t endM, int64_t endK, bool krem = false>
|
template <bool isFWDSolve, int64_t endM, int64_t endK, bool krem = false>
|
||||||
static EIGEN_ALWAYS_INLINE
|
static EIGEN_ALWAYS_INLINE void loadRHS(Scalar *B_arr, int64_t LDB,
|
||||||
void loadRHS(Scalar* B_arr, int64_t LDB, PacketBlock<vec,EIGEN_AVX_MAX_NUM_ACC> &RHSInPacket, int64_t rem = 0) {
|
PacketBlock<vec, EIGEN_AVX_MAX_NUM_ACC> &RHSInPacket, int64_t rem = 0) {
|
||||||
aux_loadRHS<isFWDSolve, endM, endK, endM * endK, krem>(B_arr, LDB, RHSInPacket, rem);
|
aux_loadRHS<isFWDSolve, endM, endK, endM * endK, krem>(B_arr, LDB, RHSInPacket, rem);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -834,8 +807,8 @@ public:
|
|||||||
* Masked loads are used for cases where endK is not a multiple of PacketSize
|
* Masked loads are used for cases where endK is not a multiple of PacketSize
|
||||||
*/
|
*/
|
||||||
template <bool isFWDSolve, int64_t endM, int64_t endK, bool krem = false>
|
template <bool isFWDSolve, int64_t endM, int64_t endK, bool krem = false>
|
||||||
static EIGEN_ALWAYS_INLINE
|
static EIGEN_ALWAYS_INLINE void storeRHS(Scalar *B_arr, int64_t LDB,
|
||||||
void storeRHS(Scalar* B_arr, int64_t LDB, PacketBlock<vec,EIGEN_AVX_MAX_NUM_ACC> &RHSInPacket, int64_t rem = 0) {
|
PacketBlock<vec, EIGEN_AVX_MAX_NUM_ACC> &RHSInPacket, int64_t rem = 0) {
|
||||||
aux_storeRHS<isFWDSolve, endM, endK, endM * endK, krem>(B_arr, LDB, RHSInPacket, rem);
|
aux_storeRHS<isFWDSolve, endM, endK, endM * endK, krem>(B_arr, LDB, RHSInPacket, rem);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -843,8 +816,8 @@ public:
|
|||||||
* Only used if Triangular matrix has non-unit diagonal values
|
* Only used if Triangular matrix has non-unit diagonal values
|
||||||
*/
|
*/
|
||||||
template <int64_t currM, int64_t endK>
|
template <int64_t currM, int64_t endK>
|
||||||
static EIGEN_ALWAYS_INLINE
|
static EIGEN_ALWAYS_INLINE void divRHSByDiag(PacketBlock<vec, EIGEN_AVX_MAX_NUM_ACC> &RHSInPacket,
|
||||||
void divRHSByDiag(PacketBlock<vec,EIGEN_AVX_MAX_NUM_ACC> &RHSInPacket, PacketBlock<vec,EIGEN_AVX_MAX_NUM_ROW> &AInPacket) {
|
PacketBlock<vec, EIGEN_AVX_MAX_NUM_ROW> &AInPacket) {
|
||||||
aux_divRHSByDiag<currM, endK, endK>(RHSInPacket, AInPacket);
|
aux_divRHSByDiag<currM, endK, endK>(RHSInPacket, AInPacket);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -852,9 +825,11 @@ public:
|
|||||||
* Update right-hand sides (stored in avx registers)
|
* Update right-hand sides (stored in avx registers)
|
||||||
* Traversing along the column A_{i,currentM}, where currentM <= i <= endM, and broadcasting each value to AInPacket.
|
* Traversing along the column A_{i,currentM}, where currentM <= i <= endM, and broadcasting each value to AInPacket.
|
||||||
**/
|
**/
|
||||||
template<bool isARowMajor, bool isFWDSolve, bool isUnitDiag, int64_t startM, int64_t endM, int64_t endK, int64_t currentM>
|
template <bool isARowMajor, bool isFWDSolve, bool isUnitDiag, int64_t startM, int64_t endM, int64_t endK,
|
||||||
static EIGEN_ALWAYS_INLINE
|
int64_t currentM>
|
||||||
void updateRHS(Scalar *A_arr, int64_t LDA, PacketBlock<vec,EIGEN_AVX_MAX_NUM_ACC> &RHSInPacket, PacketBlock<vec,EIGEN_AVX_MAX_NUM_ROW> &AInPacket) {
|
static EIGEN_ALWAYS_INLINE void updateRHS(Scalar *A_arr, int64_t LDA,
|
||||||
|
PacketBlock<vec, EIGEN_AVX_MAX_NUM_ACC> &RHSInPacket,
|
||||||
|
PacketBlock<vec, EIGEN_AVX_MAX_NUM_ROW> &AInPacket) {
|
||||||
aux_updateRHS<isARowMajor, isFWDSolve, isUnitDiag, startM, endM, endK, (endM - startM) * endK, currentM>(
|
aux_updateRHS<isARowMajor, isFWDSolve, isUnitDiag, startM, endM, endK, (endM - startM) * endK, currentM>(
|
||||||
A_arr, LDA, RHSInPacket, AInPacket);
|
A_arr, LDA, RHSInPacket, AInPacket);
|
||||||
}
|
}
|
||||||
@ -866,11 +841,11 @@ public:
|
|||||||
* isUnitDiag: true => triangular matrix has unit diagonal.
|
* isUnitDiag: true => triangular matrix has unit diagonal.
|
||||||
*/
|
*/
|
||||||
template <bool isARowMajor, bool isFWDSolve, bool isUnitDiag, int64_t endM, int64_t numK>
|
template <bool isARowMajor, bool isFWDSolve, bool isUnitDiag, int64_t endM, int64_t numK>
|
||||||
static EIGEN_ALWAYS_INLINE
|
static EIGEN_ALWAYS_INLINE void triSolveMicroKernel(Scalar *A_arr, int64_t LDA,
|
||||||
void triSolveMicroKernel(Scalar *A_arr, int64_t LDA, PacketBlock<vec,EIGEN_AVX_MAX_NUM_ACC> &RHSInPacket, PacketBlock<vec,EIGEN_AVX_MAX_NUM_ROW> &AInPacket) {
|
PacketBlock<vec, EIGEN_AVX_MAX_NUM_ACC> &RHSInPacket,
|
||||||
|
PacketBlock<vec, EIGEN_AVX_MAX_NUM_ROW> &AInPacket) {
|
||||||
static_assert(numK >= 1 && numK <= 3, "numK out of range");
|
static_assert(numK >= 1 && numK <= 3, "numK out of range");
|
||||||
aux_triSolveMicroKernel<isARowMajor, isFWDSolve, isUnitDiag, endM, endM, numK>(
|
aux_triSolveMicroKernel<isARowMajor, isFWDSolve, isUnitDiag, endM, endM, numK>(A_arr, LDA, RHSInPacket, AInPacket);
|
||||||
A_arr, LDA, RHSInPacket, AInPacket);
|
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -902,8 +877,8 @@ public:
|
|||||||
* for(startN = 0; startN < endN; startN++)
|
* for(startN = 0; startN < endN; startN++)
|
||||||
**/
|
**/
|
||||||
template <int64_t endM, int64_t endN, int64_t counter>
|
template <int64_t endM, int64_t endN, int64_t counter>
|
||||||
static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)>
|
static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_setzero(
|
||||||
aux_setzero(PacketBlock<vec,EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm) {
|
PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm) {
|
||||||
constexpr int64_t counterReverse = endM * endN - counter;
|
constexpr int64_t counterReverse = endM * endN - counter;
|
||||||
constexpr int64_t startM = counterReverse / (endN);
|
constexpr int64_t startM = counterReverse / (endN);
|
||||||
constexpr int64_t startN = counterReverse % endN;
|
constexpr int64_t startN = counterReverse % endN;
|
||||||
@ -913,9 +888,8 @@ public:
|
|||||||
}
|
}
|
||||||
|
|
||||||
template <int64_t endM, int64_t endN, int64_t counter>
|
template <int64_t endM, int64_t endN, int64_t counter>
|
||||||
static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)>
|
static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_setzero(
|
||||||
aux_setzero(PacketBlock<vec,EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm)
|
PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm) {
|
||||||
{
|
|
||||||
EIGEN_UNUSED_VARIABLE(zmm);
|
EIGEN_UNUSED_VARIABLE(zmm);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -927,8 +901,8 @@ public:
|
|||||||
* for(startN = 0; startN < endN; startN++)
|
* for(startN = 0; startN < endN; startN++)
|
||||||
**/
|
**/
|
||||||
template <int64_t endM, int64_t endN, int64_t counter, bool rem>
|
template <int64_t endM, int64_t endN, int64_t counter, bool rem>
|
||||||
static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)>
|
static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_updateC(
|
||||||
aux_updateC(Scalar *C_arr, int64_t LDC, PacketBlock<vec,EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm, int64_t rem_ = 0) {
|
Scalar *C_arr, int64_t LDC, PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm, int64_t rem_ = 0) {
|
||||||
EIGEN_UNUSED_VARIABLE(rem_);
|
EIGEN_UNUSED_VARIABLE(rem_);
|
||||||
constexpr int64_t counterReverse = endM * endN - counter;
|
constexpr int64_t counterReverse = endM * endN - counter;
|
||||||
constexpr int64_t startM = counterReverse / (endN);
|
constexpr int64_t startM = counterReverse / (endN);
|
||||||
@ -937,18 +911,15 @@ public:
|
|||||||
EIGEN_IF_CONSTEXPR(rem)
|
EIGEN_IF_CONSTEXPR(rem)
|
||||||
zmm.packet[startN * endM + startM] =
|
zmm.packet[startN * endM + startM] =
|
||||||
padd(ploadu<vec>(&C_arr[(startN)*LDC + startM * PacketSize], remMask<PacketSize>(rem_)),
|
padd(ploadu<vec>(&C_arr[(startN)*LDC + startM * PacketSize], remMask<PacketSize>(rem_)),
|
||||||
zmm.packet[startN*endM + startM],
|
zmm.packet[startN * endM + startM], remMask<PacketSize>(rem_));
|
||||||
remMask<PacketSize>(rem_));
|
else zmm.packet[startN * endM + startM] =
|
||||||
else
|
|
||||||
zmm.packet[startN*endM + startM] =
|
|
||||||
padd(ploadu<vec>(&C_arr[(startN)*LDC + startM * PacketSize]), zmm.packet[startN * endM + startM]);
|
padd(ploadu<vec>(&C_arr[(startN)*LDC + startM * PacketSize]), zmm.packet[startN * endM + startM]);
|
||||||
aux_updateC<endM, endN, counter - 1, rem>(C_arr, LDC, zmm, rem_);
|
aux_updateC<endM, endN, counter - 1, rem>(C_arr, LDC, zmm, rem_);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <int64_t endM, int64_t endN, int64_t counter, bool rem>
|
template <int64_t endM, int64_t endN, int64_t counter, bool rem>
|
||||||
static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)>
|
static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_updateC(
|
||||||
aux_updateC(Scalar *C_arr, int64_t LDC, PacketBlock<vec,EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm, int64_t rem_ = 0)
|
Scalar *C_arr, int64_t LDC, PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm, int64_t rem_ = 0) {
|
||||||
{
|
|
||||||
EIGEN_UNUSED_VARIABLE(C_arr);
|
EIGEN_UNUSED_VARIABLE(C_arr);
|
||||||
EIGEN_UNUSED_VARIABLE(LDC);
|
EIGEN_UNUSED_VARIABLE(LDC);
|
||||||
EIGEN_UNUSED_VARIABLE(zmm);
|
EIGEN_UNUSED_VARIABLE(zmm);
|
||||||
@ -963,24 +934,23 @@ public:
|
|||||||
* for(startN = 0; startN < endN; startN++)
|
* for(startN = 0; startN < endN; startN++)
|
||||||
**/
|
**/
|
||||||
template <int64_t endM, int64_t endN, int64_t counter, bool rem>
|
template <int64_t endM, int64_t endN, int64_t counter, bool rem>
|
||||||
static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)>
|
static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_storeC(
|
||||||
aux_storeC(Scalar *C_arr, int64_t LDC, PacketBlock<vec,EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm, int64_t rem_ = 0) {
|
Scalar *C_arr, int64_t LDC, PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm, int64_t rem_ = 0) {
|
||||||
EIGEN_UNUSED_VARIABLE(rem_);
|
EIGEN_UNUSED_VARIABLE(rem_);
|
||||||
constexpr int64_t counterReverse = endM * endN - counter;
|
constexpr int64_t counterReverse = endM * endN - counter;
|
||||||
constexpr int64_t startM = counterReverse / (endN);
|
constexpr int64_t startM = counterReverse / (endN);
|
||||||
constexpr int64_t startN = counterReverse % endN;
|
constexpr int64_t startN = counterReverse % endN;
|
||||||
|
|
||||||
EIGEN_IF_CONSTEXPR(rem)
|
EIGEN_IF_CONSTEXPR(rem)
|
||||||
pstoreu<Scalar>(&C_arr[(startN) * LDC + startM*PacketSize], zmm.packet[startN*endM + startM], remMask<PacketSize>(rem_));
|
pstoreu<Scalar>(&C_arr[(startN)*LDC + startM * PacketSize], zmm.packet[startN * endM + startM],
|
||||||
else
|
remMask<PacketSize>(rem_));
|
||||||
pstoreu<Scalar>(&C_arr[(startN) * LDC + startM*PacketSize], zmm.packet[startN*endM + startM]);
|
else pstoreu<Scalar>(&C_arr[(startN)*LDC + startM * PacketSize], zmm.packet[startN * endM + startM]);
|
||||||
aux_storeC<endM, endN, counter - 1, rem>(C_arr, LDC, zmm, rem_);
|
aux_storeC<endM, endN, counter - 1, rem>(C_arr, LDC, zmm, rem_);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <int64_t endM, int64_t endN, int64_t counter, bool rem>
|
template <int64_t endM, int64_t endN, int64_t counter, bool rem>
|
||||||
static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)>
|
static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_storeC(
|
||||||
aux_storeC(Scalar *C_arr, int64_t LDC, PacketBlock<vec,EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm, int64_t rem_ = 0)
|
Scalar *C_arr, int64_t LDC, PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm, int64_t rem_ = 0) {
|
||||||
{
|
|
||||||
EIGEN_UNUSED_VARIABLE(C_arr);
|
EIGEN_UNUSED_VARIABLE(C_arr);
|
||||||
EIGEN_UNUSED_VARIABLE(LDC);
|
EIGEN_UNUSED_VARIABLE(LDC);
|
||||||
EIGEN_UNUSED_VARIABLE(zmm);
|
EIGEN_UNUSED_VARIABLE(zmm);
|
||||||
@ -994,8 +964,8 @@ public:
|
|||||||
* for(startL = 0; startL < endL; startL++)
|
* for(startL = 0; startL < endL; startL++)
|
||||||
**/
|
**/
|
||||||
template <int64_t unrollM, int64_t unrollN, int64_t endL, int64_t counter, bool rem>
|
template <int64_t unrollM, int64_t unrollN, int64_t endL, int64_t counter, bool rem>
|
||||||
static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)>
|
static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_startLoadB(
|
||||||
aux_startLoadB(Scalar *B_t, int64_t LDB, PacketBlock<vec,EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm, int64_t rem_ = 0) {
|
Scalar *B_t, int64_t LDB, PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm, int64_t rem_ = 0) {
|
||||||
EIGEN_UNUSED_VARIABLE(rem_);
|
EIGEN_UNUSED_VARIABLE(rem_);
|
||||||
constexpr int64_t counterReverse = endL - counter;
|
constexpr int64_t counterReverse = endL - counter;
|
||||||
constexpr int64_t startL = counterReverse;
|
constexpr int64_t startL = counterReverse;
|
||||||
@ -1003,18 +973,15 @@ public:
|
|||||||
EIGEN_IF_CONSTEXPR(rem)
|
EIGEN_IF_CONSTEXPR(rem)
|
||||||
zmm.packet[unrollM * unrollN + startL] =
|
zmm.packet[unrollM * unrollN + startL] =
|
||||||
ploadu<vec>(&B_t[(startL / unrollM) * LDB + (startL % unrollM) * PacketSize], remMask<PacketSize>(rem_));
|
ploadu<vec>(&B_t[(startL / unrollM) * LDB + (startL % unrollM) * PacketSize], remMask<PacketSize>(rem_));
|
||||||
else
|
else zmm.packet[unrollM * unrollN + startL] =
|
||||||
zmm.packet[unrollM*unrollN+startL] = ploadu<vec>(&B_t[(startL/unrollM)*LDB + (startL%unrollM)*PacketSize]);
|
ploadu<vec>(&B_t[(startL / unrollM) * LDB + (startL % unrollM) * PacketSize]);
|
||||||
|
|
||||||
aux_startLoadB<unrollM, unrollN, endL, counter - 1, rem>(B_t, LDB, zmm, rem_);
|
aux_startLoadB<unrollM, unrollN, endL, counter - 1, rem>(B_t, LDB, zmm, rem_);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <int64_t unrollM, int64_t unrollN, int64_t endL, int64_t counter, bool rem>
|
template <int64_t unrollM, int64_t unrollN, int64_t endL, int64_t counter, bool rem>
|
||||||
static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)>
|
static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_startLoadB(
|
||||||
aux_startLoadB(
|
Scalar *B_t, int64_t LDB, PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm, int64_t rem_ = 0) {
|
||||||
Scalar *B_t, int64_t LDB,
|
|
||||||
PacketBlock<vec,EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm, int64_t rem_ = 0)
|
|
||||||
{
|
|
||||||
EIGEN_UNUSED_VARIABLE(B_t);
|
EIGEN_UNUSED_VARIABLE(B_t);
|
||||||
EIGEN_UNUSED_VARIABLE(LDB);
|
EIGEN_UNUSED_VARIABLE(LDB);
|
||||||
EIGEN_UNUSED_VARIABLE(zmm);
|
EIGEN_UNUSED_VARIABLE(zmm);
|
||||||
@ -1028,8 +995,8 @@ public:
|
|||||||
* for(startB = 0; startB < endB; startB++)
|
* for(startB = 0; startB < endB; startB++)
|
||||||
**/
|
**/
|
||||||
template <bool isARowMajor, int64_t unrollM, int64_t unrollN, int64_t endB, int64_t counter, int64_t numLoad>
|
template <bool isARowMajor, int64_t unrollM, int64_t unrollN, int64_t endB, int64_t counter, int64_t numLoad>
|
||||||
static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)>
|
static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_startBCastA(
|
||||||
aux_startBCastA(Scalar *A_t, int64_t LDA, PacketBlock<vec,EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm) {
|
Scalar *A_t, int64_t LDA, PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm) {
|
||||||
constexpr int64_t counterReverse = endB - counter;
|
constexpr int64_t counterReverse = endB - counter;
|
||||||
constexpr int64_t startB = counterReverse;
|
constexpr int64_t startB = counterReverse;
|
||||||
|
|
||||||
@ -1039,9 +1006,8 @@ public:
|
|||||||
}
|
}
|
||||||
|
|
||||||
template <bool isARowMajor, int64_t unrollM, int64_t unrollN, int64_t endB, int64_t counter, int64_t numLoad>
|
template <bool isARowMajor, int64_t unrollM, int64_t unrollN, int64_t endB, int64_t counter, int64_t numLoad>
|
||||||
static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)>
|
static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_startBCastA(
|
||||||
aux_startBCastA(Scalar *A_t, int64_t LDA, PacketBlock<vec,EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm)
|
Scalar *A_t, int64_t LDA, PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm) {
|
||||||
{
|
|
||||||
EIGEN_UNUSED_VARIABLE(A_t);
|
EIGEN_UNUSED_VARIABLE(A_t);
|
||||||
EIGEN_UNUSED_VARIABLE(LDA);
|
EIGEN_UNUSED_VARIABLE(LDA);
|
||||||
EIGEN_UNUSED_VARIABLE(zmm);
|
EIGEN_UNUSED_VARIABLE(zmm);
|
||||||
@ -1054,9 +1020,10 @@ public:
|
|||||||
* 1-D unroll
|
* 1-D unroll
|
||||||
* for(startM = 0; startM < endM; startM++)
|
* for(startM = 0; startM < endM; startM++)
|
||||||
**/
|
**/
|
||||||
template<int64_t endM, int64_t counter, int64_t unrollN, int64_t currK, int64_t unrollK, int64_t numLoad, int64_t numBCast, bool rem>
|
template <int64_t endM, int64_t counter, int64_t unrollN, int64_t currK, int64_t unrollK, int64_t numLoad,
|
||||||
static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)>
|
int64_t numBCast, bool rem>
|
||||||
aux_loadB(Scalar *B_t, int64_t LDB, PacketBlock<vec,EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm, int64_t rem_ = 0) {
|
static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_loadB(
|
||||||
|
Scalar *B_t, int64_t LDB, PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm, int64_t rem_ = 0) {
|
||||||
EIGEN_UNUSED_VARIABLE(rem_);
|
EIGEN_UNUSED_VARIABLE(rem_);
|
||||||
if ((numLoad / endM + currK < unrollK)) {
|
if ((numLoad / endM + currK < unrollK)) {
|
||||||
constexpr int64_t counterReverse = endM - counter;
|
constexpr int64_t counterReverse = endM - counter;
|
||||||
@ -1075,12 +1042,10 @@ public:
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template<int64_t endM, int64_t counter, int64_t unrollN, int64_t currK, int64_t unrollK, int64_t numLoad, int64_t numBCast, bool rem>
|
template <int64_t endM, int64_t counter, int64_t unrollN, int64_t currK, int64_t unrollK, int64_t numLoad,
|
||||||
static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)>
|
int64_t numBCast, bool rem>
|
||||||
aux_loadB(
|
static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_loadB(
|
||||||
Scalar *B_t, int64_t LDB,
|
Scalar *B_t, int64_t LDB, PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm, int64_t rem_ = 0) {
|
||||||
PacketBlock<vec,EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm, int64_t rem_ = 0)
|
|
||||||
{
|
|
||||||
EIGEN_UNUSED_VARIABLE(B_t);
|
EIGEN_UNUSED_VARIABLE(B_t);
|
||||||
EIGEN_UNUSED_VARIABLE(LDB);
|
EIGEN_UNUSED_VARIABLE(LDB);
|
||||||
EIGEN_UNUSED_VARIABLE(zmm);
|
EIGEN_UNUSED_VARIABLE(zmm);
|
||||||
@ -1095,11 +1060,11 @@ public:
|
|||||||
* for(startN = 0; startN < endN; startN++)
|
* for(startN = 0; startN < endN; startN++)
|
||||||
* for(startK = 0; startK < endK; startK++)
|
* for(startK = 0; startK < endK; startK++)
|
||||||
**/
|
**/
|
||||||
template<bool isARowMajor, int64_t endM, int64_t endN, int64_t endK, int64_t counter, int64_t numLoad, int64_t numBCast, bool rem>
|
template <bool isARowMajor, int64_t endM, int64_t endN, int64_t endK, int64_t counter, int64_t numLoad,
|
||||||
static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)>
|
int64_t numBCast, bool rem>
|
||||||
aux_microKernel(
|
static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_microKernel(
|
||||||
Scalar *B_t, Scalar* A_t, int64_t LDB, int64_t LDA,
|
Scalar *B_t, Scalar *A_t, int64_t LDB, int64_t LDA, PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm,
|
||||||
PacketBlock<vec,EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm, int64_t rem_ = 0) {
|
int64_t rem_ = 0) {
|
||||||
EIGEN_UNUSED_VARIABLE(rem_);
|
EIGEN_UNUSED_VARIABLE(rem_);
|
||||||
constexpr int64_t counterReverse = endM * endN * endK - counter;
|
constexpr int64_t counterReverse = endM * endN * endK - counter;
|
||||||
constexpr int startK = counterReverse / (endM * endN);
|
constexpr int startK = counterReverse / (endM * endN);
|
||||||
@ -1107,10 +1072,8 @@ public:
|
|||||||
constexpr int startM = counterReverse % endM;
|
constexpr int startM = counterReverse % endM;
|
||||||
|
|
||||||
EIGEN_IF_CONSTEXPR(startK == 0 && startM == 0 && startN == 0) {
|
EIGEN_IF_CONSTEXPR(startK == 0 && startM == 0 && startN == 0) {
|
||||||
gemm:: template
|
gemm::template startLoadB<endM, endN, numLoad, rem>(B_t, LDB, zmm, rem_);
|
||||||
startLoadB<endM, endN, numLoad, rem>(B_t, LDB, zmm, rem_);
|
gemm::template startBCastA<isARowMajor, endM, endN, numBCast, numLoad>(A_t, LDA, zmm);
|
||||||
gemm:: template
|
|
||||||
startBCastA<isARowMajor, endM, endN, numBCast, numLoad>(A_t, LDA, zmm);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
{
|
{
|
||||||
@ -1127,9 +1090,8 @@ public:
|
|||||||
}
|
}
|
||||||
// Bcast
|
// Bcast
|
||||||
EIGEN_IF_CONSTEXPR(startM == endM - 1 && (numBCast + startN + startK * endN < endK * endN)) {
|
EIGEN_IF_CONSTEXPR(startM == endM - 1 && (numBCast + startN + startK * endN < endK * endN)) {
|
||||||
zmm.packet[endM*endN+numLoad+(startN+startK*endN)%numBCast] =
|
zmm.packet[endM * endN + numLoad + (startN + startK * endN) % numBCast] = pload1<vec>(&A_t[idA<isARowMajor>(
|
||||||
pload1<vec>(&A_t[idA<isARowMajor>((numBCast + startN + startK*endN)%endN,
|
(numBCast + startN + startK * endN) % endN, (numBCast + startN + startK * endN) / endN, LDA)]);
|
||||||
(numBCast + startN + startK*endN)/endN, LDA)]);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1138,15 +1100,13 @@ public:
|
|||||||
gemm::template loadB<endM, endN, startK, endK, numLoad, numBCast, rem>(B_t, LDB, zmm, rem_);
|
gemm::template loadB<endM, endN, startK, endK, numLoad, numBCast, rem>(B_t, LDB, zmm, rem_);
|
||||||
}
|
}
|
||||||
aux_microKernel<isARowMajor, endM, endN, endK, counter - 1, numLoad, numBCast, rem>(B_t, A_t, LDB, LDA, zmm, rem_);
|
aux_microKernel<isARowMajor, endM, endN, endK, counter - 1, numLoad, numBCast, rem>(B_t, A_t, LDB, LDA, zmm, rem_);
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
template<bool isARowMajor, int64_t endM, int64_t endN, int64_t endK, int64_t counter, int64_t numLoad, int64_t numBCast, bool rem>
|
template <bool isARowMajor, int64_t endM, int64_t endN, int64_t endK, int64_t counter, int64_t numLoad,
|
||||||
static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)>
|
int64_t numBCast, bool rem>
|
||||||
aux_microKernel(
|
static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_microKernel(
|
||||||
Scalar *B_t, Scalar* A_t, int64_t LDB, int64_t LDA,
|
Scalar *B_t, Scalar *A_t, int64_t LDB, int64_t LDA, PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm,
|
||||||
PacketBlock<vec,EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm, int64_t rem_ = 0)
|
int64_t rem_ = 0) {
|
||||||
{
|
|
||||||
EIGEN_UNUSED_VARIABLE(B_t);
|
EIGEN_UNUSED_VARIABLE(B_t);
|
||||||
EIGEN_UNUSED_VARIABLE(A_t);
|
EIGEN_UNUSED_VARIABLE(A_t);
|
||||||
EIGEN_UNUSED_VARIABLE(LDB);
|
EIGEN_UNUSED_VARIABLE(LDB);
|
||||||
@ -1160,8 +1120,7 @@ public:
|
|||||||
********************************************************/
|
********************************************************/
|
||||||
|
|
||||||
template <int64_t endM, int64_t endN>
|
template <int64_t endM, int64_t endN>
|
||||||
static EIGEN_ALWAYS_INLINE
|
static EIGEN_ALWAYS_INLINE void setzero(PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm) {
|
||||||
void setzero(PacketBlock<vec,EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm){
|
|
||||||
aux_setzero<endM, endN, endM * endN>(zmm);
|
aux_setzero<endM, endN, endM * endN>(zmm);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1169,15 +1128,17 @@ public:
|
|||||||
* Ideally the compiler folds these into vaddp{s,d} with an embedded memory load.
|
* Ideally the compiler folds these into vaddp{s,d} with an embedded memory load.
|
||||||
*/
|
*/
|
||||||
template <int64_t endM, int64_t endN, bool rem = false>
|
template <int64_t endM, int64_t endN, bool rem = false>
|
||||||
static EIGEN_ALWAYS_INLINE
|
static EIGEN_ALWAYS_INLINE void updateC(Scalar *C_arr, int64_t LDC,
|
||||||
void updateC(Scalar *C_arr, int64_t LDC, PacketBlock<vec,EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm, int64_t rem_ = 0){
|
PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm,
|
||||||
|
int64_t rem_ = 0) {
|
||||||
EIGEN_UNUSED_VARIABLE(rem_);
|
EIGEN_UNUSED_VARIABLE(rem_);
|
||||||
aux_updateC<endM, endN, endM * endN, rem>(C_arr, LDC, zmm, rem_);
|
aux_updateC<endM, endN, endM * endN, rem>(C_arr, LDC, zmm, rem_);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <int64_t endM, int64_t endN, bool rem = false>
|
template <int64_t endM, int64_t endN, bool rem = false>
|
||||||
static EIGEN_ALWAYS_INLINE
|
static EIGEN_ALWAYS_INLINE void storeC(Scalar *C_arr, int64_t LDC,
|
||||||
void storeC(Scalar *C_arr, int64_t LDC, PacketBlock<vec,EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm, int64_t rem_ = 0){
|
PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm,
|
||||||
|
int64_t rem_ = 0) {
|
||||||
EIGEN_UNUSED_VARIABLE(rem_);
|
EIGEN_UNUSED_VARIABLE(rem_);
|
||||||
aux_storeC<endM, endN, endM * endN, rem>(C_arr, LDC, zmm, rem_);
|
aux_storeC<endM, endN, endM * endN, rem>(C_arr, LDC, zmm, rem_);
|
||||||
}
|
}
|
||||||
@ -1186,8 +1147,9 @@ public:
|
|||||||
* Use numLoad registers for loading B at start of microKernel
|
* Use numLoad registers for loading B at start of microKernel
|
||||||
*/
|
*/
|
||||||
template <int64_t unrollM, int64_t unrollN, int64_t endL, bool rem>
|
template <int64_t unrollM, int64_t unrollN, int64_t endL, bool rem>
|
||||||
static EIGEN_ALWAYS_INLINE
|
static EIGEN_ALWAYS_INLINE void startLoadB(Scalar *B_t, int64_t LDB,
|
||||||
void startLoadB(Scalar *B_t, int64_t LDB, PacketBlock<vec,EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm, int64_t rem_ = 0){
|
PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm,
|
||||||
|
int64_t rem_ = 0) {
|
||||||
EIGEN_UNUSED_VARIABLE(rem_);
|
EIGEN_UNUSED_VARIABLE(rem_);
|
||||||
aux_startLoadB<unrollM, unrollN, endL, endL, rem>(B_t, LDB, zmm, rem_);
|
aux_startLoadB<unrollM, unrollN, endL, endL, rem>(B_t, LDB, zmm, rem_);
|
||||||
}
|
}
|
||||||
@ -1196,8 +1158,8 @@ public:
|
|||||||
* Use numBCast registers for broadcasting A at start of microKernel
|
* Use numBCast registers for broadcasting A at start of microKernel
|
||||||
*/
|
*/
|
||||||
template <bool isARowMajor, int64_t unrollM, int64_t unrollN, int64_t endB, int64_t numLoad>
|
template <bool isARowMajor, int64_t unrollM, int64_t unrollN, int64_t endB, int64_t numLoad>
|
||||||
static EIGEN_ALWAYS_INLINE
|
static EIGEN_ALWAYS_INLINE void startBCastA(Scalar *A_t, int64_t LDA,
|
||||||
void startBCastA(Scalar *A_t, int64_t LDA, PacketBlock<vec,EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm){
|
PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm) {
|
||||||
aux_startBCastA<isARowMajor, unrollM, unrollN, endB, endB, numLoad>(A_t, LDA, zmm);
|
aux_startBCastA<isARowMajor, unrollM, unrollN, endB, endB, numLoad>(A_t, LDA, zmm);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1205,9 +1167,9 @@ public:
|
|||||||
* Loads next set of B into vector registers between each K unroll.
|
* Loads next set of B into vector registers between each K unroll.
|
||||||
*/
|
*/
|
||||||
template <int64_t endM, int64_t unrollN, int64_t currK, int64_t unrollK, int64_t numLoad, int64_t numBCast, bool rem>
|
template <int64_t endM, int64_t unrollN, int64_t currK, int64_t unrollK, int64_t numLoad, int64_t numBCast, bool rem>
|
||||||
static EIGEN_ALWAYS_INLINE
|
static EIGEN_ALWAYS_INLINE void loadB(Scalar *B_t, int64_t LDB,
|
||||||
void loadB(
|
PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm,
|
||||||
Scalar *B_t, int64_t LDB, PacketBlock<vec,EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm, int64_t rem_ = 0){
|
int64_t rem_ = 0) {
|
||||||
EIGEN_UNUSED_VARIABLE(rem_);
|
EIGEN_UNUSED_VARIABLE(rem_);
|
||||||
aux_loadB<endM, endM, unrollN, currK, unrollK, numLoad, numBCast, rem>(B_t, LDB, zmm, rem_);
|
aux_loadB<endM, endM, unrollN, currK, unrollK, numLoad, numBCast, rem>(B_t, LDB, zmm, rem_);
|
||||||
}
|
}
|
||||||
@ -1235,18 +1197,16 @@ public:
|
|||||||
* From testing, there are no register spills with clang. There are register spills with GNU, which
|
* From testing, there are no register spills with clang. There are register spills with GNU, which
|
||||||
* causes a performance hit.
|
* causes a performance hit.
|
||||||
*/
|
*/
|
||||||
template<bool isARowMajor, int64_t endM, int64_t endN, int64_t endK, int64_t numLoad, int64_t numBCast, bool rem = false>
|
template <bool isARowMajor, int64_t endM, int64_t endN, int64_t endK, int64_t numLoad, int64_t numBCast,
|
||||||
static EIGEN_ALWAYS_INLINE
|
bool rem = false>
|
||||||
void microKernel(
|
static EIGEN_ALWAYS_INLINE void microKernel(Scalar *B_t, Scalar *A_t, int64_t LDB, int64_t LDA,
|
||||||
Scalar *B_t, Scalar* A_t, int64_t LDB, int64_t LDA,
|
PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm,
|
||||||
PacketBlock<vec,EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm, int64_t rem_ = 0){
|
int64_t rem_ = 0) {
|
||||||
EIGEN_UNUSED_VARIABLE(rem_);
|
EIGEN_UNUSED_VARIABLE(rem_);
|
||||||
aux_microKernel<isARowMajor,endM, endN, endK, endM*endN*endK, numLoad, numBCast, rem>(
|
aux_microKernel<isARowMajor, endM, endN, endK, endM * endN * endK, numLoad, numBCast, rem>(B_t, A_t, LDB, LDA, zmm,
|
||||||
B_t, A_t, LDB, LDA, zmm, rem_);
|
rem_);
|
||||||
}
|
}
|
||||||
|
|
||||||
};
|
};
|
||||||
} // namespace unrolls
|
} // namespace unrolls
|
||||||
|
|
||||||
|
|
||||||
#endif // EIGEN_UNROLLS_IMPL_H
|
#endif // EIGEN_UNROLLS_IMPL_H
|
||||||
|
@ -171,7 +171,7 @@ EIGEN_DONT_INLINE void triangular_solve_matrix<Scalar,Index,OnTheLeft,Mode,Conju
|
|||||||
std::ptrdiff_t l1, l2, l3;
|
std::ptrdiff_t l1, l2, l3;
|
||||||
manage_caching_sizes(GetAction, &l1, &l2, &l3);
|
manage_caching_sizes(GetAction, &l1, &l2, &l3);
|
||||||
|
|
||||||
#if defined(EIGEN_ENABLE_AVX512_NOCOPY_TRSML_CUTOFFS)
|
#if (EIGEN_ENABLE_AVX512_NOCOPY_TRSM_L_CUTOFFS)
|
||||||
EIGEN_IF_CONSTEXPR( (OtherInnerStride == 1 &&
|
EIGEN_IF_CONSTEXPR( (OtherInnerStride == 1 &&
|
||||||
(std::is_same<Scalar,float>::value ||
|
(std::is_same<Scalar,float>::value ||
|
||||||
std::is_same<Scalar,double>::value)) ) {
|
std::is_same<Scalar,double>::value)) ) {
|
||||||
@ -246,7 +246,7 @@ EIGEN_DONT_INLINE void triangular_solve_matrix<Scalar,Index,OnTheLeft,Mode,Conju
|
|||||||
// tr solve
|
// tr solve
|
||||||
{
|
{
|
||||||
Index i = IsLower ? k2+k1 : k2-k1;
|
Index i = IsLower ? k2+k1 : k2-k1;
|
||||||
#if defined(EIGEN_USE_AVX512_TRSM_L_KERNELS)
|
#if EIGEN_USE_AVX512_TRSM_L_KERNELS
|
||||||
EIGEN_IF_CONSTEXPR( (OtherInnerStride == 1 &&
|
EIGEN_IF_CONSTEXPR( (OtherInnerStride == 1 &&
|
||||||
(std::is_same<Scalar,float>::value ||
|
(std::is_same<Scalar,float>::value ||
|
||||||
std::is_same<Scalar,double>::value)) ) {
|
std::is_same<Scalar,double>::value)) ) {
|
||||||
@ -318,7 +318,7 @@ EIGEN_DONT_INLINE void triangular_solve_matrix<Scalar,Index,OnTheRight,Mode,Conj
|
|||||||
{
|
{
|
||||||
Index rows = otherSize;
|
Index rows = otherSize;
|
||||||
|
|
||||||
#if defined(EIGEN_ENABLE_AVX512_NOCOPY_TRSM_R_CUTOFFS)
|
#if EIGEN_ENABLE_AVX512_NOCOPY_TRSM_R_CUTOFFS
|
||||||
EIGEN_IF_CONSTEXPR( (OtherInnerStride == 1 &&
|
EIGEN_IF_CONSTEXPR( (OtherInnerStride == 1 &&
|
||||||
(std::is_same<Scalar,float>::value ||
|
(std::is_same<Scalar,float>::value ||
|
||||||
std::is_same<Scalar,double>::value)) ) {
|
std::is_same<Scalar,double>::value)) ) {
|
||||||
|
Loading…
x
Reference in New Issue
Block a user