mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-09-23 06:43:13 +08:00
ASAN fixes for AVX512 GEMM/TRSM
This commit is contained in:
parent
178ef8c97f
commit
15fbddaf9b
@ -641,7 +641,7 @@ class gemm_class {
|
||||
}
|
||||
}
|
||||
|
||||
template <int uk, int max_b_unroll, int a_unroll, int b_unroll, bool ktail, bool fetch_x, bool c_fetch>
|
||||
template <int uk, int max_b_unroll, int a_unroll, int b_unroll, bool ktail, bool fetch_x, bool c_fetch, bool no_a_preload = false>
|
||||
EIGEN_ALWAYS_INLINE void innerkernel_1uk(const Scalar *&aa, const Scalar *const &ao, const Scalar *const &bo,
|
||||
Scalar *&co2, int &fetchA_idx, int &fetchB_idx) {
|
||||
const int um_vecs = div_up(a_unroll, nelems_in_cache_line);
|
||||
@ -655,8 +655,8 @@ class gemm_class {
|
||||
if (max_b_unroll >= 8)
|
||||
innerkernel_1pow<uk, 8, 0, um_vecs, b_unroll, ktail, fetch_x, c_fetch>(aa, ao, bo, co2, fetchA_idx, fetchB_idx);
|
||||
|
||||
// Load A after pow-loop.
|
||||
load_a<0, um_vecs, uk, a_unroll, ktail>(ao);
|
||||
// Load A after pow-loop. Skip this at the end to prevent running over the buffer
|
||||
if (!no_a_preload) load_a<0, um_vecs, uk, a_unroll, ktail>(ao);
|
||||
}
|
||||
|
||||
/* Inner kernel loop structure.
|
||||
@ -698,7 +698,7 @@ class gemm_class {
|
||||
* bo += b_unroll * kfactor;
|
||||
*/
|
||||
|
||||
template <int a_unroll, int b_unroll, int k_factor, int max_b_unroll, int max_k_factor, bool c_fetch>
|
||||
template <int a_unroll, int b_unroll, int k_factor, int max_b_unroll, int max_k_factor, bool c_fetch, bool no_a_preload = false>
|
||||
EIGEN_ALWAYS_INLINE void innerkernel(const Scalar *&aa, const Scalar *&ao, const Scalar *&bo, Scalar *&co2) {
|
||||
int fetchA_idx = 0;
|
||||
int fetchB_idx = 0;
|
||||
@ -707,18 +707,19 @@ class gemm_class {
|
||||
const bool ktail = k_factor == 1;
|
||||
|
||||
static_assert(k_factor <= 4 && k_factor > 0, "innerkernel maximum k_factor supported is 4");
|
||||
static_assert(no_a_preload == false || (no_a_preload == true && k_factor == 1), "skipping a preload only allowed when k unroll is 1");
|
||||
|
||||
if (k_factor > 0)
|
||||
innerkernel_1uk<0, max_b_unroll, a_unroll, b_unroll, ktail, fetch_x, c_fetch>(aa, ao, bo, co2, fetchA_idx,
|
||||
innerkernel_1uk<0, max_b_unroll, a_unroll, b_unroll, ktail, fetch_x, c_fetch, no_a_preload>(aa, ao, bo, co2, fetchA_idx,
|
||||
fetchB_idx);
|
||||
if (k_factor > 1)
|
||||
innerkernel_1uk<1, max_b_unroll, a_unroll, b_unroll, ktail, fetch_x, c_fetch>(aa, ao, bo, co2, fetchA_idx,
|
||||
innerkernel_1uk<1, max_b_unroll, a_unroll, b_unroll, ktail, fetch_x, c_fetch, no_a_preload>(aa, ao, bo, co2, fetchA_idx,
|
||||
fetchB_idx);
|
||||
if (k_factor > 2)
|
||||
innerkernel_1uk<2, max_b_unroll, a_unroll, b_unroll, ktail, fetch_x, c_fetch>(aa, ao, bo, co2, fetchA_idx,
|
||||
innerkernel_1uk<2, max_b_unroll, a_unroll, b_unroll, ktail, fetch_x, c_fetch, no_a_preload>(aa, ao, bo, co2, fetchA_idx,
|
||||
fetchB_idx);
|
||||
if (k_factor > 3)
|
||||
innerkernel_1uk<3, max_b_unroll, a_unroll, b_unroll, ktail, fetch_x, c_fetch>(aa, ao, bo, co2, fetchA_idx,
|
||||
innerkernel_1uk<3, max_b_unroll, a_unroll, b_unroll, ktail, fetch_x, c_fetch, no_a_preload>(aa, ao, bo, co2, fetchA_idx,
|
||||
fetchB_idx);
|
||||
|
||||
// Advance A/B pointers after uk-loop.
|
||||
@ -729,7 +730,7 @@ class gemm_class {
|
||||
template <int a_unroll, int b_unroll, int max_b_unroll>
|
||||
EIGEN_ALWAYS_INLINE void kloop(const Scalar *&aa, const Scalar *&ao, const Scalar *&bo, Scalar *&co1, Scalar *&co2) {
|
||||
const int um_vecs = div_up(a_unroll, nelems_in_cache_line);
|
||||
if (!use_less_a_regs)
|
||||
if (!use_less_a_regs && k > 1)
|
||||
a_loads<0, 2, 0, um_vecs, a_unroll>(ao);
|
||||
else
|
||||
a_loads<0, 1, 0, um_vecs, a_unroll>(ao);
|
||||
@ -743,7 +744,13 @@ class gemm_class {
|
||||
|
||||
// Unrolling k-loop by a factor of 4.
|
||||
const int max_k_factor = 4;
|
||||
Index loop_count = k / max_k_factor;
|
||||
Index kRem = k % max_k_factor;
|
||||
Index k_ = k - kRem;
|
||||
if (k_ >= max_k_factor) {
|
||||
k_ -= max_k_factor;
|
||||
kRem += max_k_factor;
|
||||
}
|
||||
Index loop_count = k_ / max_k_factor;
|
||||
|
||||
if (loop_count > 0) {
|
||||
#ifdef SECOND_FETCH
|
||||
@ -771,11 +778,14 @@ class gemm_class {
|
||||
}
|
||||
|
||||
// k-loop remainder handling.
|
||||
loop_count = k % max_k_factor;
|
||||
while (loop_count > 0) {
|
||||
loop_count = kRem;
|
||||
while (loop_count > 1) {
|
||||
innerkernel<a_unroll, b_unroll, 1, max_b_unroll, max_k_factor, 0>(aa, ao, bo, co2);
|
||||
loop_count--;
|
||||
}
|
||||
if (loop_count > 0) {
|
||||
innerkernel<a_unroll, b_unroll, 1, max_b_unroll, max_k_factor, 0, true>(aa, ao, bo, co2);
|
||||
}
|
||||
|
||||
// Update C matrix.
|
||||
c_update<max_b_unroll, a_unroll, b_unroll>(co1, co2);
|
||||
|
@ -299,7 +299,7 @@ class transB {
|
||||
* 1-D unroll
|
||||
* for(startN = 0; startN < endN; startN++)
|
||||
**/
|
||||
template <int64_t endN, int64_t counter, int64_t packetIndexOffset, bool remM>
|
||||
template <int64_t endN, int64_t counter, int64_t packetIndexOffset, bool remM, int64_t remN_>
|
||||
static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_loadB(
|
||||
Scalar *B_arr, int64_t LDB, PacketBlock<vecHalf, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &ymm,
|
||||
int64_t remM_ = 0) {
|
||||
@ -310,12 +310,18 @@ class transB {
|
||||
ymm.packet[packetIndexOffset + startN] =
|
||||
ploadu<vecHalf>((const Scalar *)&B_arr[startN * LDB], remMask<EIGEN_AVX_MAX_NUM_ROW>(remM_));
|
||||
}
|
||||
else ymm.packet[packetIndexOffset + startN] = ploadu<vecHalf>((const Scalar *)&B_arr[startN * LDB]);
|
||||
else {
|
||||
EIGEN_IF_CONSTEXPR(remN_ == 0) {
|
||||
ymm.packet[packetIndexOffset + startN] = ploadu<vecHalf>((const Scalar *)&B_arr[startN * LDB]);
|
||||
}
|
||||
else ymm.packet[packetIndexOffset + startN] =
|
||||
ploadu<vecHalf>((const Scalar *)&B_arr[startN * LDB], remMask<EIGEN_AVX_MAX_NUM_ROW>(remN_));
|
||||
}
|
||||
|
||||
aux_loadB<endN, counter - 1, packetIndexOffset, remM>(B_arr, LDB, ymm, remM_);
|
||||
aux_loadB<endN, counter - 1, packetIndexOffset, remM, remN_>(B_arr, LDB, ymm, remM_);
|
||||
}
|
||||
|
||||
template <int64_t endN, int64_t counter, int64_t packetIndexOffset, bool remM>
|
||||
template <int64_t endN, int64_t counter, int64_t packetIndexOffset, bool remM, int64_t remN_>
|
||||
static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_loadB(
|
||||
Scalar *B_arr, int64_t LDB, PacketBlock<vecHalf, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &ymm,
|
||||
int64_t remM_ = 0) {
|
||||
@ -363,17 +369,17 @@ class transB {
|
||||
* 1-D unroll
|
||||
* for(startN = 0; startN < endN; startN += EIGEN_AVX_MAX_NUM_ROW)
|
||||
**/
|
||||
template <int64_t endN, int64_t counter, bool toTemp, bool remM>
|
||||
template <int64_t endN, int64_t counter, bool toTemp, bool remM, int64_t remN_>
|
||||
static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_loadBBlock(
|
||||
Scalar *B_arr, int64_t LDB, Scalar *B_temp, int64_t LDB_,
|
||||
PacketBlock<vecHalf, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &ymm, int64_t remM_ = 0) {
|
||||
constexpr int64_t counterReverse = endN - counter;
|
||||
constexpr int64_t startN = counterReverse;
|
||||
transB::template loadB<EIGEN_AVX_MAX_NUM_ROW, startN, false>(&B_temp[startN], LDB_, ymm);
|
||||
aux_loadBBlock<endN, counter - EIGEN_AVX_MAX_NUM_ROW, toTemp, remM>(B_arr, LDB, B_temp, LDB_, ymm, remM_);
|
||||
transB::template loadB<EIGEN_AVX_MAX_NUM_ROW, startN, false, (toTemp ? 0 : remN_)>(&B_temp[startN], LDB_, ymm);
|
||||
aux_loadBBlock<endN, counter - EIGEN_AVX_MAX_NUM_ROW, toTemp, remM, remN_>(B_arr, LDB, B_temp, LDB_, ymm, remM_);
|
||||
}
|
||||
|
||||
template <int64_t endN, int64_t counter, bool toTemp, bool remM>
|
||||
template <int64_t endN, int64_t counter, bool toTemp, bool remM, int64_t remN_>
|
||||
static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_loadBBlock(
|
||||
Scalar *B_arr, int64_t LDB, Scalar *B_temp, int64_t LDB_,
|
||||
PacketBlock<vecHalf, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &ymm, int64_t remM_ = 0) {
|
||||
@ -424,11 +430,11 @@ class transB {
|
||||
* Wrappers for aux_XXXX to hide counter parameter
|
||||
********************************************************/
|
||||
|
||||
template <int64_t endN, int64_t packetIndexOffset, bool remM>
|
||||
template <int64_t endN, int64_t packetIndexOffset, bool remM, int64_t remN_>
|
||||
static EIGEN_ALWAYS_INLINE void loadB(Scalar *B_arr, int64_t LDB,
|
||||
PacketBlock<vecHalf, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &ymm,
|
||||
int64_t remM_ = 0) {
|
||||
aux_loadB<endN, endN, packetIndexOffset, remM>(B_arr, LDB, ymm, remM_);
|
||||
aux_loadB<endN, endN, packetIndexOffset, remM, remN_>(B_arr, LDB, ymm, remM_);
|
||||
}
|
||||
|
||||
template <int64_t endN, int64_t packetIndexOffset, bool remK, bool remM>
|
||||
@ -438,13 +444,13 @@ class transB {
|
||||
aux_storeB<endN, endN, packetIndexOffset, remK, remM>(B_arr, LDB, ymm, rem_);
|
||||
}
|
||||
|
||||
template <int64_t unrollN, bool toTemp, bool remM>
|
||||
template <int64_t unrollN, bool toTemp, bool remM, int64_t remN_ = 0>
|
||||
static EIGEN_ALWAYS_INLINE void loadBBlock(Scalar *B_arr, int64_t LDB, Scalar *B_temp, int64_t LDB_,
|
||||
PacketBlock<vecHalf, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &ymm,
|
||||
int64_t remM_ = 0) {
|
||||
EIGEN_IF_CONSTEXPR(toTemp) { transB::template loadB<unrollN, 0, remM>(&B_arr[0], LDB, ymm, remM_); }
|
||||
EIGEN_IF_CONSTEXPR(toTemp) { transB::template loadB<unrollN, 0, remM, 0>(&B_arr[0], LDB, ymm, remM_); }
|
||||
else {
|
||||
aux_loadBBlock<unrollN, unrollN, toTemp, remM>(B_arr, LDB, B_temp, LDB_, ymm, remM_);
|
||||
aux_loadBBlock<unrollN, unrollN, toTemp, remM, remN_>(B_arr, LDB, B_temp, LDB_, ymm, remM_);
|
||||
}
|
||||
}
|
||||
|
||||
@ -550,13 +556,13 @@ class transB {
|
||||
}
|
||||
else EIGEN_IF_CONSTEXPR(unrollN == 2) {
|
||||
// load Lx2 B col major, transpose Lx2 row major
|
||||
transB::template loadBBlock<2, toTemp, remM>(B_arr, LDB, B_temp, LDB_, ymm, remM_);
|
||||
transB::template loadBBlock<2, toTemp, remM, 2>(B_arr, LDB, B_temp, LDB_, ymm, remM_);
|
||||
transB::template transposeLxL<0>(ymm);
|
||||
transB::template storeBBlock<2, toTemp, remM, 2>(B_arr, LDB, B_temp, LDB_, ymm, remM_);
|
||||
}
|
||||
else EIGEN_IF_CONSTEXPR(unrollN == 1) {
|
||||
// load Lx1 B col major, transpose Lx1 row major
|
||||
transB::template loadBBlock<1, toTemp, remM>(B_arr, LDB, B_temp, LDB_, ymm, remM_);
|
||||
transB::template loadBBlock<1, toTemp, remM, 1>(B_arr, LDB, B_temp, LDB_, ymm, remM_);
|
||||
transB::template transposeLxL<0>(ymm);
|
||||
transB::template storeBBlock<1, toTemp, remM, 1>(B_arr, LDB, B_temp, LDB_, ymm, remM_);
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user