ASAN fixes for AVX512 GEMM/TRSM

This commit is contained in:
b-shi 2023-03-31 12:58:07 -07:00 committed by Shi, Brian
parent 178ef8c97f
commit 15fbddaf9b
2 changed files with 43 additions and 27 deletions

View File

@ -641,7 +641,7 @@ class gemm_class {
}
}
template <int uk, int max_b_unroll, int a_unroll, int b_unroll, bool ktail, bool fetch_x, bool c_fetch>
template <int uk, int max_b_unroll, int a_unroll, int b_unroll, bool ktail, bool fetch_x, bool c_fetch, bool no_a_preload = false>
EIGEN_ALWAYS_INLINE void innerkernel_1uk(const Scalar *&aa, const Scalar *const &ao, const Scalar *const &bo,
Scalar *&co2, int &fetchA_idx, int &fetchB_idx) {
const int um_vecs = div_up(a_unroll, nelems_in_cache_line);
@ -655,8 +655,8 @@ class gemm_class {
if (max_b_unroll >= 8)
innerkernel_1pow<uk, 8, 0, um_vecs, b_unroll, ktail, fetch_x, c_fetch>(aa, ao, bo, co2, fetchA_idx, fetchB_idx);
// Load A after pow-loop.
load_a<0, um_vecs, uk, a_unroll, ktail>(ao);
// Load A after pow-loop. Skip this at the end to prevent running over the buffer
if (!no_a_preload) load_a<0, um_vecs, uk, a_unroll, ktail>(ao);
}
/* Inner kernel loop structure.
@ -698,7 +698,7 @@ class gemm_class {
* bo += b_unroll * kfactor;
*/
template <int a_unroll, int b_unroll, int k_factor, int max_b_unroll, int max_k_factor, bool c_fetch>
template <int a_unroll, int b_unroll, int k_factor, int max_b_unroll, int max_k_factor, bool c_fetch, bool no_a_preload = false>
EIGEN_ALWAYS_INLINE void innerkernel(const Scalar *&aa, const Scalar *&ao, const Scalar *&bo, Scalar *&co2) {
int fetchA_idx = 0;
int fetchB_idx = 0;
@ -707,18 +707,19 @@ class gemm_class {
const bool ktail = k_factor == 1;
static_assert(k_factor <= 4 && k_factor > 0, "innerkernel maximum k_factor supported is 4");
static_assert(no_a_preload == false || (no_a_preload == true && k_factor == 1), "skipping a preload only allowed when k unroll is 1");
if (k_factor > 0)
innerkernel_1uk<0, max_b_unroll, a_unroll, b_unroll, ktail, fetch_x, c_fetch>(aa, ao, bo, co2, fetchA_idx,
innerkernel_1uk<0, max_b_unroll, a_unroll, b_unroll, ktail, fetch_x, c_fetch, no_a_preload>(aa, ao, bo, co2, fetchA_idx,
fetchB_idx);
if (k_factor > 1)
innerkernel_1uk<1, max_b_unroll, a_unroll, b_unroll, ktail, fetch_x, c_fetch>(aa, ao, bo, co2, fetchA_idx,
innerkernel_1uk<1, max_b_unroll, a_unroll, b_unroll, ktail, fetch_x, c_fetch, no_a_preload>(aa, ao, bo, co2, fetchA_idx,
fetchB_idx);
if (k_factor > 2)
innerkernel_1uk<2, max_b_unroll, a_unroll, b_unroll, ktail, fetch_x, c_fetch>(aa, ao, bo, co2, fetchA_idx,
innerkernel_1uk<2, max_b_unroll, a_unroll, b_unroll, ktail, fetch_x, c_fetch, no_a_preload>(aa, ao, bo, co2, fetchA_idx,
fetchB_idx);
if (k_factor > 3)
innerkernel_1uk<3, max_b_unroll, a_unroll, b_unroll, ktail, fetch_x, c_fetch>(aa, ao, bo, co2, fetchA_idx,
innerkernel_1uk<3, max_b_unroll, a_unroll, b_unroll, ktail, fetch_x, c_fetch, no_a_preload>(aa, ao, bo, co2, fetchA_idx,
fetchB_idx);
// Advance A/B pointers after uk-loop.
@ -729,7 +730,7 @@ class gemm_class {
template <int a_unroll, int b_unroll, int max_b_unroll>
EIGEN_ALWAYS_INLINE void kloop(const Scalar *&aa, const Scalar *&ao, const Scalar *&bo, Scalar *&co1, Scalar *&co2) {
const int um_vecs = div_up(a_unroll, nelems_in_cache_line);
if (!use_less_a_regs)
if (!use_less_a_regs && k > 1)
a_loads<0, 2, 0, um_vecs, a_unroll>(ao);
else
a_loads<0, 1, 0, um_vecs, a_unroll>(ao);
@ -743,7 +744,13 @@ class gemm_class {
// Unrolling k-loop by a factor of 4.
const int max_k_factor = 4;
Index loop_count = k / max_k_factor;
Index kRem = k % max_k_factor;
Index k_ = k - kRem;
if (k_ >= max_k_factor) {
k_ -= max_k_factor;
kRem += max_k_factor;
}
Index loop_count = k_ / max_k_factor;
if (loop_count > 0) {
#ifdef SECOND_FETCH
@ -771,11 +778,14 @@ class gemm_class {
}
// k-loop remainder handling.
loop_count = k % max_k_factor;
while (loop_count > 0) {
loop_count = kRem;
while (loop_count > 1) {
innerkernel<a_unroll, b_unroll, 1, max_b_unroll, max_k_factor, 0>(aa, ao, bo, co2);
loop_count--;
}
if (loop_count > 0) {
innerkernel<a_unroll, b_unroll, 1, max_b_unroll, max_k_factor, 0, true>(aa, ao, bo, co2);
}
// Update C matrix.
c_update<max_b_unroll, a_unroll, b_unroll>(co1, co2);

View File

@ -299,7 +299,7 @@ class transB {
* 1-D unroll
* for(startN = 0; startN < endN; startN++)
**/
template <int64_t endN, int64_t counter, int64_t packetIndexOffset, bool remM>
template <int64_t endN, int64_t counter, int64_t packetIndexOffset, bool remM, int64_t remN_>
static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_loadB(
Scalar *B_arr, int64_t LDB, PacketBlock<vecHalf, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &ymm,
int64_t remM_ = 0) {
@ -310,12 +310,18 @@ class transB {
ymm.packet[packetIndexOffset + startN] =
ploadu<vecHalf>((const Scalar *)&B_arr[startN * LDB], remMask<EIGEN_AVX_MAX_NUM_ROW>(remM_));
}
else ymm.packet[packetIndexOffset + startN] = ploadu<vecHalf>((const Scalar *)&B_arr[startN * LDB]);
aux_loadB<endN, counter - 1, packetIndexOffset, remM>(B_arr, LDB, ymm, remM_);
else {
EIGEN_IF_CONSTEXPR(remN_ == 0) {
ymm.packet[packetIndexOffset + startN] = ploadu<vecHalf>((const Scalar *)&B_arr[startN * LDB]);
}
else ymm.packet[packetIndexOffset + startN] =
ploadu<vecHalf>((const Scalar *)&B_arr[startN * LDB], remMask<EIGEN_AVX_MAX_NUM_ROW>(remN_));
}
template <int64_t endN, int64_t counter, int64_t packetIndexOffset, bool remM>
aux_loadB<endN, counter - 1, packetIndexOffset, remM, remN_>(B_arr, LDB, ymm, remM_);
}
template <int64_t endN, int64_t counter, int64_t packetIndexOffset, bool remM, int64_t remN_>
static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_loadB(
Scalar *B_arr, int64_t LDB, PacketBlock<vecHalf, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &ymm,
int64_t remM_ = 0) {
@ -363,17 +369,17 @@ class transB {
* 1-D unroll
* for(startN = 0; startN < endN; startN += EIGEN_AVX_MAX_NUM_ROW)
**/
template <int64_t endN, int64_t counter, bool toTemp, bool remM>
template <int64_t endN, int64_t counter, bool toTemp, bool remM, int64_t remN_>
static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_loadBBlock(
Scalar *B_arr, int64_t LDB, Scalar *B_temp, int64_t LDB_,
PacketBlock<vecHalf, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &ymm, int64_t remM_ = 0) {
constexpr int64_t counterReverse = endN - counter;
constexpr int64_t startN = counterReverse;
transB::template loadB<EIGEN_AVX_MAX_NUM_ROW, startN, false>(&B_temp[startN], LDB_, ymm);
aux_loadBBlock<endN, counter - EIGEN_AVX_MAX_NUM_ROW, toTemp, remM>(B_arr, LDB, B_temp, LDB_, ymm, remM_);
transB::template loadB<EIGEN_AVX_MAX_NUM_ROW, startN, false, (toTemp ? 0 : remN_)>(&B_temp[startN], LDB_, ymm);
aux_loadBBlock<endN, counter - EIGEN_AVX_MAX_NUM_ROW, toTemp, remM, remN_>(B_arr, LDB, B_temp, LDB_, ymm, remM_);
}
template <int64_t endN, int64_t counter, bool toTemp, bool remM>
template <int64_t endN, int64_t counter, bool toTemp, bool remM, int64_t remN_>
static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_loadBBlock(
Scalar *B_arr, int64_t LDB, Scalar *B_temp, int64_t LDB_,
PacketBlock<vecHalf, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &ymm, int64_t remM_ = 0) {
@ -424,11 +430,11 @@ class transB {
* Wrappers for aux_XXXX to hide counter parameter
********************************************************/
template <int64_t endN, int64_t packetIndexOffset, bool remM>
template <int64_t endN, int64_t packetIndexOffset, bool remM, int64_t remN_>
static EIGEN_ALWAYS_INLINE void loadB(Scalar *B_arr, int64_t LDB,
PacketBlock<vecHalf, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &ymm,
int64_t remM_ = 0) {
aux_loadB<endN, endN, packetIndexOffset, remM>(B_arr, LDB, ymm, remM_);
aux_loadB<endN, endN, packetIndexOffset, remM, remN_>(B_arr, LDB, ymm, remM_);
}
template <int64_t endN, int64_t packetIndexOffset, bool remK, bool remM>
@ -438,13 +444,13 @@ class transB {
aux_storeB<endN, endN, packetIndexOffset, remK, remM>(B_arr, LDB, ymm, rem_);
}
template <int64_t unrollN, bool toTemp, bool remM>
template <int64_t unrollN, bool toTemp, bool remM, int64_t remN_ = 0>
static EIGEN_ALWAYS_INLINE void loadBBlock(Scalar *B_arr, int64_t LDB, Scalar *B_temp, int64_t LDB_,
PacketBlock<vecHalf, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &ymm,
int64_t remM_ = 0) {
EIGEN_IF_CONSTEXPR(toTemp) { transB::template loadB<unrollN, 0, remM>(&B_arr[0], LDB, ymm, remM_); }
EIGEN_IF_CONSTEXPR(toTemp) { transB::template loadB<unrollN, 0, remM, 0>(&B_arr[0], LDB, ymm, remM_); }
else {
aux_loadBBlock<unrollN, unrollN, toTemp, remM>(B_arr, LDB, B_temp, LDB_, ymm, remM_);
aux_loadBBlock<unrollN, unrollN, toTemp, remM, remN_>(B_arr, LDB, B_temp, LDB_, ymm, remM_);
}
}
@ -550,13 +556,13 @@ class transB {
}
else EIGEN_IF_CONSTEXPR(unrollN == 2) {
// load Lx2 B col major, transpose Lx2 row major
transB::template loadBBlock<2, toTemp, remM>(B_arr, LDB, B_temp, LDB_, ymm, remM_);
transB::template loadBBlock<2, toTemp, remM, 2>(B_arr, LDB, B_temp, LDB_, ymm, remM_);
transB::template transposeLxL<0>(ymm);
transB::template storeBBlock<2, toTemp, remM, 2>(B_arr, LDB, B_temp, LDB_, ymm, remM_);
}
else EIGEN_IF_CONSTEXPR(unrollN == 1) {
// load Lx1 B col major, transpose Lx1 row major
transB::template loadBBlock<1, toTemp, remM>(B_arr, LDB, B_temp, LDB_, ymm, remM_);
transB::template loadBBlock<1, toTemp, remM, 1>(B_arr, LDB, B_temp, LDB_, ymm, remM_);
transB::template transposeLxL<0>(ymm);
transB::template storeBBlock<1, toTemp, remM, 1>(B_arr, LDB, B_temp, LDB_, ymm, remM_);
}