From 28812d2ebbb82ceb3bc24684fd341f73dcb15041 Mon Sep 17 00:00:00 2001 From: "Shi, Brian" Date: Mon, 6 Jun 2022 17:03:10 -0700 Subject: [PATCH] AVX512 TRSM Kernels respect EIGEN_NO_MALLOC --- Eigen/src/Core/arch/AVX512/TrsmKernel.h | 114 ++++++++++-------- .../Core/products/TriangularSolverMatrix.h | 43 ++++--- 2 files changed, 87 insertions(+), 70 deletions(-) diff --git a/Eigen/src/Core/arch/AVX512/TrsmKernel.h b/Eigen/src/Core/arch/AVX512/TrsmKernel.h index a8b999c31..b39524958 100644 --- a/Eigen/src/Core/arch/AVX512/TrsmKernel.h +++ b/Eigen/src/Core/arch/AVX512/TrsmKernel.h @@ -14,6 +14,13 @@ #define EIGEN_USE_AVX512_TRSM_KERNELS // Comment out to prevent using optimized trsm kernels. +#if defined(EIGEN_USE_AVX512_TRSM_KERNELS) +#define EIGEN_USE_AVX512_TRSM_R_KERNELS +#if !defined(EIGEN_NO_MALLOC) // Separate MACRO since these kernels require malloc +#define EIGEN_USE_AVX512_TRSM_L_KERNELS +#endif +#endif + #if defined(EIGEN_HAS_CXX17_IFCONSTEXPR) #define EIGEN_IF_CONSTEXPR(X) if constexpr (X) #else @@ -61,6 +68,14 @@ typedef Packet4d vecHalfDouble; * */ #define EIGEN_ENABLE_AVX512_NOCOPY_TRSM_CUTOFFS // Comment out to disable no-copy dispatch + +#if defined(EIGEN_ENABLE_AVX512_NOCOPY_TRSM_CUTOFFS) +#define EIGEN_ENABLE_AVX512_NOCOPY_TRSM_R_CUTOFFS +#if !defined(EIGEN_NO_MALLOC) +#define EIGEN_ENABLE_AVX512_NOCOPY_TRSM_L_CUTOFFS +#endif +#endif + template int64_t avx512_trsm_cutoff(int64_t L2Size, int64_t N, double L2Cap){ const int64_t U3 = 3*packet_traits::size; @@ -882,11 +897,7 @@ void triSolve(Scalar *A_arr, Scalar *B_arr, int64_t M, int64_t numRHS, int64_t L * The updated row-major copy of B is reused in the GEMM updates. */ sizeBTemp = (((std::min(kB, numRHS) + 15)/16+ 4)*16)*numM; -#if EIGEN_COMP_MSVC - B_temp = (Scalar*) _aligned_malloc(sizeof(Scalar)*sizeBTemp,4096); -#else - B_temp = (Scalar*) aligned_alloc(4096,sizeof(Scalar)*sizeBTemp); -#endif + B_temp = (Scalar*) handmade_aligned_malloc(sizeof(Scalar)*sizeBTemp,4096); } for(int64_t k = 0; k < numRHS; k += kB) { int64_t bK = numRHS - k > kB ? kB : numRHS - k; @@ -1030,55 +1041,29 @@ void triSolve(Scalar *A_arr, Scalar *B_arr, int64_t M, int64_t numRHS, int64_t L } } } -#if EIGEN_COMP_MSVC - EIGEN_IF_CONSTEXPR(!isBRowMajor) _aligned_free(B_temp); -#else - EIGEN_IF_CONSTEXPR(!isBRowMajor) free(B_temp); -#endif -} -template -void gemmKer(Scalar *A_arr, Scalar *B_arr, Scalar *C_arr, - int64_t M, int64_t N, int64_t K, - int64_t LDA, int64_t LDB, int64_t LDC) { - gemmKernel(B_arr, A_arr, C_arr, N, M, K, LDB, LDA, LDC); + EIGEN_IF_CONSTEXPR(!isBRowMajor) handmade_aligned_free(B_temp); } - // Template specializations of trsmKernelL/R for float/double and inner strides of 1. #if defined(EIGEN_USE_AVX512_TRSM_KERNELS) template -struct trsm_kernels; +struct trsmKernelR; template -struct trsm_kernels{ - static void trsmKernelL(Index size, Index otherSize, const float* _tri, Index triStride, - float* _other, Index otherIncr, Index otherStride); - static void trsmKernelR(Index size, Index otherSize, const float* _tri, Index triStride, +struct trsmKernelR{ + static void kernel(Index size, Index otherSize, const float* _tri, Index triStride, float* _other, Index otherIncr, Index otherStride); }; template -struct trsm_kernels{ - static void trsmKernelL(Index size, Index otherSize, const double* _tri, Index triStride, - double* _other, Index otherIncr, Index otherStride); - static void trsmKernelR(Index size, Index otherSize, const double* _tri, Index triStride, +struct trsmKernelR{ + static void kernel(Index size, Index otherSize, const double* _tri, Index triStride, double* _other, Index otherIncr, Index otherStride); }; template -EIGEN_DONT_INLINE void trsm_kernels::trsmKernelL( - Index size, Index otherSize, - const float* _tri, Index triStride, - float* _other, Index otherIncr, Index otherStride) -{ - EIGEN_UNUSED_VARIABLE(otherIncr); - triSolve( - const_cast(_tri), _other, size, otherSize, triStride, otherStride); -} - -template -EIGEN_DONT_INLINE void trsm_kernels::trsmKernelR( +EIGEN_DONT_INLINE void trsmKernelR::kernel( Index size, Index otherSize, const float* _tri, Index triStride, float* _other, Index otherIncr, Index otherStride) @@ -1089,18 +1074,7 @@ EIGEN_DONT_INLINE void trsm_kernels -EIGEN_DONT_INLINE void trsm_kernels::trsmKernelL( - Index size, Index otherSize, - const double* _tri, Index triStride, - double* _other, Index otherIncr, Index otherStride) -{ - EIGEN_UNUSED_VARIABLE(otherIncr); - triSolve( - const_cast(_tri), _other, size, otherSize, triStride, otherStride); -} - -template -EIGEN_DONT_INLINE void trsm_kernels::trsmKernelR( +EIGEN_DONT_INLINE void trsmKernelR::kernel( Index size, Index otherSize, const double* _tri, Index triStride, double* _other, Index otherIncr, Index otherStride) @@ -1109,6 +1083,46 @@ EIGEN_DONT_INLINE void trsm_kernels( const_cast(_tri), _other, size, otherSize, triStride, otherStride); } + +// These trsm kernels require temporary memory allocation, so disable them if malloc is not allowed. +#if defined(EIGEN_USE_AVX512_TRSM_L_KERNELS) +template +struct trsmKernelL; + +template +struct trsmKernelL{ + static void kernel(Index size, Index otherSize, const float* _tri, Index triStride, + float* _other, Index otherIncr, Index otherStride); +}; + +template +struct trsmKernelL{ + static void kernel(Index size, Index otherSize, const double* _tri, Index triStride, + double* _other, Index otherIncr, Index otherStride); +}; + +template +EIGEN_DONT_INLINE void trsmKernelL::kernel( + Index size, Index otherSize, + const float* _tri, Index triStride, + float* _other, Index otherIncr, Index otherStride) +{ + EIGEN_UNUSED_VARIABLE(otherIncr); + triSolve( + const_cast(_tri), _other, size, otherSize, triStride, otherStride); +} + +template +EIGEN_DONT_INLINE void trsmKernelL::kernel( + Index size, Index otherSize, + const double* _tri, Index triStride, + double* _other, Index otherIncr, Index otherStride) +{ + EIGEN_UNUSED_VARIABLE(otherIncr); + triSolve( + const_cast(_tri), _other, size, otherSize, triStride, otherStride); +} +#endif //EIGEN_USE_AVX512_TRSM_L_KERNELS #endif //EIGEN_USE_AVX512_TRSM_KERNELS } } diff --git a/Eigen/src/Core/products/TriangularSolverMatrix.h b/Eigen/src/Core/products/TriangularSolverMatrix.h index def6a28f2..8b9be8b60 100644 --- a/Eigen/src/Core/products/TriangularSolverMatrix.h +++ b/Eigen/src/Core/products/TriangularSolverMatrix.h @@ -18,24 +18,27 @@ namespace Eigen { namespace internal { template -struct trsm_kernels { +struct trsmKernelL { // Generic Implementation of triangular solve for triangular matrix on left and multiple rhs. // Handles non-packed matrices. - static void trsmKernelL( - Index size, Index otherSize, - const Scalar* _tri, Index triStride, - Scalar* _other, Index otherIncr, Index otherStride); - - // Generic Implementation of triangular solve for triangular matrix on right and multiple lhs. - // Handles non-packed matrices. - static void trsmKernelR( + static void kernel( Index size, Index otherSize, const Scalar* _tri, Index triStride, Scalar* _other, Index otherIncr, Index otherStride); }; template -EIGEN_STRONG_INLINE void trsm_kernels::trsmKernelL( +struct trsmKernelR { + // Generic Implementation of triangular solve for triangular matrix on right and multiple lhs. + // Handles non-packed matrices. + static void kernel( + Index size, Index otherSize, + const Scalar* _tri, Index triStride, + Scalar* _other, Index otherIncr, Index otherStride); +}; + +template +EIGEN_STRONG_INLINE void trsmKernelL::kernel( Index size, Index otherSize, const Scalar* _tri, Index triStride, Scalar* _other, Index otherIncr, Index otherStride) @@ -86,7 +89,7 @@ EIGEN_STRONG_INLINE void trsm_kernels -EIGEN_STRONG_INLINE void trsm_kernels::trsmKernelR( +EIGEN_STRONG_INLINE void trsmKernelR::kernel( Index size, Index otherSize, const Scalar* _tri, Index triStride, Scalar* _other, Index otherIncr, Index otherStride) @@ -168,7 +171,7 @@ EIGEN_DONT_INLINE void triangular_solve_matrix::value || std::is_same::value)) ) { @@ -177,7 +180,7 @@ EIGEN_DONT_INLINE void triangular_solve_matrix(l2, cols, L2Cap)) { - trsm_kernels::trsmKernelL( + trsmKernelL::kernel( size, cols, _tri, triStride, _other, 1, otherStride); return; } @@ -243,14 +246,14 @@ EIGEN_DONT_INLINE void triangular_solve_matrix::value || std::is_same::value)) ) { i = IsLower ? k2 + k1: k2 - k1 - actualPanelWidth; } #endif - trsm_kernels::trsmKernelL( + trsmKernelL::kernel( actualPanelWidth, actual_cols, _tri + i + (i)*triStride, triStride, _other + i*OtherInnerStride + j2*otherStride, otherIncr, otherStride); @@ -315,7 +318,7 @@ EIGEN_DONT_INLINE void triangular_solve_matrix::value || std::is_same::value)) ) { @@ -324,8 +327,8 @@ EIGEN_DONT_INLINE void triangular_solve_matrix(l2, rows, L2Cap)) { - trsm_kernels:: - trsmKernelR(size, rows, _tri, triStride, _other, 1, otherStride); + trsmKernelR:: + kernel(size, rows, _tri, triStride, _other, 1, otherStride); return; } } @@ -420,8 +423,8 @@ EIGEN_DONT_INLINE void triangular_solve_matrix:: - trsmKernelR(actualPanelWidth, actual_mc, + trsmKernelR:: + kernel(actualPanelWidth, actual_mc, _tri + absolute_j2 + absolute_j2*triStride, triStride, _other + i2*OtherInnerStride + absolute_j2*otherStride, otherIncr, otherStride); }