ASAN fixes for AVX512 GEMM/TRSM

This commit is contained in:
b-shi 2023-03-31 12:58:07 -07:00 committed by Shi, Brian
parent 178ef8c97f
commit 15fbddaf9b
2 changed files with 43 additions and 27 deletions

View File

@ -641,7 +641,7 @@ class gemm_class {
} }
} }
template <int uk, int max_b_unroll, int a_unroll, int b_unroll, bool ktail, bool fetch_x, bool c_fetch> template <int uk, int max_b_unroll, int a_unroll, int b_unroll, bool ktail, bool fetch_x, bool c_fetch, bool no_a_preload = false>
EIGEN_ALWAYS_INLINE void innerkernel_1uk(const Scalar *&aa, const Scalar *const &ao, const Scalar *const &bo, EIGEN_ALWAYS_INLINE void innerkernel_1uk(const Scalar *&aa, const Scalar *const &ao, const Scalar *const &bo,
Scalar *&co2, int &fetchA_idx, int &fetchB_idx) { Scalar *&co2, int &fetchA_idx, int &fetchB_idx) {
const int um_vecs = div_up(a_unroll, nelems_in_cache_line); const int um_vecs = div_up(a_unroll, nelems_in_cache_line);
@ -655,8 +655,8 @@ class gemm_class {
if (max_b_unroll >= 8) if (max_b_unroll >= 8)
innerkernel_1pow<uk, 8, 0, um_vecs, b_unroll, ktail, fetch_x, c_fetch>(aa, ao, bo, co2, fetchA_idx, fetchB_idx); innerkernel_1pow<uk, 8, 0, um_vecs, b_unroll, ktail, fetch_x, c_fetch>(aa, ao, bo, co2, fetchA_idx, fetchB_idx);
// Load A after pow-loop. // Load A after pow-loop. Skip this at the end to prevent running over the buffer
load_a<0, um_vecs, uk, a_unroll, ktail>(ao); if (!no_a_preload) load_a<0, um_vecs, uk, a_unroll, ktail>(ao);
} }
/* Inner kernel loop structure. /* Inner kernel loop structure.
@ -698,7 +698,7 @@ class gemm_class {
* bo += b_unroll * kfactor; * bo += b_unroll * kfactor;
*/ */
template <int a_unroll, int b_unroll, int k_factor, int max_b_unroll, int max_k_factor, bool c_fetch> template <int a_unroll, int b_unroll, int k_factor, int max_b_unroll, int max_k_factor, bool c_fetch, bool no_a_preload = false>
EIGEN_ALWAYS_INLINE void innerkernel(const Scalar *&aa, const Scalar *&ao, const Scalar *&bo, Scalar *&co2) { EIGEN_ALWAYS_INLINE void innerkernel(const Scalar *&aa, const Scalar *&ao, const Scalar *&bo, Scalar *&co2) {
int fetchA_idx = 0; int fetchA_idx = 0;
int fetchB_idx = 0; int fetchB_idx = 0;
@ -707,18 +707,19 @@ class gemm_class {
const bool ktail = k_factor == 1; const bool ktail = k_factor == 1;
static_assert(k_factor <= 4 && k_factor > 0, "innerkernel maximum k_factor supported is 4"); static_assert(k_factor <= 4 && k_factor > 0, "innerkernel maximum k_factor supported is 4");
static_assert(no_a_preload == false || (no_a_preload == true && k_factor == 1), "skipping a preload only allowed when k unroll is 1");
if (k_factor > 0) if (k_factor > 0)
innerkernel_1uk<0, max_b_unroll, a_unroll, b_unroll, ktail, fetch_x, c_fetch>(aa, ao, bo, co2, fetchA_idx, innerkernel_1uk<0, max_b_unroll, a_unroll, b_unroll, ktail, fetch_x, c_fetch, no_a_preload>(aa, ao, bo, co2, fetchA_idx,
fetchB_idx); fetchB_idx);
if (k_factor > 1) if (k_factor > 1)
innerkernel_1uk<1, max_b_unroll, a_unroll, b_unroll, ktail, fetch_x, c_fetch>(aa, ao, bo, co2, fetchA_idx, innerkernel_1uk<1, max_b_unroll, a_unroll, b_unroll, ktail, fetch_x, c_fetch, no_a_preload>(aa, ao, bo, co2, fetchA_idx,
fetchB_idx); fetchB_idx);
if (k_factor > 2) if (k_factor > 2)
innerkernel_1uk<2, max_b_unroll, a_unroll, b_unroll, ktail, fetch_x, c_fetch>(aa, ao, bo, co2, fetchA_idx, innerkernel_1uk<2, max_b_unroll, a_unroll, b_unroll, ktail, fetch_x, c_fetch, no_a_preload>(aa, ao, bo, co2, fetchA_idx,
fetchB_idx); fetchB_idx);
if (k_factor > 3) if (k_factor > 3)
innerkernel_1uk<3, max_b_unroll, a_unroll, b_unroll, ktail, fetch_x, c_fetch>(aa, ao, bo, co2, fetchA_idx, innerkernel_1uk<3, max_b_unroll, a_unroll, b_unroll, ktail, fetch_x, c_fetch, no_a_preload>(aa, ao, bo, co2, fetchA_idx,
fetchB_idx); fetchB_idx);
// Advance A/B pointers after uk-loop. // Advance A/B pointers after uk-loop.
@ -729,7 +730,7 @@ class gemm_class {
template <int a_unroll, int b_unroll, int max_b_unroll> template <int a_unroll, int b_unroll, int max_b_unroll>
EIGEN_ALWAYS_INLINE void kloop(const Scalar *&aa, const Scalar *&ao, const Scalar *&bo, Scalar *&co1, Scalar *&co2) { EIGEN_ALWAYS_INLINE void kloop(const Scalar *&aa, const Scalar *&ao, const Scalar *&bo, Scalar *&co1, Scalar *&co2) {
const int um_vecs = div_up(a_unroll, nelems_in_cache_line); const int um_vecs = div_up(a_unroll, nelems_in_cache_line);
if (!use_less_a_regs) if (!use_less_a_regs && k > 1)
a_loads<0, 2, 0, um_vecs, a_unroll>(ao); a_loads<0, 2, 0, um_vecs, a_unroll>(ao);
else else
a_loads<0, 1, 0, um_vecs, a_unroll>(ao); a_loads<0, 1, 0, um_vecs, a_unroll>(ao);
@ -743,7 +744,13 @@ class gemm_class {
// Unrolling k-loop by a factor of 4. // Unrolling k-loop by a factor of 4.
const int max_k_factor = 4; const int max_k_factor = 4;
Index loop_count = k / max_k_factor; Index kRem = k % max_k_factor;
Index k_ = k - kRem;
if (k_ >= max_k_factor) {
k_ -= max_k_factor;
kRem += max_k_factor;
}
Index loop_count = k_ / max_k_factor;
if (loop_count > 0) { if (loop_count > 0) {
#ifdef SECOND_FETCH #ifdef SECOND_FETCH
@ -771,11 +778,14 @@ class gemm_class {
} }
// k-loop remainder handling. // k-loop remainder handling.
loop_count = k % max_k_factor; loop_count = kRem;
while (loop_count > 0) { while (loop_count > 1) {
innerkernel<a_unroll, b_unroll, 1, max_b_unroll, max_k_factor, 0>(aa, ao, bo, co2); innerkernel<a_unroll, b_unroll, 1, max_b_unroll, max_k_factor, 0>(aa, ao, bo, co2);
loop_count--; loop_count--;
} }
if (loop_count > 0) {
innerkernel<a_unroll, b_unroll, 1, max_b_unroll, max_k_factor, 0, true>(aa, ao, bo, co2);
}
// Update C matrix. // Update C matrix.
c_update<max_b_unroll, a_unroll, b_unroll>(co1, co2); c_update<max_b_unroll, a_unroll, b_unroll>(co1, co2);

View File

@ -299,7 +299,7 @@ class transB {
* 1-D unroll * 1-D unroll
* for(startN = 0; startN < endN; startN++) * for(startN = 0; startN < endN; startN++)
**/ **/
template <int64_t endN, int64_t counter, int64_t packetIndexOffset, bool remM> template <int64_t endN, int64_t counter, int64_t packetIndexOffset, bool remM, int64_t remN_>
static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_loadB( static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_loadB(
Scalar *B_arr, int64_t LDB, PacketBlock<vecHalf, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &ymm, Scalar *B_arr, int64_t LDB, PacketBlock<vecHalf, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &ymm,
int64_t remM_ = 0) { int64_t remM_ = 0) {
@ -310,12 +310,18 @@ class transB {
ymm.packet[packetIndexOffset + startN] = ymm.packet[packetIndexOffset + startN] =
ploadu<vecHalf>((const Scalar *)&B_arr[startN * LDB], remMask<EIGEN_AVX_MAX_NUM_ROW>(remM_)); ploadu<vecHalf>((const Scalar *)&B_arr[startN * LDB], remMask<EIGEN_AVX_MAX_NUM_ROW>(remM_));
} }
else ymm.packet[packetIndexOffset + startN] = ploadu<vecHalf>((const Scalar *)&B_arr[startN * LDB]); else {
EIGEN_IF_CONSTEXPR(remN_ == 0) {
ymm.packet[packetIndexOffset + startN] = ploadu<vecHalf>((const Scalar *)&B_arr[startN * LDB]);
}
else ymm.packet[packetIndexOffset + startN] =
ploadu<vecHalf>((const Scalar *)&B_arr[startN * LDB], remMask<EIGEN_AVX_MAX_NUM_ROW>(remN_));
}
aux_loadB<endN, counter - 1, packetIndexOffset, remM>(B_arr, LDB, ymm, remM_); aux_loadB<endN, counter - 1, packetIndexOffset, remM, remN_>(B_arr, LDB, ymm, remM_);
} }
template <int64_t endN, int64_t counter, int64_t packetIndexOffset, bool remM> template <int64_t endN, int64_t counter, int64_t packetIndexOffset, bool remM, int64_t remN_>
static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_loadB( static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_loadB(
Scalar *B_arr, int64_t LDB, PacketBlock<vecHalf, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &ymm, Scalar *B_arr, int64_t LDB, PacketBlock<vecHalf, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &ymm,
int64_t remM_ = 0) { int64_t remM_ = 0) {
@ -363,17 +369,17 @@ class transB {
* 1-D unroll * 1-D unroll
* for(startN = 0; startN < endN; startN += EIGEN_AVX_MAX_NUM_ROW) * for(startN = 0; startN < endN; startN += EIGEN_AVX_MAX_NUM_ROW)
**/ **/
template <int64_t endN, int64_t counter, bool toTemp, bool remM> template <int64_t endN, int64_t counter, bool toTemp, bool remM, int64_t remN_>
static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_loadBBlock( static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_loadBBlock(
Scalar *B_arr, int64_t LDB, Scalar *B_temp, int64_t LDB_, Scalar *B_arr, int64_t LDB, Scalar *B_temp, int64_t LDB_,
PacketBlock<vecHalf, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &ymm, int64_t remM_ = 0) { PacketBlock<vecHalf, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &ymm, int64_t remM_ = 0) {
constexpr int64_t counterReverse = endN - counter; constexpr int64_t counterReverse = endN - counter;
constexpr int64_t startN = counterReverse; constexpr int64_t startN = counterReverse;
transB::template loadB<EIGEN_AVX_MAX_NUM_ROW, startN, false>(&B_temp[startN], LDB_, ymm); transB::template loadB<EIGEN_AVX_MAX_NUM_ROW, startN, false, (toTemp ? 0 : remN_)>(&B_temp[startN], LDB_, ymm);
aux_loadBBlock<endN, counter - EIGEN_AVX_MAX_NUM_ROW, toTemp, remM>(B_arr, LDB, B_temp, LDB_, ymm, remM_); aux_loadBBlock<endN, counter - EIGEN_AVX_MAX_NUM_ROW, toTemp, remM, remN_>(B_arr, LDB, B_temp, LDB_, ymm, remM_);
} }
template <int64_t endN, int64_t counter, bool toTemp, bool remM> template <int64_t endN, int64_t counter, bool toTemp, bool remM, int64_t remN_>
static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_loadBBlock( static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_loadBBlock(
Scalar *B_arr, int64_t LDB, Scalar *B_temp, int64_t LDB_, Scalar *B_arr, int64_t LDB, Scalar *B_temp, int64_t LDB_,
PacketBlock<vecHalf, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &ymm, int64_t remM_ = 0) { PacketBlock<vecHalf, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &ymm, int64_t remM_ = 0) {
@ -424,11 +430,11 @@ class transB {
* Wrappers for aux_XXXX to hide counter parameter * Wrappers for aux_XXXX to hide counter parameter
********************************************************/ ********************************************************/
template <int64_t endN, int64_t packetIndexOffset, bool remM> template <int64_t endN, int64_t packetIndexOffset, bool remM, int64_t remN_>
static EIGEN_ALWAYS_INLINE void loadB(Scalar *B_arr, int64_t LDB, static EIGEN_ALWAYS_INLINE void loadB(Scalar *B_arr, int64_t LDB,
PacketBlock<vecHalf, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &ymm, PacketBlock<vecHalf, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &ymm,
int64_t remM_ = 0) { int64_t remM_ = 0) {
aux_loadB<endN, endN, packetIndexOffset, remM>(B_arr, LDB, ymm, remM_); aux_loadB<endN, endN, packetIndexOffset, remM, remN_>(B_arr, LDB, ymm, remM_);
} }
template <int64_t endN, int64_t packetIndexOffset, bool remK, bool remM> template <int64_t endN, int64_t packetIndexOffset, bool remK, bool remM>
@ -438,13 +444,13 @@ class transB {
aux_storeB<endN, endN, packetIndexOffset, remK, remM>(B_arr, LDB, ymm, rem_); aux_storeB<endN, endN, packetIndexOffset, remK, remM>(B_arr, LDB, ymm, rem_);
} }
template <int64_t unrollN, bool toTemp, bool remM> template <int64_t unrollN, bool toTemp, bool remM, int64_t remN_ = 0>
static EIGEN_ALWAYS_INLINE void loadBBlock(Scalar *B_arr, int64_t LDB, Scalar *B_temp, int64_t LDB_, static EIGEN_ALWAYS_INLINE void loadBBlock(Scalar *B_arr, int64_t LDB, Scalar *B_temp, int64_t LDB_,
PacketBlock<vecHalf, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &ymm, PacketBlock<vecHalf, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &ymm,
int64_t remM_ = 0) { int64_t remM_ = 0) {
EIGEN_IF_CONSTEXPR(toTemp) { transB::template loadB<unrollN, 0, remM>(&B_arr[0], LDB, ymm, remM_); } EIGEN_IF_CONSTEXPR(toTemp) { transB::template loadB<unrollN, 0, remM, 0>(&B_arr[0], LDB, ymm, remM_); }
else { else {
aux_loadBBlock<unrollN, unrollN, toTemp, remM>(B_arr, LDB, B_temp, LDB_, ymm, remM_); aux_loadBBlock<unrollN, unrollN, toTemp, remM, remN_>(B_arr, LDB, B_temp, LDB_, ymm, remM_);
} }
} }
@ -550,13 +556,13 @@ class transB {
} }
else EIGEN_IF_CONSTEXPR(unrollN == 2) { else EIGEN_IF_CONSTEXPR(unrollN == 2) {
// load Lx2 B col major, transpose Lx2 row major // load Lx2 B col major, transpose Lx2 row major
transB::template loadBBlock<2, toTemp, remM>(B_arr, LDB, B_temp, LDB_, ymm, remM_); transB::template loadBBlock<2, toTemp, remM, 2>(B_arr, LDB, B_temp, LDB_, ymm, remM_);
transB::template transposeLxL<0>(ymm); transB::template transposeLxL<0>(ymm);
transB::template storeBBlock<2, toTemp, remM, 2>(B_arr, LDB, B_temp, LDB_, ymm, remM_); transB::template storeBBlock<2, toTemp, remM, 2>(B_arr, LDB, B_temp, LDB_, ymm, remM_);
} }
else EIGEN_IF_CONSTEXPR(unrollN == 1) { else EIGEN_IF_CONSTEXPR(unrollN == 1) {
// load Lx1 B col major, transpose Lx1 row major // load Lx1 B col major, transpose Lx1 row major
transB::template loadBBlock<1, toTemp, remM>(B_arr, LDB, B_temp, LDB_, ymm, remM_); transB::template loadBBlock<1, toTemp, remM, 1>(B_arr, LDB, B_temp, LDB_, ymm, remM_);
transB::template transposeLxL<0>(ymm); transB::template transposeLxL<0>(ymm);
transB::template storeBBlock<1, toTemp, remM, 1>(B_arr, LDB, B_temp, LDB_, ymm, remM_); transB::template storeBBlock<1, toTemp, remM, 1>(B_arr, LDB, B_temp, LDB_, ymm, remM_);
} }