mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-07-16 10:01:49 +08:00
AVX512 TRSM Kernels respect EIGEN_NO_MALLOC
This commit is contained in:
parent
9960a30422
commit
28812d2ebb
@ -14,6 +14,13 @@
|
||||
|
||||
#define EIGEN_USE_AVX512_TRSM_KERNELS // Comment out to prevent using optimized trsm kernels.
|
||||
|
||||
#if defined(EIGEN_USE_AVX512_TRSM_KERNELS)
|
||||
#define EIGEN_USE_AVX512_TRSM_R_KERNELS
|
||||
#if !defined(EIGEN_NO_MALLOC) // Separate MACRO since these kernels require malloc
|
||||
#define EIGEN_USE_AVX512_TRSM_L_KERNELS
|
||||
#endif
|
||||
#endif
|
||||
|
||||
#if defined(EIGEN_HAS_CXX17_IFCONSTEXPR)
|
||||
#define EIGEN_IF_CONSTEXPR(X) if constexpr (X)
|
||||
#else
|
||||
@ -61,6 +68,14 @@ typedef Packet4d vecHalfDouble;
|
||||
*
|
||||
*/
|
||||
#define EIGEN_ENABLE_AVX512_NOCOPY_TRSM_CUTOFFS // Comment out to disable no-copy dispatch
|
||||
|
||||
#if defined(EIGEN_ENABLE_AVX512_NOCOPY_TRSM_CUTOFFS)
|
||||
#define EIGEN_ENABLE_AVX512_NOCOPY_TRSM_R_CUTOFFS
|
||||
#if !defined(EIGEN_NO_MALLOC)
|
||||
#define EIGEN_ENABLE_AVX512_NOCOPY_TRSM_L_CUTOFFS
|
||||
#endif
|
||||
#endif
|
||||
|
||||
template <typename Scalar>
|
||||
int64_t avx512_trsm_cutoff(int64_t L2Size, int64_t N, double L2Cap){
|
||||
const int64_t U3 = 3*packet_traits<Scalar>::size;
|
||||
@ -882,11 +897,7 @@ void triSolve(Scalar *A_arr, Scalar *B_arr, int64_t M, int64_t numRHS, int64_t L
|
||||
* The updated row-major copy of B is reused in the GEMM updates.
|
||||
*/
|
||||
sizeBTemp = (((std::min(kB, numRHS) + 15)/16+ 4)*16)*numM;
|
||||
#if EIGEN_COMP_MSVC
|
||||
B_temp = (Scalar*) _aligned_malloc(sizeof(Scalar)*sizeBTemp,4096);
|
||||
#else
|
||||
B_temp = (Scalar*) aligned_alloc(4096,sizeof(Scalar)*sizeBTemp);
|
||||
#endif
|
||||
B_temp = (Scalar*) handmade_aligned_malloc(sizeof(Scalar)*sizeBTemp,4096);
|
||||
}
|
||||
for(int64_t k = 0; k < numRHS; k += kB) {
|
||||
int64_t bK = numRHS - k > kB ? kB : numRHS - k;
|
||||
@ -1030,55 +1041,29 @@ void triSolve(Scalar *A_arr, Scalar *B_arr, int64_t M, int64_t numRHS, int64_t L
|
||||
}
|
||||
}
|
||||
}
|
||||
#if EIGEN_COMP_MSVC
|
||||
EIGEN_IF_CONSTEXPR(!isBRowMajor) _aligned_free(B_temp);
|
||||
#else
|
||||
EIGEN_IF_CONSTEXPR(!isBRowMajor) free(B_temp);
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename Scalar, bool isARowMajor = true, bool isCRowMajor = true>
|
||||
void gemmKer(Scalar *A_arr, Scalar *B_arr, Scalar *C_arr,
|
||||
int64_t M, int64_t N, int64_t K,
|
||||
int64_t LDA, int64_t LDB, int64_t LDC) {
|
||||
gemmKernel<Scalar, isARowMajor, isCRowMajor, true, true>(B_arr, A_arr, C_arr, N, M, K, LDB, LDA, LDC);
|
||||
EIGEN_IF_CONSTEXPR(!isBRowMajor) handmade_aligned_free(B_temp);
|
||||
}
|
||||
|
||||
|
||||
// Template specializations of trsmKernelL/R for float/double and inner strides of 1.
|
||||
#if defined(EIGEN_USE_AVX512_TRSM_KERNELS)
|
||||
template <typename Scalar, typename Index, int Mode, bool Conjugate, int TriStorageOrder,int OtherInnerStride>
|
||||
struct trsm_kernels;
|
||||
struct trsmKernelR;
|
||||
|
||||
template <typename Index, int Mode, int TriStorageOrder>
|
||||
struct trsm_kernels<float, Index, Mode, false, TriStorageOrder, 1>{
|
||||
static void trsmKernelL(Index size, Index otherSize, const float* _tri, Index triStride,
|
||||
float* _other, Index otherIncr, Index otherStride);
|
||||
static void trsmKernelR(Index size, Index otherSize, const float* _tri, Index triStride,
|
||||
struct trsmKernelR<float, Index, Mode, false, TriStorageOrder, 1>{
|
||||
static void kernel(Index size, Index otherSize, const float* _tri, Index triStride,
|
||||
float* _other, Index otherIncr, Index otherStride);
|
||||
};
|
||||
|
||||
template <typename Index, int Mode, int TriStorageOrder>
|
||||
struct trsm_kernels<double, Index, Mode, false, TriStorageOrder, 1>{
|
||||
static void trsmKernelL(Index size, Index otherSize, const double* _tri, Index triStride,
|
||||
double* _other, Index otherIncr, Index otherStride);
|
||||
static void trsmKernelR(Index size, Index otherSize, const double* _tri, Index triStride,
|
||||
struct trsmKernelR<double, Index, Mode, false, TriStorageOrder, 1>{
|
||||
static void kernel(Index size, Index otherSize, const double* _tri, Index triStride,
|
||||
double* _other, Index otherIncr, Index otherStride);
|
||||
};
|
||||
|
||||
template <typename Index, int Mode, int TriStorageOrder>
|
||||
EIGEN_DONT_INLINE void trsm_kernels<float, Index, Mode, false, TriStorageOrder, 1>::trsmKernelL(
|
||||
Index size, Index otherSize,
|
||||
const float* _tri, Index triStride,
|
||||
float* _other, Index otherIncr, Index otherStride)
|
||||
{
|
||||
EIGEN_UNUSED_VARIABLE(otherIncr);
|
||||
triSolve<float, TriStorageOrder==RowMajor, false, (Mode&Lower)==Lower, (Mode & UnitDiag)!=0>(
|
||||
const_cast<float*>(_tri), _other, size, otherSize, triStride, otherStride);
|
||||
}
|
||||
|
||||
template <typename Index, int Mode, int TriStorageOrder>
|
||||
EIGEN_DONT_INLINE void trsm_kernels<float, Index, Mode, false, TriStorageOrder, 1>::trsmKernelR(
|
||||
EIGEN_DONT_INLINE void trsmKernelR<float, Index, Mode, false, TriStorageOrder, 1>::kernel(
|
||||
Index size, Index otherSize,
|
||||
const float* _tri, Index triStride,
|
||||
float* _other, Index otherIncr, Index otherStride)
|
||||
@ -1089,18 +1074,7 @@ EIGEN_DONT_INLINE void trsm_kernels<float, Index, Mode, false, TriStorageOrder,
|
||||
}
|
||||
|
||||
template <typename Index, int Mode, int TriStorageOrder>
|
||||
EIGEN_DONT_INLINE void trsm_kernels<double, Index, Mode, false, TriStorageOrder, 1>::trsmKernelL(
|
||||
Index size, Index otherSize,
|
||||
const double* _tri, Index triStride,
|
||||
double* _other, Index otherIncr, Index otherStride)
|
||||
{
|
||||
EIGEN_UNUSED_VARIABLE(otherIncr);
|
||||
triSolve<double, TriStorageOrder==RowMajor, false, (Mode&Lower)==Lower, (Mode & UnitDiag)!=0>(
|
||||
const_cast<double*>(_tri), _other, size, otherSize, triStride, otherStride);
|
||||
}
|
||||
|
||||
template <typename Index, int Mode, int TriStorageOrder>
|
||||
EIGEN_DONT_INLINE void trsm_kernels<double, Index, Mode, false, TriStorageOrder, 1>::trsmKernelR(
|
||||
EIGEN_DONT_INLINE void trsmKernelR<double, Index, Mode, false, TriStorageOrder, 1>::kernel(
|
||||
Index size, Index otherSize,
|
||||
const double* _tri, Index triStride,
|
||||
double* _other, Index otherIncr, Index otherStride)
|
||||
@ -1109,6 +1083,46 @@ EIGEN_DONT_INLINE void trsm_kernels<double, Index, Mode, false, TriStorageOrder,
|
||||
triSolve<double, TriStorageOrder!=RowMajor, true, (Mode&Lower)!=Lower, (Mode & UnitDiag)!=0>(
|
||||
const_cast<double*>(_tri), _other, size, otherSize, triStride, otherStride);
|
||||
}
|
||||
|
||||
// These trsm kernels require temporary memory allocation, so disable them if malloc is not allowed.
|
||||
#if defined(EIGEN_USE_AVX512_TRSM_L_KERNELS)
|
||||
template <typename Scalar, typename Index, int Mode, bool Conjugate, int TriStorageOrder,int OtherInnerStride>
|
||||
struct trsmKernelL;
|
||||
|
||||
template <typename Index, int Mode, int TriStorageOrder>
|
||||
struct trsmKernelL<float, Index, Mode, false, TriStorageOrder, 1>{
|
||||
static void kernel(Index size, Index otherSize, const float* _tri, Index triStride,
|
||||
float* _other, Index otherIncr, Index otherStride);
|
||||
};
|
||||
|
||||
template <typename Index, int Mode, int TriStorageOrder>
|
||||
struct trsmKernelL<double, Index, Mode, false, TriStorageOrder, 1>{
|
||||
static void kernel(Index size, Index otherSize, const double* _tri, Index triStride,
|
||||
double* _other, Index otherIncr, Index otherStride);
|
||||
};
|
||||
|
||||
template <typename Index, int Mode, int TriStorageOrder>
|
||||
EIGEN_DONT_INLINE void trsmKernelL<float, Index, Mode, false, TriStorageOrder, 1>::kernel(
|
||||
Index size, Index otherSize,
|
||||
const float* _tri, Index triStride,
|
||||
float* _other, Index otherIncr, Index otherStride)
|
||||
{
|
||||
EIGEN_UNUSED_VARIABLE(otherIncr);
|
||||
triSolve<float, TriStorageOrder==RowMajor, false, (Mode&Lower)==Lower, (Mode & UnitDiag)!=0>(
|
||||
const_cast<float*>(_tri), _other, size, otherSize, triStride, otherStride);
|
||||
}
|
||||
|
||||
template <typename Index, int Mode, int TriStorageOrder>
|
||||
EIGEN_DONT_INLINE void trsmKernelL<double, Index, Mode, false, TriStorageOrder, 1>::kernel(
|
||||
Index size, Index otherSize,
|
||||
const double* _tri, Index triStride,
|
||||
double* _other, Index otherIncr, Index otherStride)
|
||||
{
|
||||
EIGEN_UNUSED_VARIABLE(otherIncr);
|
||||
triSolve<double, TriStorageOrder==RowMajor, false, (Mode&Lower)==Lower, (Mode & UnitDiag)!=0>(
|
||||
const_cast<double*>(_tri), _other, size, otherSize, triStride, otherStride);
|
||||
}
|
||||
#endif //EIGEN_USE_AVX512_TRSM_L_KERNELS
|
||||
#endif //EIGEN_USE_AVX512_TRSM_KERNELS
|
||||
}
|
||||
}
|
||||
|
@ -18,24 +18,27 @@ namespace Eigen {
|
||||
namespace internal {
|
||||
|
||||
template <typename Scalar, typename Index, int Mode, bool Conjugate, int TriStorageOrder,int OtherInnerStride>
|
||||
struct trsm_kernels {
|
||||
struct trsmKernelL {
|
||||
// Generic Implementation of triangular solve for triangular matrix on left and multiple rhs.
|
||||
// Handles non-packed matrices.
|
||||
static void trsmKernelL(
|
||||
Index size, Index otherSize,
|
||||
const Scalar* _tri, Index triStride,
|
||||
Scalar* _other, Index otherIncr, Index otherStride);
|
||||
|
||||
// Generic Implementation of triangular solve for triangular matrix on right and multiple lhs.
|
||||
// Handles non-packed matrices.
|
||||
static void trsmKernelR(
|
||||
static void kernel(
|
||||
Index size, Index otherSize,
|
||||
const Scalar* _tri, Index triStride,
|
||||
Scalar* _other, Index otherIncr, Index otherStride);
|
||||
};
|
||||
|
||||
template <typename Scalar, typename Index, int Mode, bool Conjugate, int TriStorageOrder,int OtherInnerStride>
|
||||
EIGEN_STRONG_INLINE void trsm_kernels<Scalar, Index, Mode, Conjugate, TriStorageOrder, OtherInnerStride>::trsmKernelL(
|
||||
struct trsmKernelR {
|
||||
// Generic Implementation of triangular solve for triangular matrix on right and multiple lhs.
|
||||
// Handles non-packed matrices.
|
||||
static void kernel(
|
||||
Index size, Index otherSize,
|
||||
const Scalar* _tri, Index triStride,
|
||||
Scalar* _other, Index otherIncr, Index otherStride);
|
||||
};
|
||||
|
||||
template <typename Scalar, typename Index, int Mode, bool Conjugate, int TriStorageOrder,int OtherInnerStride>
|
||||
EIGEN_STRONG_INLINE void trsmKernelL<Scalar, Index, Mode, Conjugate, TriStorageOrder, OtherInnerStride>::kernel(
|
||||
Index size, Index otherSize,
|
||||
const Scalar* _tri, Index triStride,
|
||||
Scalar* _other, Index otherIncr, Index otherStride)
|
||||
@ -86,7 +89,7 @@ EIGEN_STRONG_INLINE void trsm_kernels<Scalar, Index, Mode, Conjugate, TriStorage
|
||||
|
||||
|
||||
template <typename Scalar, typename Index, int Mode, bool Conjugate, int TriStorageOrder, int OtherInnerStride>
|
||||
EIGEN_STRONG_INLINE void trsm_kernels<Scalar, Index, Mode, Conjugate, TriStorageOrder, OtherInnerStride>::trsmKernelR(
|
||||
EIGEN_STRONG_INLINE void trsmKernelR<Scalar, Index, Mode, Conjugate, TriStorageOrder, OtherInnerStride>::kernel(
|
||||
Index size, Index otherSize,
|
||||
const Scalar* _tri, Index triStride,
|
||||
Scalar* _other, Index otherIncr, Index otherStride)
|
||||
@ -168,7 +171,7 @@ EIGEN_DONT_INLINE void triangular_solve_matrix<Scalar,Index,OnTheLeft,Mode,Conju
|
||||
std::ptrdiff_t l1, l2, l3;
|
||||
manage_caching_sizes(GetAction, &l1, &l2, &l3);
|
||||
|
||||
#if defined(EIGEN_ENABLE_AVX512_NOCOPY_TRSM_CUTOFFS)
|
||||
#if defined(EIGEN_ENABLE_AVX512_NOCOPY_TRSML_CUTOFFS)
|
||||
EIGEN_IF_CONSTEXPR( (OtherInnerStride == 1 &&
|
||||
(std::is_same<Scalar,float>::value ||
|
||||
std::is_same<Scalar,double>::value)) ) {
|
||||
@ -177,7 +180,7 @@ EIGEN_DONT_INLINE void triangular_solve_matrix<Scalar,Index,OnTheLeft,Mode,Conju
|
||||
// TODO: Investigate better heuristics for cutoffs.
|
||||
double L2Cap = 0.5; // 50% of L2 size
|
||||
if (size < avx512_trsm_cutoff<Scalar>(l2, cols, L2Cap)) {
|
||||
trsm_kernels<Scalar, Index, Mode, Conjugate, TriStorageOrder, 1>::trsmKernelL(
|
||||
trsmKernelL<Scalar, Index, Mode, Conjugate, TriStorageOrder, 1>::kernel(
|
||||
size, cols, _tri, triStride, _other, 1, otherStride);
|
||||
return;
|
||||
}
|
||||
@ -243,14 +246,14 @@ EIGEN_DONT_INLINE void triangular_solve_matrix<Scalar,Index,OnTheLeft,Mode,Conju
|
||||
// tr solve
|
||||
{
|
||||
Index i = IsLower ? k2+k1 : k2-k1;
|
||||
#if defined(EIGEN_USE_AVX512_TRSM_KERNELS)
|
||||
#if defined(EIGEN_USE_AVX512_TRSM_L_KERNELS)
|
||||
EIGEN_IF_CONSTEXPR( (OtherInnerStride == 1 &&
|
||||
(std::is_same<Scalar,float>::value ||
|
||||
std::is_same<Scalar,double>::value)) ) {
|
||||
i = IsLower ? k2 + k1: k2 - k1 - actualPanelWidth;
|
||||
}
|
||||
#endif
|
||||
trsm_kernels<Scalar, Index, Mode, Conjugate, TriStorageOrder, OtherInnerStride>::trsmKernelL(
|
||||
trsmKernelL<Scalar, Index, Mode, Conjugate, TriStorageOrder, OtherInnerStride>::kernel(
|
||||
actualPanelWidth, actual_cols,
|
||||
_tri + i + (i)*triStride, triStride,
|
||||
_other + i*OtherInnerStride + j2*otherStride, otherIncr, otherStride);
|
||||
@ -315,7 +318,7 @@ EIGEN_DONT_INLINE void triangular_solve_matrix<Scalar,Index,OnTheRight,Mode,Conj
|
||||
{
|
||||
Index rows = otherSize;
|
||||
|
||||
#if defined(EIGEN_ENABLE_AVX512_NOCOPY_TRSM_CUTOFFS)
|
||||
#if defined(EIGEN_ENABLE_AVX512_NOCOPY_TRSM_R_CUTOFFS)
|
||||
EIGEN_IF_CONSTEXPR( (OtherInnerStride == 1 &&
|
||||
(std::is_same<Scalar,float>::value ||
|
||||
std::is_same<Scalar,double>::value)) ) {
|
||||
@ -324,8 +327,8 @@ EIGEN_DONT_INLINE void triangular_solve_matrix<Scalar,Index,OnTheRight,Mode,Conj
|
||||
manage_caching_sizes(GetAction, &l1, &l2, &l3);
|
||||
double L2Cap = 0.5; // 50% of L2 size
|
||||
if (size < avx512_trsm_cutoff<Scalar>(l2, rows, L2Cap)) {
|
||||
trsm_kernels<Scalar, Index, Mode, Conjugate, TriStorageOrder, OtherInnerStride>::
|
||||
trsmKernelR(size, rows, _tri, triStride, _other, 1, otherStride);
|
||||
trsmKernelR<Scalar, Index, Mode, Conjugate, TriStorageOrder, OtherInnerStride>::
|
||||
kernel(size, rows, _tri, triStride, _other, 1, otherStride);
|
||||
return;
|
||||
}
|
||||
}
|
||||
@ -420,8 +423,8 @@ EIGEN_DONT_INLINE void triangular_solve_matrix<Scalar,Index,OnTheRight,Mode,Conj
|
||||
|
||||
{
|
||||
// unblocked triangular solve
|
||||
trsm_kernels<Scalar, Index, Mode, Conjugate, TriStorageOrder, OtherInnerStride>::
|
||||
trsmKernelR(actualPanelWidth, actual_mc,
|
||||
trsmKernelR<Scalar, Index, Mode, Conjugate, TriStorageOrder, OtherInnerStride>::
|
||||
kernel(actualPanelWidth, actual_mc,
|
||||
_tri + absolute_j2 + absolute_j2*triStride, triStride,
|
||||
_other + i2*OtherInnerStride + absolute_j2*otherStride, otherIncr, otherStride);
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user