diff --git a/Eigen/src/Core/arch/AVX512/GemmKernel.h b/Eigen/src/Core/arch/AVX512/GemmKernel.h index cb7cfdffc..616a058ab 100644 --- a/Eigen/src/Core/arch/AVX512/GemmKernel.h +++ b/Eigen/src/Core/arch/AVX512/GemmKernel.h @@ -641,7 +641,7 @@ class gemm_class { } } - template + template EIGEN_ALWAYS_INLINE void innerkernel_1uk(const Scalar *&aa, const Scalar *const &ao, const Scalar *const &bo, Scalar *&co2, int &fetchA_idx, int &fetchB_idx) { const int um_vecs = div_up(a_unroll, nelems_in_cache_line); @@ -655,8 +655,8 @@ class gemm_class { if (max_b_unroll >= 8) innerkernel_1pow(aa, ao, bo, co2, fetchA_idx, fetchB_idx); - // Load A after pow-loop. - load_a<0, um_vecs, uk, a_unroll, ktail>(ao); + // Load A after pow-loop. Skip this at the end to prevent running over the buffer + if (!no_a_preload) load_a<0, um_vecs, uk, a_unroll, ktail>(ao); } /* Inner kernel loop structure. @@ -698,7 +698,7 @@ class gemm_class { * bo += b_unroll * kfactor; */ - template + template EIGEN_ALWAYS_INLINE void innerkernel(const Scalar *&aa, const Scalar *&ao, const Scalar *&bo, Scalar *&co2) { int fetchA_idx = 0; int fetchB_idx = 0; @@ -707,18 +707,19 @@ class gemm_class { const bool ktail = k_factor == 1; static_assert(k_factor <= 4 && k_factor > 0, "innerkernel maximum k_factor supported is 4"); + static_assert(no_a_preload == false || (no_a_preload == true && k_factor == 1), "skipping a preload only allowed when k unroll is 1"); if (k_factor > 0) - innerkernel_1uk<0, max_b_unroll, a_unroll, b_unroll, ktail, fetch_x, c_fetch>(aa, ao, bo, co2, fetchA_idx, + innerkernel_1uk<0, max_b_unroll, a_unroll, b_unroll, ktail, fetch_x, c_fetch, no_a_preload>(aa, ao, bo, co2, fetchA_idx, fetchB_idx); if (k_factor > 1) - innerkernel_1uk<1, max_b_unroll, a_unroll, b_unroll, ktail, fetch_x, c_fetch>(aa, ao, bo, co2, fetchA_idx, + innerkernel_1uk<1, max_b_unroll, a_unroll, b_unroll, ktail, fetch_x, c_fetch, no_a_preload>(aa, ao, bo, co2, fetchA_idx, fetchB_idx); if (k_factor > 2) - innerkernel_1uk<2, max_b_unroll, a_unroll, b_unroll, ktail, fetch_x, c_fetch>(aa, ao, bo, co2, fetchA_idx, + innerkernel_1uk<2, max_b_unroll, a_unroll, b_unroll, ktail, fetch_x, c_fetch, no_a_preload>(aa, ao, bo, co2, fetchA_idx, fetchB_idx); if (k_factor > 3) - innerkernel_1uk<3, max_b_unroll, a_unroll, b_unroll, ktail, fetch_x, c_fetch>(aa, ao, bo, co2, fetchA_idx, + innerkernel_1uk<3, max_b_unroll, a_unroll, b_unroll, ktail, fetch_x, c_fetch, no_a_preload>(aa, ao, bo, co2, fetchA_idx, fetchB_idx); // Advance A/B pointers after uk-loop. @@ -729,7 +730,7 @@ class gemm_class { template EIGEN_ALWAYS_INLINE void kloop(const Scalar *&aa, const Scalar *&ao, const Scalar *&bo, Scalar *&co1, Scalar *&co2) { const int um_vecs = div_up(a_unroll, nelems_in_cache_line); - if (!use_less_a_regs) + if (!use_less_a_regs && k > 1) a_loads<0, 2, 0, um_vecs, a_unroll>(ao); else a_loads<0, 1, 0, um_vecs, a_unroll>(ao); @@ -743,7 +744,13 @@ class gemm_class { // Unrolling k-loop by a factor of 4. const int max_k_factor = 4; - Index loop_count = k / max_k_factor; + Index kRem = k % max_k_factor; + Index k_ = k - kRem; + if (k_ >= max_k_factor) { + k_ -= max_k_factor; + kRem += max_k_factor; + } + Index loop_count = k_ / max_k_factor; if (loop_count > 0) { #ifdef SECOND_FETCH @@ -771,11 +778,14 @@ class gemm_class { } // k-loop remainder handling. - loop_count = k % max_k_factor; - while (loop_count > 0) { + loop_count = kRem; + while (loop_count > 1) { innerkernel(aa, ao, bo, co2); loop_count--; } + if (loop_count > 0) { + innerkernel(aa, ao, bo, co2); + } // Update C matrix. c_update(co1, co2); diff --git a/Eigen/src/Core/arch/AVX512/TrsmUnrolls.inc b/Eigen/src/Core/arch/AVX512/TrsmUnrolls.inc index e137d6a9a..4c6116c69 100644 --- a/Eigen/src/Core/arch/AVX512/TrsmUnrolls.inc +++ b/Eigen/src/Core/arch/AVX512/TrsmUnrolls.inc @@ -299,7 +299,7 @@ class transB { * 1-D unroll * for(startN = 0; startN < endN; startN++) **/ - template + template static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_loadB( Scalar *B_arr, int64_t LDB, PacketBlock &ymm, int64_t remM_ = 0) { @@ -310,12 +310,18 @@ class transB { ymm.packet[packetIndexOffset + startN] = ploadu((const Scalar *)&B_arr[startN * LDB], remMask(remM_)); } - else ymm.packet[packetIndexOffset + startN] = ploadu((const Scalar *)&B_arr[startN * LDB]); + else { + EIGEN_IF_CONSTEXPR(remN_ == 0) { + ymm.packet[packetIndexOffset + startN] = ploadu((const Scalar *)&B_arr[startN * LDB]); + } + else ymm.packet[packetIndexOffset + startN] = + ploadu((const Scalar *)&B_arr[startN * LDB], remMask(remN_)); + } - aux_loadB(B_arr, LDB, ymm, remM_); + aux_loadB(B_arr, LDB, ymm, remM_); } - template + template static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_loadB( Scalar *B_arr, int64_t LDB, PacketBlock &ymm, int64_t remM_ = 0) { @@ -363,17 +369,17 @@ class transB { * 1-D unroll * for(startN = 0; startN < endN; startN += EIGEN_AVX_MAX_NUM_ROW) **/ - template + template static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_loadBBlock( Scalar *B_arr, int64_t LDB, Scalar *B_temp, int64_t LDB_, PacketBlock &ymm, int64_t remM_ = 0) { constexpr int64_t counterReverse = endN - counter; constexpr int64_t startN = counterReverse; - transB::template loadB(&B_temp[startN], LDB_, ymm); - aux_loadBBlock(B_arr, LDB, B_temp, LDB_, ymm, remM_); + transB::template loadB(&B_temp[startN], LDB_, ymm); + aux_loadBBlock(B_arr, LDB, B_temp, LDB_, ymm, remM_); } - template + template static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_loadBBlock( Scalar *B_arr, int64_t LDB, Scalar *B_temp, int64_t LDB_, PacketBlock &ymm, int64_t remM_ = 0) { @@ -424,11 +430,11 @@ class transB { * Wrappers for aux_XXXX to hide counter parameter ********************************************************/ - template + template static EIGEN_ALWAYS_INLINE void loadB(Scalar *B_arr, int64_t LDB, PacketBlock &ymm, int64_t remM_ = 0) { - aux_loadB(B_arr, LDB, ymm, remM_); + aux_loadB(B_arr, LDB, ymm, remM_); } template @@ -438,13 +444,13 @@ class transB { aux_storeB(B_arr, LDB, ymm, rem_); } - template + template static EIGEN_ALWAYS_INLINE void loadBBlock(Scalar *B_arr, int64_t LDB, Scalar *B_temp, int64_t LDB_, PacketBlock &ymm, int64_t remM_ = 0) { - EIGEN_IF_CONSTEXPR(toTemp) { transB::template loadB(&B_arr[0], LDB, ymm, remM_); } + EIGEN_IF_CONSTEXPR(toTemp) { transB::template loadB(&B_arr[0], LDB, ymm, remM_); } else { - aux_loadBBlock(B_arr, LDB, B_temp, LDB_, ymm, remM_); + aux_loadBBlock(B_arr, LDB, B_temp, LDB_, ymm, remM_); } } @@ -550,13 +556,13 @@ class transB { } else EIGEN_IF_CONSTEXPR(unrollN == 2) { // load Lx2 B col major, transpose Lx2 row major - transB::template loadBBlock<2, toTemp, remM>(B_arr, LDB, B_temp, LDB_, ymm, remM_); + transB::template loadBBlock<2, toTemp, remM, 2>(B_arr, LDB, B_temp, LDB_, ymm, remM_); transB::template transposeLxL<0>(ymm); transB::template storeBBlock<2, toTemp, remM, 2>(B_arr, LDB, B_temp, LDB_, ymm, remM_); } else EIGEN_IF_CONSTEXPR(unrollN == 1) { // load Lx1 B col major, transpose Lx1 row major - transB::template loadBBlock<1, toTemp, remM>(B_arr, LDB, B_temp, LDB_, ymm, remM_); + transB::template loadBBlock<1, toTemp, remM, 1>(B_arr, LDB, B_temp, LDB_, ymm, remM_); transB::template transposeLxL<0>(ymm); transB::template storeBBlock<1, toTemp, remM, 1>(B_arr, LDB, B_temp, LDB_, ymm, remM_); }