// This file is part of Eigen, a lightweight C++ template library // for linear algebra. // // Copyright (C) 2022 Intel Corporation // // This Source Code Form is subject to the terms of the Mozilla // Public License v. 2.0. If a copy of the MPL was not distributed // with this file, You can obtain one at http://mozilla.org/MPL/2.0/. #ifndef EIGEN_CORE_ARCH_AVX512_GEMM_KERNEL_H #define EIGEN_CORE_ARCH_AVX512_GEMM_KERNEL_H #if EIGEN_COMP_MSVC #include #else #include #endif #include #include #include "../../InternalHeaderCheck.h" #if !defined(EIGEN_USE_AVX512_GEMM_KERNELS) #define EIGEN_USE_AVX512_GEMM_KERNELS 1 #endif #define SECOND_FETCH (32) #if (EIGEN_COMP_GNUC_STRICT != 0) && !defined(EIGEN_ARCH_AVX512_GEMM_KERNEL_USE_LESS_A_REGS) // Use less registers to load A elements to workaround compiler spills. Loose a // bit of performance (less than ~2%). #define EIGEN_ARCH_AVX512_GEMM_KERNEL_USE_LESS_A_REGS #endif namespace Eigen { namespace internal { template class gemm_class { using vec = typename packet_traits::type; using vec_ymm = typename unpacket_traits::half; using vec_xmm = typename unpacket_traits::half; using umask_t = typename unpacket_traits::mask_t; static constexpr bool is_f32 = sizeof(Scalar) == sizeof(float); static constexpr bool is_f64 = sizeof(Scalar) == sizeof(double); #ifndef EIGEN_ARCH_AVX512_GEMM_KERNEL_USE_LESS_A_REGS static constexpr bool use_less_a_regs = !is_unit_inc; #else static constexpr bool use_less_a_regs = true; #endif #ifndef EIGEN_ARCH_AVX512_GEMM_KERNEL_USE_LESS_B_REGS static constexpr bool use_less_b_regs = !is_unit_inc; #else static constexpr bool use_less_b_regs = true; #endif static constexpr int a_regs[] = {0, 1, 2, use_less_a_regs ? 0 : 3, use_less_a_regs ? 1 : 4, use_less_a_regs ? 2 : 5}; static constexpr int b_regs[] = {6, use_less_b_regs ? 6 : 7}; static constexpr int c_regs[] = { 8, 16, 24, 9, 17, 25, 10, 18, 26, 11, 19, 27, 12, 20, 28, 13, 21, 29, 14, 22, 30, 15, 23, 31, }; static constexpr int alpha_load_reg = 0; static constexpr int c_load_regs[] = {1, 2, 6}; static constexpr int a_shift = 128; static constexpr int b_shift = 128; static constexpr int nelems_in_cache_line = is_f32 ? 16 : 8; static constexpr int a_prefetch_size = nelems_in_cache_line * 2; static constexpr int b_prefetch_size = nelems_in_cache_line * 8; vec zmm[32]; umask_t mask; // gemm arguments. Index m; const Index n, k, ldc; const Index inc; const Scalar *alpha; const Scalar *a, *b; Scalar *c; const bool is_alpha1; const bool is_beta0; const Index a_stride, b_stride; const Index a_off, b_off; static EIGEN_ALWAYS_INLINE constexpr int div_up(int a, int b) { return a == 0 ? 0 : (a - 1) / b + 1; } EIGEN_ALWAYS_INLINE void prefetch_a(const Scalar *a_addr) { _mm_prefetch((char *)(a_prefetch_size + a_addr - a_shift), _MM_HINT_T0); } EIGEN_ALWAYS_INLINE void prefetch_b(const Scalar *b_addr) { _mm_prefetch((char *)(b_prefetch_size + b_addr - b_shift), _MM_HINT_T0); } EIGEN_ALWAYS_INLINE void prefetch_x(const Scalar *x_addr) { _mm_prefetch((char *)(x_addr - a_shift), _MM_HINT_T2); } EIGEN_ALWAYS_INLINE void prefetch_c(const Scalar *c_addr) { #if defined(__PRFCHW__) && __PRFCHW__ == 1 _m_prefetchw((void *)c_addr); #else _mm_prefetch((char *)c_addr, _MM_HINT_T0); #endif } template EIGEN_ALWAYS_INLINE void a_load(vec &a_reg, const Scalar *a_addr) { switch (nelems * sizeof(*a_addr) * 8) { default: case 512 * 3: a_reg = ploadu(a_addr); break; case 512 * 2: a_reg = ploadu(a_addr); break; case 512 * 1: a_reg = ploadu(a_addr); break; case 256 * 1: a_reg = preinterpret(_mm512_broadcast_f64x4(ploadu(reinterpret_cast(a_addr)))); break; case 128 * 1: a_reg = preinterpret(_mm512_broadcast_f32x4(ploadu(reinterpret_cast(a_addr)))); break; case 64 * 1: a_reg = preinterpret(pload1(reinterpret_cast(a_addr))); break; case 32 * 1: a_reg = pload1(a_addr); break; } } EIGEN_ALWAYS_INLINE void b_load(vec &b_reg, const Scalar *b_addr) { b_reg = pload1(b_addr); } template EIGEN_ALWAYS_INLINE void c_store(Scalar *mem, vec &src) { if (is_unit_inc) { switch (nelems * sizeof(*mem) * 8) { default: case 512 * 3: pstoreu(mem, src); break; case 512 * 2: pstoreu(mem, src); break; case 512 * 1: pstoreu(mem, src); break; case 256 * 1: pstoreu(mem, preinterpret(src)); break; case 128 * 1: pstoreu(mem, preinterpret(src)); break; case 64 * 1: pstorel(mem, preinterpret(src)); break; case 32 * 1: pstores(mem, preinterpret(src)); break; } } else { switch (nelems * sizeof(*mem) * 8) { default: case 512 * 3: pscatter(mem, src, inc); break; case 512 * 2: pscatter(mem, src, inc); break; case 512 * 1: pscatter(mem, src, inc); break; case 256 * 1: pscatter(mem, src, inc, mask); break; case 128 * 1: pscatter(mem, src, inc, mask); break; case 64 * 1: pscatter(mem, src, inc, mask); break; case 32 * 1: pscatter(mem, src, inc, mask); break; } } } template EIGEN_ALWAYS_INLINE void vaddm(vec &dst, const Scalar *mem, vec &src, vec ®) { if (is_unit_inc) { switch (nelems * sizeof(*mem) * 8) { default: case 512 * 3: dst = padd(src, ploadu(mem)); break; case 512 * 2: dst = padd(src, ploadu(mem)); break; case 512 * 1: dst = padd(src, ploadu(mem)); break; case 256 * 1: dst = preinterpret(padd(preinterpret(src), ploadu(mem))); break; case 128 * 1: dst = preinterpret(padd(preinterpret(src), ploadu(mem))); break; case 64 * 1: dst = preinterpret(padd(preinterpret(src), ploadl(mem))); break; case 32 * 1: dst = preinterpret(padds(preinterpret(src), ploads(mem))); break; } } else { // Zero out scratch register reg = pzero(reg); switch (nelems * sizeof(*mem) * 8) { default: case 512 * 3: reg = pgather(mem, inc); dst = padd(src, reg); break; case 512 * 2: reg = pgather(mem, inc); dst = padd(src, reg); break; case 512 * 1: reg = pgather(mem, inc); dst = padd(src, reg); break; case 256 * 1: reg = preinterpret(pgather(mem, inc)); dst = preinterpret(padd(preinterpret(src), preinterpret(reg))); break; case 128 * 1: reg = preinterpret(pgather(mem, inc)); dst = preinterpret(padd(preinterpret(src), preinterpret(reg))); break; case 64 * 1: if (is_f32) { reg = pgather(reg, mem, inc, mask); dst = preinterpret(padd(preinterpret(src), preinterpret(reg))); } else { dst = preinterpret(padd(preinterpret(src), ploadl(mem))); } break; case 32 * 1: dst = preinterpret(padds(preinterpret(src), ploads(mem))); break; } } } EIGEN_STRONG_INLINE void vfmadd(vec &dst, const vec &src1, const vec &src2) { dst = pmadd(src1, src2, dst); #if (EIGEN_COMP_GNUC != 0) || (EIGEN_COMP_CLANG != 0) // Workaround register spills for gcc and clang __asm__("#" : [dst] "+v"(dst) : [src1] "%v"(src1), [src2] "v"(src2)); #endif } template EIGEN_ALWAYS_INLINE void vfmaddm(vec &dst, const Scalar *mem, vec &src, vec &scale, vec ®) { if (is_unit_inc) { switch (nelems * sizeof(*mem) * 8) { default: case 512 * 3: dst = pmadd(scale, src, ploadu(mem)); break; case 512 * 2: dst = pmadd(scale, src, ploadu(mem)); break; case 512 * 1: dst = pmadd(scale, src, ploadu(mem)); break; case 256 * 1: dst = preinterpret(pmadd(preinterpret(scale), preinterpret(src), ploadu(mem))); break; case 128 * 1: dst = preinterpret(pmadd(preinterpret(scale), preinterpret(src), ploadu(mem))); break; case 64 * 1: dst = preinterpret(pmadd(preinterpret(scale), preinterpret(src), ploadl(mem))); break; case 32 * 1: dst = preinterpret(pmadds(preinterpret(scale), preinterpret(src), ploads(mem))); break; } } else { // Zero out scratch register reg = pzero(reg); switch (nelems * sizeof(*mem) * 8) { default: case 512 * 3: reg = pgather(mem, inc); dst = pmadd(scale, src, reg); break; case 512 * 2: reg = pgather(mem, inc); dst = pmadd(scale, src, reg); break; case 512 * 1: reg = pgather(mem, inc); dst = pmadd(scale, src, reg); break; case 256 * 1: reg = preinterpret(pgather(mem, inc)); dst = preinterpret( pmadd(preinterpret(scale), preinterpret(src), preinterpret(reg))); break; case 128 * 1: reg = preinterpret(pgather(mem, inc)); dst = preinterpret( pmadd(preinterpret(scale), preinterpret(src), preinterpret(reg))); break; case 64 * 1: if (is_f32) { reg = pgather(reg, mem, inc, mask); dst = preinterpret( pmadd(preinterpret(scale), preinterpret(src), preinterpret(reg))); } else { dst = preinterpret( pmadd(preinterpret(scale), preinterpret(src), ploadl(mem))); } break; case 32 * 1: dst = preinterpret(pmadds(preinterpret(scale), preinterpret(src), ploads(mem))); break; } } } template EIGEN_ALWAYS_INLINE std::enable_if_t<(j > endX) || (i > endY)> a_loads(const Scalar *ao) { EIGEN_UNUSED_VARIABLE(ao); } template EIGEN_ALWAYS_INLINE std::enable_if_t<(j <= endX) && (i <= endY)> a_loads(const Scalar *ao) { if (j < endX) { if (i < endY) { auto &a_reg = zmm[a_regs[i + (j % 2) * 3]]; const Scalar *a_addr = ao + nelems * j + nelems_in_cache_line * i - a_shift; a_load(a_reg, a_addr); a_loads(ao); } else { a_loads(ao); } } } template EIGEN_ALWAYS_INLINE std::enable_if_t<(un > max_b_unroll) || (i > um_vecs)> prefetch_cs(const Scalar *co1, const Scalar *co2) { EIGEN_UNUSED_VARIABLE(co1); EIGEN_UNUSED_VARIABLE(co2); } /* C prefetch loop structure. * for (int un = 0; un < 8; un++) { * if (b_unroll >= un + 1) { * if (un == 4) co2 = co1 + 4 * ldc; * * for (int i = 0; i < um_vecs; i++) { * Scalar *co = (un + 1 <= 4) ? co1 : co2; * auto co_off = (un % 4) * ldc + a_unroll - 1 + i * nelems_in_cache_line * sizeof *co; * prefetch_c(co + co_off); * } * } * } */ template EIGEN_ALWAYS_INLINE std::enable_if_t<(un <= max_b_unroll) && (i <= um_vecs)> prefetch_cs(Scalar *&co1, Scalar *&co2) { if (un < max_b_unroll) { if (b_unroll >= un + 1) { if (un == 4 && i == 0) co2 = co1 + 4 * ldc; if (i < um_vecs) { Scalar *co = (un + 1 <= 4) ? co1 : co2; auto co_off = (un % 4) * ldc + a_unroll - 1 + i * nelems_in_cache_line * sizeof *co; prefetch_c(co + co_off); prefetch_cs(co1, co2); } else { prefetch_cs(co1, co2); } } else { prefetch_cs(co1, co2); } } } // load_c template EIGEN_ALWAYS_INLINE std::enable_if_t<(i > um_vecs)> scale_load_c(const Scalar *cox, vec &alpha_reg) { EIGEN_UNUSED_VARIABLE(cox); EIGEN_UNUSED_VARIABLE(alpha_reg); } template EIGEN_ALWAYS_INLINE std::enable_if_t<(i <= um_vecs)> scale_load_c(const Scalar *cox, vec &alpha_reg) { if (i < um_vecs) { auto &c_reg = zmm[c_regs[i + idx * 3]]; auto &c_load_reg = zmm[c_load_regs[i % 3]]; auto c_mem = cox; if (is_unit_inc) c_mem += i * nelems_in_cache_line; else c_mem += i * nelems_in_cache_line * inc; if (!is_beta0 && is_alpha1) vaddm(c_reg, c_mem, c_reg, c_load_reg); else if (!is_beta0 && !is_alpha1) vfmaddm(c_reg, c_mem, c_reg, alpha_reg, c_load_reg); else if (is_beta0 && !is_alpha1) c_reg = pmul(alpha_reg, c_reg); scale_load_c(cox, alpha_reg); } } // store_c template EIGEN_ALWAYS_INLINE std::enable_if_t<(i > um_vecs)> write_c(Scalar *cox) { EIGEN_UNUSED_VARIABLE(cox); } template EIGEN_ALWAYS_INLINE std::enable_if_t<(i <= um_vecs)> write_c(Scalar *cox) { if (i < um_vecs) { auto &c_reg = zmm[c_regs[i + idx * 3]]; auto c_mem = cox; if (is_unit_inc) c_mem += i * nelems_in_cache_line; else c_mem += i * nelems_in_cache_line * inc; c_store(c_mem, c_reg); c_reg = pzero(c_reg); write_c(cox); } } /* C update loop structure. * co2 = co1 + ldc; * * auto &alpha_reg = zmm[alpha_load_reg]; * if (!is_alpha1) alpha_reg = pload1(alpha); * * int idx = 0; * for (pow = 1; pow <= 8; pow <<= 1) { * * if (b_unroll >= pow) { * for (count = 1; count < (pow + 1) / 2 + 1; count++) { * if (pow >= 4) co2 += ldc; * * const Scalar *cox = (idx == 0) ? co1 : co2; * * const int um_vecs = div_up(a_unroll, nelems_in_cache_line); * scale_load_c<0, um_vecs, idx, a_unroll>(cox, alpha_reg); * write_c<0, um_vecs, idx, a_unroll>(cox); * * idx++; * } * } * } * * if (b_unroll == 1) * co1 += ldc; * else * co1 = co2 + ldc; */ template EIGEN_ALWAYS_INLINE void c_update_1count(Scalar *&cox) { if (pow >= 4) cox += ldc; const int um_vecs = div_up(a_unroll, nelems_in_cache_line); auto &alpha_reg = zmm[alpha_load_reg]; scale_load_c<0, um_vecs, idx, a_unroll>(cox, alpha_reg); write_c<0, um_vecs, idx, a_unroll>(cox); } template EIGEN_ALWAYS_INLINE void c_update_1pow(Scalar *&co1, Scalar *&co2) { constexpr int idx = pow / 2; Scalar *&cox = idx == 0 ? co1 : co2; constexpr int max_count = (pow + 1) / 2; static_assert(max_count <= 4, "Unsupported max_count."); if (1 <= max_count) c_update_1count(cox); if (2 <= max_count) c_update_1count(cox); if (3 <= max_count) c_update_1count(cox); if (4 <= max_count) c_update_1count(cox); } template EIGEN_ALWAYS_INLINE void c_update(Scalar *&co1, Scalar *&co2) { auto &alpha_reg = zmm[alpha_load_reg]; co2 = co1 + ldc; if (!is_alpha1) alpha_reg = pload1(alpha); if (!is_unit_inc && a_unroll < nelems_in_cache_line) mask = static_cast((1ull << a_unroll) - 1); static_assert(max_b_unroll <= 8, "Unsupported max_b_unroll"); if (1 <= max_b_unroll && 1 <= b_unroll) c_update_1pow<1, a_unroll>(co1, co2); if (2 <= max_b_unroll && 2 <= b_unroll) c_update_1pow<2, a_unroll>(co1, co2); if (4 <= max_b_unroll && 4 <= b_unroll) c_update_1pow<4, a_unroll>(co1, co2); if (8 <= max_b_unroll && 8 <= b_unroll) c_update_1pow<8, a_unroll>(co1, co2); if (b_unroll == 1) co1 += ldc; else co1 = co2 + ldc; } // compute template EIGEN_ALWAYS_INLINE std::enable_if_t<(um > um_vecs)> compute(const Scalar *ao, const Scalar *bo, int &fetchA_idx, int &fetchB_idx, vec &b_reg) { EIGEN_UNUSED_VARIABLE(ao); EIGEN_UNUSED_VARIABLE(bo); EIGEN_UNUSED_VARIABLE(fetchA_idx); EIGEN_UNUSED_VARIABLE(fetchB_idx); EIGEN_UNUSED_VARIABLE(b_reg); } template EIGEN_ALWAYS_INLINE std::enable_if_t<(um <= um_vecs)> compute(const Scalar *ao, const Scalar *bo, int &fetchA_idx, int &fetchB_idx, vec &b_reg) { if (um < um_vecs) { auto &c_reg = zmm[c_regs[um + idx * 3]]; auto &a_reg = zmm[a_regs[um + (uk % 2) * 3]]; vfmadd(c_reg, a_reg, b_reg); if (!fetch_x && um == 0 && (((idx == 0 || idx == 6) && (uk % 2 == 0 || is_f64 || ktail)) || (idx == 3 && (uk % 2 == 1 || is_f64 || ktail)))) { prefetch_a(ao + nelems_in_cache_line * fetchA_idx); fetchA_idx++; } if (um == 0 && idx == 1 && (uk % 2 == 0 || is_f64 || ktail)) { prefetch_b(bo + nelems_in_cache_line * fetchB_idx); fetchB_idx++; } compute(ao, bo, fetchA_idx, fetchB_idx, b_reg); } } // load_a template EIGEN_ALWAYS_INLINE std::enable_if_t<(um > um_vecs)> load_a(const Scalar *ao) { EIGEN_UNUSED_VARIABLE(ao); } template EIGEN_ALWAYS_INLINE std::enable_if_t<(um <= um_vecs)> load_a(const Scalar *ao) { if (um < um_vecs) { auto &a_reg = zmm[a_regs[um + (uk % 2) * 3]]; const Scalar *a_addr = ao + nelems * (1 + !ktail * !use_less_a_regs + uk) + nelems_in_cache_line * um - a_shift; a_load(a_reg, a_addr); load_a(ao); } } template EIGEN_ALWAYS_INLINE std::enable_if_t<(count > (pow + 1) / 2)> innerkernel_1pow(const Scalar *&aa, const Scalar *const &ao, const Scalar *const &bo, Scalar *&co2, int &fetchA_idx, int &fetchB_idx) { EIGEN_UNUSED_VARIABLE(aa); EIGEN_UNUSED_VARIABLE(ao); EIGEN_UNUSED_VARIABLE(bo); EIGEN_UNUSED_VARIABLE(co2); EIGEN_UNUSED_VARIABLE(fetchA_idx); EIGEN_UNUSED_VARIABLE(fetchB_idx); } template EIGEN_ALWAYS_INLINE std::enable_if_t<(count <= (pow + 1) / 2)> innerkernel_1pow(const Scalar *&aa, const Scalar *const &ao, const Scalar *const &bo, Scalar *&co2, int &fetchA_idx, int &fetchB_idx) { const int idx = (pow / 2) + count; if (count < (pow + 1) / 2) { auto &b_reg = zmm[b_regs[idx % 2]]; if (fetch_x && uk == 3 && idx == 0) prefetch_x(aa); if (fetch_x && uk == 3 && idx == 4) aa += 8; if (b_unroll >= pow) { compute<0, um_vecs, idx, uk, fetch_x, ktail>(ao, bo, fetchA_idx, fetchB_idx, b_reg); const Scalar *b_addr = bo + b_unroll * uk + idx + 1 + (b_unroll > 1) * !use_less_b_regs - b_shift; b_load(b_reg, b_addr); } // Go to the next count. innerkernel_1pow(aa, ao, bo, co2, fetchA_idx, fetchB_idx); } else { // Maybe prefetch C data after count-loop. if (pow == 2 && c_fetch) { if (uk % 3 == 0 && uk > 0) { co2 += ldc; } else { prefetch_c(co2 + (uk % 3) * nelems_in_cache_line); } } } } template EIGEN_ALWAYS_INLINE void innerkernel_1uk(const Scalar *&aa, const Scalar *const &ao, const Scalar *const &bo, Scalar *&co2, int &fetchA_idx, int &fetchB_idx) { const int um_vecs = div_up(a_unroll, nelems_in_cache_line); if (max_b_unroll >= 1) innerkernel_1pow(aa, ao, bo, co2, fetchA_idx, fetchB_idx); if (max_b_unroll >= 2) innerkernel_1pow(aa, ao, bo, co2, fetchA_idx, fetchB_idx); if (max_b_unroll >= 4) innerkernel_1pow(aa, ao, bo, co2, fetchA_idx, fetchB_idx); if (max_b_unroll >= 8) innerkernel_1pow(aa, ao, bo, co2, fetchA_idx, fetchB_idx); // Load A after pow-loop. Skip this at the end to prevent running over the buffer if (!no_a_preload) load_a<0, um_vecs, uk, a_unroll, ktail>(ao); } /* Inner kernel loop structure. * for (int uk = 0; uk < kfactor; uk++) { * int idx = 0; * * for (pow = 1; pow < max_b_unroll << 1; pow <<= 1) { * for (int count = 0; count < (pow + 1) / 2; count++) { * auto &b_reg = zmm[b_regs[idx % 2]]; * * if (fetch_x && uk == 3 && idx == 0) prefetch_x(aa); * if (fetch_x && uk == 3 && idx == 4) aa += 8; * * if (b_unroll >= pow) { * compute<0, um_vecs, idx, uk, fetchx, ktail>(ao, bo, fetchA_idx, fetchB_idx, b_reg); * * const Scalar *b_addr = bo + b_unroll * uk + idx + 1 + (b_unroll > 1) - b_shift ; * b_load(b_reg, b_addr); * } * idx++; * } * * Maybe prefetch C data. * if (pow == 2 && c_fetch) { * if (uk % 3 == 0 && uk > 0) { * co2 += ldc; * } else { * prefetch_c(co2 + (uk % 3) * nelems_in_cache_line); * } * } * } * * Load A. * load_a<0, um_vecs, uk, ktail, a_unroll>(ao); * } * * Advance A/B pointers after uk-loop. * ao += a_unroll * kfactor; * bo += b_unroll * kfactor; */ template EIGEN_ALWAYS_INLINE void innerkernel(const Scalar *&aa, const Scalar *&ao, const Scalar *&bo, Scalar *&co2) { int fetchA_idx = 0; int fetchB_idx = 0; const bool fetch_x = k_factor == max_k_factor; const bool ktail = k_factor == 1; static_assert(k_factor <= 4 && k_factor > 0, "innerkernel maximum k_factor supported is 4"); static_assert(no_a_preload == false || (no_a_preload == true && k_factor == 1), "skipping a preload only allowed when k unroll is 1"); if (k_factor > 0) innerkernel_1uk<0, max_b_unroll, a_unroll, b_unroll, ktail, fetch_x, c_fetch, no_a_preload>(aa, ao, bo, co2, fetchA_idx, fetchB_idx); if (k_factor > 1) innerkernel_1uk<1, max_b_unroll, a_unroll, b_unroll, ktail, fetch_x, c_fetch, no_a_preload>(aa, ao, bo, co2, fetchA_idx, fetchB_idx); if (k_factor > 2) innerkernel_1uk<2, max_b_unroll, a_unroll, b_unroll, ktail, fetch_x, c_fetch, no_a_preload>(aa, ao, bo, co2, fetchA_idx, fetchB_idx); if (k_factor > 3) innerkernel_1uk<3, max_b_unroll, a_unroll, b_unroll, ktail, fetch_x, c_fetch, no_a_preload>(aa, ao, bo, co2, fetchA_idx, fetchB_idx); // Advance A/B pointers after uk-loop. ao += a_unroll * k_factor; bo += b_unroll * k_factor; } template EIGEN_ALWAYS_INLINE void kloop(const Scalar *&aa, const Scalar *&ao, const Scalar *&bo, Scalar *&co1, Scalar *&co2) { const int um_vecs = div_up(a_unroll, nelems_in_cache_line); if (!use_less_a_regs && k > 1) a_loads<0, 2, 0, um_vecs, a_unroll>(ao); else a_loads<0, 1, 0, um_vecs, a_unroll>(ao); b_load(zmm[b_regs[0]], bo - b_shift + 0); if (!use_less_b_regs) b_load(zmm[b_regs[1]], bo - b_shift + 1); #ifndef SECOND_FETCH prefetch_cs<0, max_b_unroll, 0, um_vecs, a_unroll, b_unroll>(co1, co2); #endif // SECOND_FETCH // Unrolling k-loop by a factor of 4. const int max_k_factor = 4; Index kRem = k % max_k_factor; Index k_ = k - kRem; if (k_ >= max_k_factor) { k_ -= max_k_factor; kRem += max_k_factor; } Index loop_count = k_ / max_k_factor; if (loop_count > 0) { #ifdef SECOND_FETCH loop_count -= SECOND_FETCH; #endif while (loop_count > 0) { innerkernel(aa, ao, bo, co2); loop_count--; } #ifdef SECOND_FETCH co2 = co1 + nelems_in_cache_line - 1; loop_count += b_unroll; while (loop_count > 0) { innerkernel(aa, ao, bo, co2); loop_count--; } loop_count += SECOND_FETCH - b_unroll; while (loop_count > 0) { innerkernel(aa, ao, bo, co2); loop_count--; } #endif } // k-loop remainder handling. loop_count = kRem; while (loop_count > 1) { innerkernel(aa, ao, bo, co2); loop_count--; } if (loop_count > 0) { innerkernel(aa, ao, bo, co2); } // Update C matrix. c_update(co1, co2); } template EIGEN_ALWAYS_INLINE void nloop(const Scalar *&aa, const Scalar *&ao, const Scalar *&bo, Scalar *&co1, Scalar *&co2) { // Set A matrix pointer. ao = a + a_off * a_unroll; // Set B matrix pointer if needed. bo += b_unroll * b_off; kloop(aa, ao, bo, co1, co2); // Advance B matrix pointer if needed. bo += b_unroll * (b_stride - k - b_off); // Advance prefetch A pointer. aa += 16; } template EIGEN_ALWAYS_INLINE void mloop(const Scalar *&ao, const Scalar *&bo, Scalar *&co1, Scalar *&co2) { // Set prefetch A pointers. const Scalar *aa = a + a_unroll * a_stride; // Set C matrix pointers. co1 = c; if (a_unroll >= max_a_unroll) co2 = c + 2 * ldc; if (is_unit_inc) c += a_unroll; else c += a_unroll * inc; // Set B matrix pointer. bo = b; // Main n-loop. for (Index i = n / max_b_unroll; i > 0; i--) nloop(aa, ao, bo, co1, co2); // n-remainders. if (n & 4 && max_b_unroll > 4) nloop(aa, ao, bo, co1, co2); #if 0 if (n & 2 && max_b_unroll > 2) nloop(aa, ao, bo, co1, co2); if (n & 1 && max_b_unroll > 1) nloop(aa, ao, bo, co1, co2); #else // Copy kernels don't support tails of n = 2 for single/double precision. // Loop over ones. int n_rem = 2 * ((n & 2) != 0) + 1 * ((n & 1) != 0); while (n_rem > 0) { nloop(aa, ao, bo, co1, co2); n_rem--; } #endif // Advance A matrix pointer. a = ao + a_unroll * (a_stride - k - a_off); } public: // Compute kernel unrolling C matrix by max_a_unroll x max_b_unroll. template EIGEN_ALWAYS_INLINE void compute_kern() { a -= -a_shift; b -= -b_shift; const Scalar *ao = nullptr; const Scalar *bo = nullptr; Scalar *co1 = nullptr; Scalar *co2 = nullptr; // Main m-loop. for (; m >= max_a_unroll; m -= max_a_unroll) mloop(ao, bo, co1, co2); // m-remainders. if (m & 32 && max_a_unroll > 32) mloop<32, max_a_unroll, max_b_unroll>(ao, bo, co1, co2); if (m & 16 && max_a_unroll > 16) mloop<16, max_a_unroll, max_b_unroll>(ao, bo, co1, co2); if (m & 8 && max_a_unroll > 8) mloop<8, max_a_unroll, max_b_unroll>(ao, bo, co1, co2); if (m & 4 && max_a_unroll > 4) mloop<4, max_a_unroll, max_b_unroll>(ao, bo, co1, co2); if (m & 2 && max_a_unroll > 2 && is_f64) mloop<2, max_a_unroll, max_b_unroll>(ao, bo, co1, co2); if (m & 1 && max_a_unroll > 1 && is_f64) mloop<1, max_a_unroll, max_b_unroll>(ao, bo, co1, co2); // Copy kernels don't support tails of m = 2 for single precision. // Loop over ones. if (is_f32) { int m_rem = 2 * ((m & 2) != 0) + 1 * ((m & 1) != 0); while (m_rem > 0) { mloop<1, max_a_unroll, max_b_unroll>(ao, bo, co1, co2); m_rem--; } } } gemm_class(Index m_, Index n_, Index k_, Index ldc_, Index inc_, const Scalar *alpha_, const Scalar *a_, const Scalar *b_, Scalar *c_, bool is_alpha1_, bool is_beta0_, Index a_stride_, Index b_stride_, Index a_off_, Index b_off_) : m(m_), n(n_), k(k_), ldc(ldc_), inc(inc_), alpha(alpha_), a(a_), b(b_), c(c_), is_alpha1(is_alpha1_), is_beta0(is_beta0_), a_stride(a_stride_), b_stride(b_stride_), a_off(a_off_), b_off(b_off_) { // Zero out all accumulation registers. zmm[8] = pzero(zmm[8]); zmm[9] = pzero(zmm[9]); zmm[10] = pzero(zmm[10]); zmm[11] = pzero(zmm[11]); zmm[12] = pzero(zmm[12]); zmm[13] = pzero(zmm[13]); zmm[14] = pzero(zmm[14]); zmm[15] = pzero(zmm[15]); zmm[16] = pzero(zmm[16]); zmm[17] = pzero(zmm[17]); zmm[18] = pzero(zmm[18]); zmm[19] = pzero(zmm[19]); zmm[20] = pzero(zmm[20]); zmm[21] = pzero(zmm[21]); zmm[22] = pzero(zmm[22]); zmm[23] = pzero(zmm[23]); zmm[24] = pzero(zmm[24]); zmm[25] = pzero(zmm[25]); zmm[26] = pzero(zmm[26]); zmm[27] = pzero(zmm[27]); zmm[28] = pzero(zmm[28]); zmm[29] = pzero(zmm[29]); zmm[30] = pzero(zmm[30]); zmm[31] = pzero(zmm[31]); } }; // Compute kernel with max unroll support of: // Single precision: // max_a_unroll: 48, 32, 16, 8, 4, 2, 1 // max_b_unroll: 8, 4, 2, 1 // Double precision: // max_a_unroll: 24, 16, 8, 4, 2, 1 // max_b_unroll: 8, 4, 2, 1 template EIGEN_DONT_INLINE void gemm_kern_avx512(Index m, Index n, Index k, Scalar *alpha, const Scalar *a, const Scalar *b, Scalar *c, Index ldc, Index inc = 1, Index a_stride = -1, Index b_stride = -1, Index a_off = 0, Index b_off = 0) { if (a_stride == -1) a_stride = k; if (b_stride == -1) b_stride = k; gemm_class g(m, n, k, ldc, inc, alpha, a, b, c, is_alpha1, is_beta0, a_stride, b_stride, a_off, b_off); g.template compute_kern(); } // Template specializations of GEBP kernels with nr = 8. #if EIGEN_USE_AVX512_GEMM_KERNELS template class gebp_traits : public gebp_traits { using Base = gebp_traits; public: enum { nr = Base::Vectorizable ? 8 : 4 }; }; template class gebp_traits : public gebp_traits { using Base = gebp_traits; public: enum { nr = Base::Vectorizable ? 8 : 4 }; }; template struct gemm_pack_rhs { typedef typename packet_traits::type Packet; typedef typename DataMapper::LinearMapper LinearMapper; enum { PacketSize = packet_traits::size }; EIGEN_DONT_INLINE void operator()(Scalar *blockB, const DataMapper &rhs, Index depth, Index cols, Index stride = 0, Index offset = 0); }; template EIGEN_DONT_INLINE void gemm_pack_rhs::operator()( Scalar *blockB, const DataMapper &rhs, Index depth, Index cols, Index stride, Index offset) { constexpr int nr = 8; EIGEN_ASM_COMMENT("EIGEN PRODUCT PACK RHS COLMAJOR"); EIGEN_UNUSED_VARIABLE(stride); EIGEN_UNUSED_VARIABLE(offset); eigen_assert(((!PanelMode) && stride == 0 && offset == 0) || (PanelMode && stride >= depth && offset <= stride)); conj_if::IsComplex && Conjugate> cj; Index packet_cols8 = nr >= 8 ? (cols / 8) * 8 : 0; Index packet_cols4 = nr >= 4 ? (cols / 4) * 4 : 0; Index count = 0; const Index peeled_k = (depth / PacketSize) * PacketSize; if (nr >= 8) { for (Index j2 = 0; j2 < packet_cols8; j2 += 8) { // skip what we have before if (PanelMode) count += 8 * offset; const LinearMapper dm0 = rhs.getLinearMapper(0, j2 + 0); const LinearMapper dm1 = rhs.getLinearMapper(0, j2 + 1); const LinearMapper dm2 = rhs.getLinearMapper(0, j2 + 2); const LinearMapper dm3 = rhs.getLinearMapper(0, j2 + 3); const LinearMapper dm4 = rhs.getLinearMapper(0, j2 + 4); const LinearMapper dm5 = rhs.getLinearMapper(0, j2 + 5); const LinearMapper dm6 = rhs.getLinearMapper(0, j2 + 6); const LinearMapper dm7 = rhs.getLinearMapper(0, j2 + 7); Index k = 0; if ((PacketSize % 8) == 0) // TODO enable vectorized transposition for PacketSize==4 { for (; k < peeled_k; k += PacketSize) { PacketBlock kernel; kernel.packet[0] = dm0.template loadPacket(k); kernel.packet[1] = dm1.template loadPacket(k); kernel.packet[2] = dm2.template loadPacket(k); kernel.packet[3] = dm3.template loadPacket(k); kernel.packet[4] = dm4.template loadPacket(k); kernel.packet[5] = dm5.template loadPacket(k); kernel.packet[6] = dm6.template loadPacket(k); kernel.packet[7] = dm7.template loadPacket(k); ptranspose(kernel); pstoreu(blockB + count + 0 * PacketSize, cj.pconj(kernel.packet[0])); pstoreu(blockB + count + 1 * PacketSize, cj.pconj(kernel.packet[1 % PacketSize])); pstoreu(blockB + count + 2 * PacketSize, cj.pconj(kernel.packet[2 % PacketSize])); pstoreu(blockB + count + 3 * PacketSize, cj.pconj(kernel.packet[3 % PacketSize])); pstoreu(blockB + count + 4 * PacketSize, cj.pconj(kernel.packet[4 % PacketSize])); pstoreu(blockB + count + 5 * PacketSize, cj.pconj(kernel.packet[5 % PacketSize])); pstoreu(blockB + count + 6 * PacketSize, cj.pconj(kernel.packet[6 % PacketSize])); pstoreu(blockB + count + 7 * PacketSize, cj.pconj(kernel.packet[7 % PacketSize])); count += 8 * PacketSize; } } for (; k < depth; k++) { blockB[count + 0] = cj(dm0(k)); blockB[count + 1] = cj(dm1(k)); blockB[count + 2] = cj(dm2(k)); blockB[count + 3] = cj(dm3(k)); blockB[count + 4] = cj(dm4(k)); blockB[count + 5] = cj(dm5(k)); blockB[count + 6] = cj(dm6(k)); blockB[count + 7] = cj(dm7(k)); count += 8; } // skip what we have after if (PanelMode) count += 8 * (stride - offset - depth); } } if (nr >= 4) { for (Index j2 = packet_cols8; j2 < packet_cols4; j2 += 4) { // skip what we have before if (PanelMode) count += 4 * offset; const LinearMapper dm0 = rhs.getLinearMapper(0, j2 + 0); const LinearMapper dm1 = rhs.getLinearMapper(0, j2 + 1); const LinearMapper dm2 = rhs.getLinearMapper(0, j2 + 2); const LinearMapper dm3 = rhs.getLinearMapper(0, j2 + 3); Index k = 0; if ((PacketSize % 4) == 0) // TODO enable vectorized transposition for PacketSize==2 ?? { for (; k < peeled_k; k += PacketSize) { PacketBlock kernel; kernel.packet[0] = dm0.template loadPacket(k); kernel.packet[1 % PacketSize] = dm1.template loadPacket(k); kernel.packet[2 % PacketSize] = dm2.template loadPacket(k); kernel.packet[3 % PacketSize] = dm3.template loadPacket(k); ptranspose(kernel); pstoreu(blockB + count + 0 * PacketSize, cj.pconj(kernel.packet[0])); pstoreu(blockB + count + 1 * PacketSize, cj.pconj(kernel.packet[1 % PacketSize])); pstoreu(blockB + count + 2 * PacketSize, cj.pconj(kernel.packet[2 % PacketSize])); pstoreu(blockB + count + 3 * PacketSize, cj.pconj(kernel.packet[3 % PacketSize])); count += 4 * PacketSize; } } for (; k < depth; k++) { blockB[count + 0] = cj(dm0(k)); blockB[count + 1] = cj(dm1(k)); blockB[count + 2] = cj(dm2(k)); blockB[count + 3] = cj(dm3(k)); count += 4; } // skip what we have after if (PanelMode) count += 4 * (stride - offset - depth); } } // copy the remaining columns one at a time (nr==1) for (Index j2 = packet_cols4; j2 < cols; ++j2) { if (PanelMode) count += offset; const LinearMapper dm0 = rhs.getLinearMapper(0, j2); for (Index k = 0; k < depth; k++) { blockB[count] = cj(dm0(k)); count += 1; } if (PanelMode) count += (stride - offset - depth); } } template struct gemm_pack_rhs { typedef typename packet_traits::type Packet; typedef typename unpacket_traits::half HalfPacket; typedef typename unpacket_traits::half>::half QuarterPacket; typedef typename DataMapper::LinearMapper LinearMapper; enum { PacketSize = packet_traits::size, HalfPacketSize = unpacket_traits::size, QuarterPacketSize = unpacket_traits::size }; EIGEN_DONT_INLINE void operator()(Scalar *blockB, const DataMapper &rhs, Index depth, Index cols, Index stride = 0, Index offset = 0) { constexpr int nr = 8; EIGEN_ASM_COMMENT("EIGEN PRODUCT PACK RHS ROWMAJOR"); EIGEN_UNUSED_VARIABLE(stride); EIGEN_UNUSED_VARIABLE(offset); eigen_assert(((!PanelMode) && stride == 0 && offset == 0) || (PanelMode && stride >= depth && offset <= stride)); const bool HasHalf = (int)HalfPacketSize < (int)PacketSize; const bool HasQuarter = (int)QuarterPacketSize < (int)HalfPacketSize; conj_if::IsComplex && Conjugate> cj; Index packet_cols8 = nr >= 8 ? (cols / 8) * 8 : 0; Index packet_cols4 = nr >= 4 ? (cols / 4) * 4 : 0; Index count = 0; if (nr >= 8) { for (Index j2 = 0; j2 < packet_cols8; j2 += 8) { // skip what we have before if (PanelMode) count += 8 * offset; for (Index k = 0; k < depth; k++) { if (PacketSize == 8) { // Packet A = ploadu(&rhs.data()[k*rhs.stride() + j2]); Packet A = rhs.template loadPacket(k, j2); pstoreu(blockB + count, cj.pconj(A)); } else if (HasHalf && HalfPacketSize == 8) { HalfPacket A = rhs.template loadPacket(k, j2); pstoreu(blockB + count, cj.pconj(A)); } else if (HasQuarter && QuarterPacketSize == 8) { QuarterPacket A = rhs.template loadPacket(k, j2); pstoreu(blockB + count, cj.pconj(A)); } else if (PacketSize == 4) { // Packet A = ploadu(&rhs.data()[k*rhs.stride() + j2]); // Packet B = ploadu(&rhs.data()[k*rhs.stride() + j2 + PacketSize]); Packet A = rhs.template loadPacket(k, j2); Packet B = rhs.template loadPacket(k, j2 + PacketSize); pstoreu(blockB + count, cj.pconj(A)); pstoreu(blockB + count + PacketSize, cj.pconj(B)); } else { // const Scalar* b0 = &rhs.data()[k*rhs.stride() + j2]; const LinearMapper dm0 = rhs.getLinearMapper(k, j2); blockB[count + 0] = cj(dm0(0)); blockB[count + 1] = cj(dm0(1)); blockB[count + 2] = cj(dm0(2)); blockB[count + 3] = cj(dm0(3)); blockB[count + 4] = cj(dm0(4)); blockB[count + 5] = cj(dm0(5)); blockB[count + 6] = cj(dm0(6)); blockB[count + 7] = cj(dm0(7)); } count += 8; } // skip what we have after if (PanelMode) count += 8 * (stride - offset - depth); } } if (nr >= 4) { for (Index j2 = packet_cols8; j2 < packet_cols4; j2 += 4) { // skip what we have before if (PanelMode) count += 4 * offset; for (Index k = 0; k < depth; k++) { if (PacketSize == 4) { Packet A = rhs.template loadPacket(k, j2); pstoreu(blockB + count, cj.pconj(A)); count += PacketSize; } else if (HasHalf && HalfPacketSize == 4) { HalfPacket A = rhs.template loadPacket(k, j2); pstoreu(blockB + count, cj.pconj(A)); count += HalfPacketSize; } else if (HasQuarter && QuarterPacketSize == 4) { QuarterPacket A = rhs.template loadPacket(k, j2); pstoreu(blockB + count, cj.pconj(A)); count += QuarterPacketSize; } else { const LinearMapper dm0 = rhs.getLinearMapper(k, j2); blockB[count + 0] = cj(dm0(0)); blockB[count + 1] = cj(dm0(1)); blockB[count + 2] = cj(dm0(2)); blockB[count + 3] = cj(dm0(3)); count += 4; } } // skip what we have after if (PanelMode) count += 4 * (stride - offset - depth); } } // copy the remaining columns one at a time (nr==1) for (Index j2 = packet_cols4; j2 < cols; ++j2) { if (PanelMode) count += offset; for (Index k = 0; k < depth; k++) { blockB[count] = cj(rhs(k, j2)); count += 1; } if (PanelMode) count += stride - offset - depth; } } }; template struct gebp_kernel { EIGEN_ALWAYS_INLINE void operator()(const DataMapper &res, const Scalar *blockA, const Scalar *blockB, Index rows, Index depth, Index cols, Scalar alpha, Index strideA = -1, Index strideB = -1, Index offsetA = 0, Index offsetB = 0); }; template EIGEN_ALWAYS_INLINE void gebp_kernel::operator()( const DataMapper &res, const Scalar *blockA, const Scalar *blockB, Index rows, Index depth, Index cols, Scalar alpha, Index strideA, Index strideB, Index offsetA, Index offsetB) { if (res.incr() == 1) { if (alpha == 1) { gemm_kern_avx512(rows, cols, depth, &alpha, blockA, blockB, (Scalar *)res.data(), res.stride(), res.incr(), strideA, strideB, offsetA, offsetB); } else { gemm_kern_avx512(rows, cols, depth, &alpha, blockA, blockB, (Scalar *)res.data(), res.stride(), res.incr(), strideA, strideB, offsetA, offsetB); } } else { if (alpha == 1) { gemm_kern_avx512(rows, cols, depth, &alpha, blockA, blockB, (Scalar *)res.data(), res.stride(), res.incr(), strideA, strideB, offsetA, offsetB); } else { gemm_kern_avx512(rows, cols, depth, &alpha, blockA, blockB, (Scalar *)res.data(), res.stride(), res.incr(), strideA, strideB, offsetA, offsetB); } } } #endif // EIGEN_USE_AVX512_GEMM_KERNELS } // namespace internal } // namespace Eigen #undef SECOND_FETCH #endif // EIGEN_CORE_ARCH_AVX512_GEMM_KERNEL_H