diff --git a/Eigen/Core b/Eigen/Core index 7bbdee3cf..63b9850c9 100644 --- a/Eigen/Core +++ b/Eigen/Core @@ -356,6 +356,10 @@ using std::ptrdiff_t; #include "src/Core/arch/NEON/GeneralBlockPanelKernel.h" #endif +#if defined(EIGEN_VECTORIZE_AVX512) + #include "src/Core/arch/AVX512/GemmKernel.h" +#endif + #include "src/Core/BooleanRedux.h" #include "src/Core/Select.h" #include "src/Core/VectorwiseOp.h" diff --git a/Eigen/src/Core/arch/AVX512/GemmKernel.h b/Eigen/src/Core/arch/AVX512/GemmKernel.h new file mode 100644 index 000000000..ee4beb91a --- /dev/null +++ b/Eigen/src/Core/arch/AVX512/GemmKernel.h @@ -0,0 +1,1182 @@ +// This file is part of Eigen, a lightweight C++ template library +// for linear algebra. +// +// Copyright (C) 2022 Intel Corporation +// +// This Source Code Form is subject to the terms of the Mozilla +// Public License v. 2.0. If a copy of the MPL was not distributed +// with this file, You can obtain one at http://mozilla.org/MPL/2.0/. + +#ifndef GEMM_KERNEL_H +#define GEMM_KERNEL_H + +#include +#include +#include + +#include "../../InternalHeaderCheck.h" + +#define SECOND_FETCH (32) +#if (EIGEN_COMP_GNUC_STRICT != 0) && !defined(EIGEN_ARCH_AVX512_GEMM_KERNEL_USE_LESS_A_REGS) +// Use less registers to load A elements to workaround compiler spills. Loose a +// bit of performance (less than ~2%). +#define EIGEN_ARCH_AVX512_GEMM_KERNEL_USE_LESS_A_REGS +#endif + + +namespace Eigen { +namespace internal { + +template +class gemm_class +{ + using vec = typename packet_traits::type; + using vec_ymm = typename unpacket_traits::half; + using vec_xmm = typename unpacket_traits::half; + using umask_t = typename unpacket_traits::mask_t; + + static constexpr bool is_f32 = sizeof(Scalar) == sizeof(float); + static constexpr bool is_f64 = sizeof(Scalar) == sizeof(double); + +#ifndef EIGEN_ARCH_AVX512_GEMM_KERNEL_USE_LESS_A_REGS + static constexpr bool use_less_a_regs = !is_unit_inc; +#else + static constexpr bool use_less_a_regs = true; +#endif +#ifndef EIGEN_ARCH_AVX512_GEMM_KERNEL_USE_LESS_B_REGS + static constexpr bool use_less_b_regs = !is_unit_inc; +#else + static constexpr bool use_less_b_regs = true; +#endif + + static constexpr int a_regs[] = {0, 1, 2, + use_less_a_regs ? 0 : 3, + use_less_a_regs ? 1 : 4, + use_less_a_regs ? 2 : 5 + }; + static constexpr int b_regs[] = {6, + use_less_b_regs ? 6 : 7 + }; + static constexpr int c_regs[] = { + 8 , 16, 24, + 9 , 17, 25, + 10, 18, 26, + 11, 19, 27, + 12, 20, 28, + 13, 21, 29, + 14, 22, 30, + 15, 23, 31, + }; + + static constexpr int alpha_load_reg = 0; + static constexpr int c_load_regs[] = {1, 2, 6}; + + static constexpr int a_shift = 128; + static constexpr int b_shift = 128; + + static constexpr int nelems_in_cache_line = is_f32 ? 16 : 8; + static constexpr int a_prefetch_size = nelems_in_cache_line * 2; + static constexpr int b_prefetch_size = nelems_in_cache_line * 8; + + vec zmm[32]; + umask_t mask; + + // gemm arguments. + Index m; + const Index n, k, ldc; + const Index inc; + const Scalar *alpha; + + const Scalar *a, *b; + Scalar *c; + + const bool is_alpha1; + const bool is_beta0; + + const Index a_stride, b_stride; + const Index a_off, b_off; + + static EIGEN_ALWAYS_INLINE constexpr int div_up(int a, int b) { + return a == 0 ? 0 : (a - 1) / b + 1; + } + + EIGEN_ALWAYS_INLINE void prefetch_a(const Scalar *a_addr) + { + _mm_prefetch((char *) (a_prefetch_size + a_addr - a_shift), _MM_HINT_T0); + } + + EIGEN_ALWAYS_INLINE void prefetch_b(const Scalar *b_addr) + { + _mm_prefetch((char *) (b_prefetch_size + b_addr - b_shift), _MM_HINT_T0); + } + + EIGEN_ALWAYS_INLINE void prefetch_x(const Scalar *x_addr) + { + _mm_prefetch((char *) (x_addr - a_shift), _MM_HINT_T2); + } + + EIGEN_ALWAYS_INLINE void prefetch_c(const Scalar *c_addr) + { +#if defined(__PRFCHW__) && __PRFCHW__ == 1 + _m_prefetchw((void *) c_addr); +#else + _mm_prefetch((char *) c_addr, _MM_HINT_T0); +#endif + } + + template + EIGEN_ALWAYS_INLINE void a_load(vec &a_reg, const Scalar *a_addr) + { + switch (nelems * sizeof(*a_addr) * 8) { + default: + case 512 * 3: a_reg = ploadu(a_addr); break; + case 512 * 2: a_reg = ploadu(a_addr); break; + case 512 * 1: a_reg = ploadu(a_addr); break; + case 256 * 1: a_reg = preinterpret(_mm512_broadcast_f64x4(ploadu(reinterpret_cast(a_addr)))); break; + case 128 * 1: a_reg = preinterpret(_mm512_broadcast_f32x4(ploadu(reinterpret_cast(a_addr)))); break; + case 64 * 1: a_reg = preinterpret(pload1(reinterpret_cast(a_addr))); break; + case 32 * 1: a_reg = pload1(a_addr); break; + } + } + + EIGEN_ALWAYS_INLINE void b_load(vec &b_reg, const Scalar *b_addr) + { + b_reg = pload1(b_addr); + } + + template + EIGEN_ALWAYS_INLINE void c_store(Scalar *mem, vec &src) + { + if (is_unit_inc) { + switch (nelems * sizeof(*mem) * 8) { + default: + case 512 * 3: pstoreu(mem, src); break; + case 512 * 2: pstoreu(mem, src); break; + case 512 * 1: pstoreu(mem, src); break; + case 256 * 1: pstoreu(mem, preinterpret(src)); break; + case 128 * 1: pstoreu(mem, preinterpret(src)); break; + case 64 * 1: pstorel(mem, preinterpret(src)); break; + case 32 * 1: pstores(mem, preinterpret(src)); break; + } + } else { + switch (nelems * sizeof(*mem) * 8) { + default: + case 512 * 3: pscatter(mem, src, inc); break; + case 512 * 2: pscatter(mem, src, inc); break; + case 512 * 1: pscatter(mem, src, inc); break; + case 256 * 1: pscatter(mem, src, inc, mask); break; + case 128 * 1: pscatter(mem, src, inc, mask); break; + case 64 * 1: pscatter(mem, src, inc, mask); break; + case 32 * 1: pscatter(mem, src, inc, mask); break; + } + } + } + + template + EIGEN_ALWAYS_INLINE void vaddm(vec &dst, const Scalar *mem, vec &src, vec ®) + { + if (is_unit_inc) { + switch (nelems * sizeof(*mem) * 8) { + default: + case 512 * 3: dst = padd(src, ploadu(mem)); break; + case 512 * 2: dst = padd(src, ploadu(mem)); break; + case 512 * 1: dst = padd(src, ploadu(mem)); break; + case 256 * 1: dst = preinterpret(padd(preinterpret(src), ploadu(mem))); break; + case 128 * 1: dst = preinterpret(padd(preinterpret(src), ploadu(mem))); break; + case 64 * 1: dst = preinterpret(padd(preinterpret(src), ploadl(mem))); break; + case 32 * 1: dst = preinterpret(padds(preinterpret(src), ploads(mem))); break; + } + } else { + // Zero out scratch register + reg = pzero(reg); + + switch (nelems * sizeof(*mem) * 8) { + default: + case 512 * 3: reg = pgather(mem, inc); dst = padd(src, reg); break; + case 512 * 2: reg = pgather(mem, inc); dst = padd(src, reg); break; + case 512 * 1: reg = pgather(mem, inc); dst = padd(src, reg); break; + case 256 * 1: reg = preinterpret(pgather(mem, inc)); dst = preinterpret(padd(preinterpret(src), preinterpret(reg))); break; + case 128 * 1: reg = preinterpret(pgather(mem, inc)); dst = preinterpret(padd(preinterpret(src), preinterpret(reg))); break; + case 64 * 1: if (is_f32) { + reg = pgather(reg, mem, inc, mask); + dst = preinterpret(padd(preinterpret(src), preinterpret(reg))); + } else { + dst = preinterpret(padd(preinterpret(src), ploadl(mem))); + } + break; + case 32 * 1: dst = preinterpret(padds(preinterpret(src), ploads(mem))); break; + } + } + } + + EIGEN_STRONG_INLINE void vfmadd(vec &dst, const vec &src1, const vec &src2) { + dst = pmadd(src1, src2, dst); + +#if (EIGEN_COMP_GNUC != 0) || (EIGEN_COMP_CLANG != 0) + // Workaround register spills for gcc and clang + __asm__ ("#" : [dst] "+v" (dst) : [src1] "%v" (src1), [src2] "v" (src2)); +#endif + } + + template + EIGEN_ALWAYS_INLINE void vfmaddm(vec &dst, const Scalar *mem, vec &src, vec &scale, vec ®) + { + if (is_unit_inc) { + switch (nelems * sizeof(*mem) * 8) { + default: + case 512 * 3: dst = pmadd(scale, src, ploadu(mem)); break; + case 512 * 2: dst = pmadd(scale, src, ploadu(mem)); break; + case 512 * 1: dst = pmadd(scale, src, ploadu(mem)); break; + case 256 * 1: dst = preinterpret(pmadd(preinterpret(scale), preinterpret(src), ploadu(mem))); break; + case 128 * 1: dst = preinterpret(pmadd(preinterpret(scale), preinterpret(src), ploadu(mem))); break; + case 64 * 1: dst = preinterpret(pmadd(preinterpret(scale), preinterpret(src), ploadl(mem))); break; + case 32 * 1: dst = preinterpret(pmadds(preinterpret(scale), preinterpret(src), ploads(mem))); break; + } + } else { + // Zero out scratch register + reg = pzero(reg); + + switch (nelems * sizeof(*mem) * 8) { + default: + case 512 * 3: reg = pgather(mem, inc); dst = pmadd(scale, src, reg); break; + case 512 * 2: reg = pgather(mem, inc); dst = pmadd(scale, src, reg); break; + case 512 * 1: reg = pgather(mem, inc); dst = pmadd(scale, src, reg); break; + case 256 * 1: reg = preinterpret(pgather(mem, inc)); dst = preinterpret(pmadd(preinterpret(scale), preinterpret(src), preinterpret(reg))); break; + case 128 * 1: reg = preinterpret(pgather(mem, inc)); dst = preinterpret(pmadd(preinterpret(scale), preinterpret(src), preinterpret(reg))); break; + case 64 * 1: if (is_f32) { + reg = pgather(reg, mem, inc, mask); + dst = preinterpret(pmadd(preinterpret(scale), preinterpret(src), preinterpret(reg))); + } else { + dst = preinterpret(pmadd(preinterpret(scale), preinterpret(src), ploadl(mem))); + } + break; + case 32 * 1: dst = preinterpret(pmadds(preinterpret(scale), preinterpret(src), ploads(mem))); break; + } + } + } + + template + EIGEN_ALWAYS_INLINE std::enable_if_t<(j > endX) || (i > endY)> + a_loads(const Scalar *ao) + { + EIGEN_UNUSED_VARIABLE(ao); + } + + template + EIGEN_ALWAYS_INLINE std::enable_if_t<(j <= endX) && (i <= endY)> + a_loads(const Scalar *ao) + { + if (j < endX) { + if (i < endY) { + auto &a_reg = zmm[a_regs[i + (j % 2) * 3]]; + const Scalar *a_addr = ao + nelems * j + nelems_in_cache_line * i - a_shift; + a_load(a_reg, a_addr); + + a_loads(ao); + } else { + a_loads(ao); + } + } + } + + template + EIGEN_ALWAYS_INLINE std::enable_if_t<(un > max_b_unroll) || (i > um_vecs)> + prefetch_cs(const Scalar *co1, const Scalar *co2) + { + EIGEN_UNUSED_VARIABLE(co1); + EIGEN_UNUSED_VARIABLE(co2); + } + + /* C prefetch loop structure. + * for (int un = 0; un < 8; un++) { + * if (b_unroll >= un + 1) { + * if (un == 4) co2 = co1 + 4 * ldc; + * + * for (int i = 0; i < um_vecs; i++) { + * Scalar *co = (un + 1 <= 4) ? co1 : co2; + * auto co_off = (un % 4) * ldc + a_unroll - 1 + i * nelems_in_cache_line * sizeof *co; + * prefetch_c(co + co_off); + * } + * } + * } + */ + + template + EIGEN_ALWAYS_INLINE std::enable_if_t<(un <= max_b_unroll) && (i <= um_vecs)> + prefetch_cs(Scalar *&co1, Scalar *&co2) + { + if (un < max_b_unroll) { + + if (b_unroll >= un + 1) { + if (un == 4 && i == 0) co2 = co1 + 4 * ldc; + + if (i < um_vecs) { + Scalar *co = (un + 1 <= 4) ? co1 : co2; + auto co_off = (un % 4) * ldc + a_unroll - 1 + i * nelems_in_cache_line * sizeof *co; + prefetch_c(co + co_off); + + prefetch_cs(co1, co2); + } else { + prefetch_cs(co1, co2); + } + + } else { + prefetch_cs(co1, co2); + } + } + } + + // load_c + template + EIGEN_ALWAYS_INLINE std::enable_if_t<(i > um_vecs)> + scale_load_c(const Scalar *cox, vec &alpha_reg) + { + EIGEN_UNUSED_VARIABLE(cox); + EIGEN_UNUSED_VARIABLE(alpha_reg); + } + + template + EIGEN_ALWAYS_INLINE std::enable_if_t<(i <= um_vecs)> + scale_load_c(const Scalar *cox, vec &alpha_reg) + { + if (i < um_vecs) { + auto &c_reg = zmm[c_regs[i + idx * 3]]; + auto &c_load_reg = zmm[c_load_regs[i % 3]]; + auto c_mem = cox; + if (is_unit_inc) + c_mem += i * nelems_in_cache_line; + else + c_mem += i * nelems_in_cache_line * inc; + + + if (!is_beta0 && is_alpha1) + vaddm(c_reg, c_mem, c_reg, c_load_reg); + else if (!is_beta0 && !is_alpha1) + vfmaddm(c_reg, c_mem, c_reg, alpha_reg, c_load_reg); + else if (is_beta0 && !is_alpha1) + c_reg = pmul(alpha_reg, c_reg); + + scale_load_c(cox, alpha_reg); + } + } + + // store_c + template + EIGEN_ALWAYS_INLINE std::enable_if_t<(i > um_vecs)> + write_c(Scalar *cox) + { + EIGEN_UNUSED_VARIABLE(cox); + } + + template + EIGEN_ALWAYS_INLINE std::enable_if_t<(i <= um_vecs)> + write_c(Scalar *cox) + { + if (i < um_vecs) { + auto &c_reg = zmm[c_regs[i + idx * 3]]; + auto c_mem = cox; + if (is_unit_inc) + c_mem += i * nelems_in_cache_line; + else + c_mem += i * nelems_in_cache_line * inc; + + c_store(c_mem, c_reg); + c_reg = pzero(c_reg); + + write_c(cox); + } + } + + /* C update loop structure. + * co2 = co1 + ldc; + * + * auto &alpha_reg = zmm[alpha_load_reg]; + * if (!is_alpha1) alpha_reg = pload1(alpha); + * + * int idx = 0; + * for (pow = 1; pow <= 8; pow <<= 1) { + * + * if (b_unroll >= pow) { + * for (count = 1; count < (pow + 1) / 2 + 1; count++) { + * if (pow >= 4) co2 += ldc; + * + * const Scalar *cox = (idx == 0) ? co1 : co2; + * + * const int um_vecs = div_up(a_unroll, nelems_in_cache_line); + * scale_load_c<0, um_vecs, idx, a_unroll>(cox, alpha_reg); + * write_c<0, um_vecs, idx, a_unroll>(cox); + * + * idx++; + * } + * } + * } + * + * if (b_unroll == 1) + * co1 += ldc; + * else + * co1 = co2 + ldc; + */ + + template + EIGEN_ALWAYS_INLINE void c_update_1count(Scalar *&cox) + { + if (pow >= 4) cox += ldc; + + const int um_vecs = div_up(a_unroll, nelems_in_cache_line); + auto &alpha_reg = zmm[alpha_load_reg]; + + scale_load_c<0, um_vecs, idx, a_unroll>(cox, alpha_reg); + write_c<0, um_vecs, idx, a_unroll>(cox); + } + + template + EIGEN_ALWAYS_INLINE void c_update_1pow(Scalar *&co1, Scalar *&co2) + { + constexpr int idx = pow / 2; + Scalar *&cox = idx == 0 ? co1 : co2; + + constexpr int max_count = (pow + 1) / 2; + static_assert(max_count <= 4, "Unsupported max_count."); + + if (1 <= max_count) c_update_1count(cox); + if (2 <= max_count) c_update_1count(cox); + if (3 <= max_count) c_update_1count(cox); + if (4 <= max_count) c_update_1count(cox); + } + + template + EIGEN_ALWAYS_INLINE void c_update(Scalar *&co1, Scalar *&co2) + { + auto &alpha_reg = zmm[alpha_load_reg]; + + co2 = co1 + ldc; + if (!is_alpha1) alpha_reg = pload1(alpha); + if (!is_unit_inc && a_unroll < nelems_in_cache_line) + mask = (umask_t)(1 << a_unroll) - 1; + + static_assert(max_b_unroll <= 8, "Unsupported max_b_unroll"); + + if (1 <= max_b_unroll && 1 <= b_unroll) c_update_1pow<1, a_unroll>(co1, co2); + if (2 <= max_b_unroll && 2 <= b_unroll) c_update_1pow<2, a_unroll>(co1, co2); + if (4 <= max_b_unroll && 4 <= b_unroll) c_update_1pow<4, a_unroll>(co1, co2); + if (8 <= max_b_unroll && 8 <= b_unroll) c_update_1pow<8, a_unroll>(co1, co2); + + if (b_unroll == 1) + co1 += ldc; + else + co1 = co2 + ldc; + } + + // compute + template + EIGEN_ALWAYS_INLINE std::enable_if_t<(um > um_vecs)> + compute(const Scalar *ao, const Scalar *bo, int &fetchA_idx, int &fetchB_idx, vec &b_reg) + { + EIGEN_UNUSED_VARIABLE(ao); + EIGEN_UNUSED_VARIABLE(bo); + EIGEN_UNUSED_VARIABLE(fetchA_idx); + EIGEN_UNUSED_VARIABLE(fetchB_idx); + EIGEN_UNUSED_VARIABLE(b_reg); + } + + template + EIGEN_ALWAYS_INLINE std::enable_if_t<(um <= um_vecs)> + compute(const Scalar *ao, const Scalar *bo, int &fetchA_idx, int &fetchB_idx, vec &b_reg) + { + if (um < um_vecs) { + auto &c_reg = zmm[c_regs[um + idx * 3]]; + auto &a_reg = zmm[a_regs[um + (uk % 2) * 3]]; + + vfmadd(c_reg, a_reg, b_reg); + + if (!fetch_x && um == 0 && (((idx == 0 || idx == 6) && (uk % 2 == 0 || is_f64 || ktail)) || (idx == 3 && (uk % 2 == 1 || is_f64 || ktail)))) { + prefetch_a(ao + nelems_in_cache_line * fetchA_idx); + fetchA_idx++; + } + + if (um == 0 && idx == 1 && (uk % 2 == 0 || is_f64 || ktail)) { + prefetch_b(bo + nelems_in_cache_line * fetchB_idx); + fetchB_idx++; + } + + compute(ao, bo, fetchA_idx, fetchB_idx, b_reg); + } + } + + // load_a + template + EIGEN_ALWAYS_INLINE std::enable_if_t<(um > um_vecs)> + load_a(const Scalar *ao) + { + EIGEN_UNUSED_VARIABLE(ao); + } + + template + EIGEN_ALWAYS_INLINE std::enable_if_t<(um <= um_vecs)> + load_a(const Scalar *ao) + { + if (um < um_vecs) { + auto &a_reg = zmm[a_regs[um + (uk % 2) * 3]]; + const Scalar *a_addr = ao + + nelems * (1 + !ktail * !use_less_a_regs + uk) + + nelems_in_cache_line * um - a_shift; + a_load(a_reg, a_addr); + + load_a(ao); + } + } + template + EIGEN_ALWAYS_INLINE std::enable_if_t<(count > (pow + 1) / 2)> + innerkernel_1pow(const Scalar *&aa, const Scalar * const &ao, const Scalar * const &bo, Scalar *&co2, int &fetchA_idx, int &fetchB_idx) + { + EIGEN_UNUSED_VARIABLE(aa); + EIGEN_UNUSED_VARIABLE(ao); + EIGEN_UNUSED_VARIABLE(bo); + EIGEN_UNUSED_VARIABLE(co2); + EIGEN_UNUSED_VARIABLE(fetchA_idx); + EIGEN_UNUSED_VARIABLE(fetchB_idx); + } + + template + EIGEN_ALWAYS_INLINE std::enable_if_t<(count <= (pow + 1) / 2)> + innerkernel_1pow(const Scalar *&aa, const Scalar * const &ao, const Scalar * const &bo, Scalar *&co2, int &fetchA_idx, int &fetchB_idx) + { + const int idx = (pow / 2) + count; + + if (count < (pow + 1) / 2) { + auto &b_reg = zmm[b_regs[idx % 2]]; + + if (fetch_x && uk == 3 && idx == 0) prefetch_x(aa); + if (fetch_x && uk == 3 && idx == 4) aa += 8; + + if (b_unroll >= pow) { + + compute<0, um_vecs, idx, uk, fetch_x, ktail>(ao, bo, fetchA_idx, fetchB_idx, b_reg); + + const Scalar *b_addr = bo + b_unroll * uk + idx + 1 + + (b_unroll > 1) * !use_less_b_regs - b_shift; + b_load(b_reg, b_addr); + } + + // Go to the next count. + innerkernel_1pow(aa, ao, bo, co2, fetchA_idx, fetchB_idx); + + } else { + // Maybe prefetch C data after count-loop. + if (pow == 2 && c_fetch) { + if (uk % 3 == 0 && uk > 0) { + co2 += ldc; + } else { + prefetch_c(co2 + (uk % 3) * nelems_in_cache_line); + } + } + } + } + + template + EIGEN_ALWAYS_INLINE void innerkernel_1uk(const Scalar *&aa, const Scalar * const &ao, const Scalar * const &bo, Scalar *&co2, int &fetchA_idx, int &fetchB_idx) + { + const int um_vecs = div_up(a_unroll, nelems_in_cache_line); + + if (max_b_unroll >= 1) innerkernel_1pow(aa, ao, bo, co2, fetchA_idx, fetchB_idx); + if (max_b_unroll >= 2) innerkernel_1pow(aa, ao, bo, co2, fetchA_idx, fetchB_idx); + if (max_b_unroll >= 4) innerkernel_1pow(aa, ao, bo, co2, fetchA_idx, fetchB_idx); + if (max_b_unroll >= 8) innerkernel_1pow(aa, ao, bo, co2, fetchA_idx, fetchB_idx); + + // Load A after pow-loop. + load_a<0, um_vecs, uk, a_unroll, ktail>(ao); + } + + /* Inner kernel loop structure. + * for (int uk = 0; uk < kfactor; uk++) { + * int idx = 0; + * + * for (pow = 1; pow < max_b_unroll << 1; pow <<= 1) { + * for (int count = 0; count < (pow + 1) / 2; count++) { + * auto &b_reg = zmm[b_regs[idx % 2]]; + * + * if (fetch_x && uk == 3 && idx == 0) prefetch_x(aa); + * if (fetch_x && uk == 3 && idx == 4) aa += 8; + * + * if (b_unroll >= pow) { + * compute<0, um_vecs, idx, uk, fetchx, ktail>(ao, bo, fetchA_idx, fetchB_idx, b_reg); + * + * const Scalar *b_addr = bo + b_unroll * uk + idx + 1 + (b_unroll > 1) - b_shift ; + * b_load(b_reg, b_addr); + * } + * idx++; + * } + * + * Maybe prefetch C data. + * if (pow == 2 && c_fetch) { + * if (uk % 3 == 0 && uk > 0) { + * co2 += ldc; + * } else { + * prefetch_c(co2 + (uk % 3) * nelems_in_cache_line); + * } + * } + * } + * + * Load A. + * load_a<0, um_vecs, uk, ktail, a_unroll>(ao); + * } + * + * Advance A/B pointers after uk-loop. + * ao += a_unroll * kfactor; + * bo += b_unroll * kfactor; + */ + + template + EIGEN_ALWAYS_INLINE void innerkernel(const Scalar *&aa, const Scalar *&ao, const Scalar *&bo, Scalar *&co2) + { + int fetchA_idx = 0; + int fetchB_idx = 0; + + const bool fetch_x = k_factor == max_k_factor; + const bool ktail = k_factor == 1; + + static_assert(k_factor <= 4 && k_factor > 0, + "innerkernel maximum k_factor supported is 4"); + + if (k_factor > 0) innerkernel_1uk<0, max_b_unroll, a_unroll, b_unroll, ktail, fetch_x, c_fetch>(aa, ao, bo, co2, fetchA_idx, fetchB_idx); + if (k_factor > 1) innerkernel_1uk<1, max_b_unroll, a_unroll, b_unroll, ktail, fetch_x, c_fetch>(aa, ao, bo, co2, fetchA_idx, fetchB_idx); + if (k_factor > 2) innerkernel_1uk<2, max_b_unroll, a_unroll, b_unroll, ktail, fetch_x, c_fetch>(aa, ao, bo, co2, fetchA_idx, fetchB_idx); + if (k_factor > 3) innerkernel_1uk<3, max_b_unroll, a_unroll, b_unroll, ktail, fetch_x, c_fetch>(aa, ao, bo, co2, fetchA_idx, fetchB_idx); + + // Advance A/B pointers after uk-loop. + ao += a_unroll * k_factor; + bo += b_unroll * k_factor; + } + + + template + EIGEN_ALWAYS_INLINE void kloop(const Scalar *&aa, const Scalar *&ao, const Scalar *&bo, Scalar *&co1, Scalar *&co2) + { + const int um_vecs = div_up(a_unroll, nelems_in_cache_line); + if (!use_less_a_regs) + a_loads<0, 2, 0, um_vecs, a_unroll>(ao); + else + a_loads<0, 1, 0, um_vecs, a_unroll>(ao); + + b_load(zmm[b_regs[0]], bo - b_shift + 0); + if (!use_less_b_regs) + b_load(zmm[b_regs[1]], bo - b_shift + 1); + +#ifndef SECOND_FETCH + prefetch_cs<0, max_b_unroll, 0, um_vecs, a_unroll, b_unroll>(co1, co2); +#endif // SECOND_FETCH + + // Unrolling k-loop by a factor of 4. + const int max_k_factor = 4; + Index loop_count = k / max_k_factor; + + if (loop_count > 0) { +#ifdef SECOND_FETCH + loop_count -= SECOND_FETCH; +#endif + while (loop_count > 0) { + innerkernel(aa, ao, bo, co2); + loop_count--; + } +#ifdef SECOND_FETCH + co2 = co1 + nelems_in_cache_line - 1; + + loop_count += b_unroll; + while (loop_count > 0) { + innerkernel(aa, ao, bo, co2); + loop_count--; + } + + loop_count += SECOND_FETCH - b_unroll; + while (loop_count > 0) { + innerkernel(aa, ao, bo, co2); + loop_count--; + } +#endif + } + + // k-loop remainder handling. + loop_count = k % max_k_factor; + while (loop_count > 0) { + innerkernel(aa, ao, bo, co2); + loop_count--; + } + + // Update C matrix. + c_update(co1, co2); + } + + template + EIGEN_ALWAYS_INLINE void nloop(const Scalar *&aa, const Scalar *&ao, const Scalar *&bo, Scalar *&co1, Scalar *&co2) + { + // Set A matrix pointer. + ao = a + a_off * a_unroll; + + // Set B matrix pointer if needed. + bo += b_unroll * b_off; + + kloop(aa, ao, bo, co1, co2); + + // Advance B matrix pointer if needed. + bo += b_unroll * (b_stride - k - b_off); + + // Advance prefetch A pointer. + aa += 16; + } + + template + EIGEN_ALWAYS_INLINE void mloop(const Scalar *&ao, const Scalar *&bo, Scalar *&co1, Scalar *&co2) + { + // Set prefetch A pointers. + const Scalar *aa = a + a_unroll * a_stride; + + // Set C matrix pointers. + co1 = c; + if (a_unroll >= max_a_unroll) co2 = c + 2 * ldc; + if (is_unit_inc) + c += a_unroll; + else + c += a_unroll * inc; + + // Set B matrix pointer. + bo = b; + + // Main n-loop. + for (Index i = n / max_b_unroll; i > 0; i--) + nloop(aa, ao, bo, co1, co2); + + // n-remainders. + if (n & 4 && max_b_unroll > 4) nloop(aa, ao, bo, co1, co2); +#if 0 + if (n & 2 && max_b_unroll > 2) nloop(aa, ao, bo, co1, co2); + if (n & 1 && max_b_unroll > 1) nloop(aa, ao, bo, co1, co2); +#else + // Copy kernels don't support tails of n = 2 for single/double precision. + // Loop over ones. + int n_rem = 2 * ((n & 2) != 0) + 1 * ((n & 1) != 0); + while (n_rem > 0) {nloop(aa, ao, bo, co1, co2); n_rem--;} +#endif + + // Advance A matrix pointer. + a = ao + a_unroll * (a_stride - k - a_off); + } + +public: + // Compute kernel unrolling C matrix by max_a_unroll x max_b_unroll. + template + EIGEN_ALWAYS_INLINE void compute_kern() + { + a -= -a_shift; + b -= -b_shift; + + const Scalar *ao = nullptr; + const Scalar *bo = nullptr; + Scalar *co1 = nullptr; + Scalar *co2 = nullptr; + + // Main m-loop. + for (; m >= max_a_unroll; m -= max_a_unroll) + mloop(ao, bo, co1, co2); + + // m-remainders. + if (m & 32 && max_a_unroll > 32) mloop<32, max_a_unroll, max_b_unroll>(ao, bo, co1, co2); + if (m & 16 && max_a_unroll > 16) mloop<16, max_a_unroll, max_b_unroll>(ao, bo, co1, co2); + if (m & 8 && max_a_unroll > 8) mloop< 8, max_a_unroll, max_b_unroll>(ao, bo, co1, co2); + if (m & 4 && max_a_unroll > 4) mloop< 4, max_a_unroll, max_b_unroll>(ao, bo, co1, co2); + if (m & 2 && max_a_unroll > 2 && is_f64) mloop< 2, max_a_unroll, max_b_unroll>(ao, bo, co1, co2); + if (m & 1 && max_a_unroll > 1 && is_f64) mloop< 1, max_a_unroll, max_b_unroll>(ao, bo, co1, co2); + + // Copy kernels don't support tails of m = 2 for single precision. + // Loop over ones. + if (is_f32) { + int m_rem = 2 * ((m & 2) != 0) + 1 * ((m & 1) != 0); + while (m_rem > 0) {mloop< 1, max_a_unroll, max_b_unroll>(ao, bo, co1, co2); m_rem--;} + } + } + + gemm_class(Index m_, Index n_, Index k_, Index ldc_, Index inc_, + const Scalar *alpha_, + const Scalar *a_, const Scalar *b_, Scalar *c_, + bool is_alpha1_, bool is_beta0_, + Index a_stride_, Index b_stride_, + Index a_off_, Index b_off_) + : m(m_) + , n(n_) + , k(k_) + , ldc(ldc_) + , inc(inc_) + , alpha(alpha_) + , a(a_) + , b(b_) + , c(c_) + , is_alpha1(is_alpha1_) + , is_beta0(is_beta0_) + , a_stride(a_stride_) + , b_stride(b_stride_) + , a_off(a_off_) + , b_off(b_off_) + { + // Zero out all accumulation registers. + zmm[8 ] = pzero(zmm[8 ]); + zmm[9 ] = pzero(zmm[9 ]); + zmm[10] = pzero(zmm[10]); + zmm[11] = pzero(zmm[11]); + zmm[12] = pzero(zmm[12]); + zmm[13] = pzero(zmm[13]); + zmm[14] = pzero(zmm[14]); + zmm[15] = pzero(zmm[15]); + zmm[16] = pzero(zmm[16]); + zmm[17] = pzero(zmm[17]); + zmm[18] = pzero(zmm[18]); + zmm[19] = pzero(zmm[19]); + zmm[20] = pzero(zmm[20]); + zmm[21] = pzero(zmm[21]); + zmm[22] = pzero(zmm[22]); + zmm[23] = pzero(zmm[23]); + zmm[24] = pzero(zmm[24]); + zmm[25] = pzero(zmm[25]); + zmm[26] = pzero(zmm[26]); + zmm[27] = pzero(zmm[27]); + zmm[28] = pzero(zmm[28]); + zmm[29] = pzero(zmm[29]); + zmm[30] = pzero(zmm[30]); + zmm[31] = pzero(zmm[31]); + } +}; + +// Compute kernel with max unroll support of: +// Single precision: +// max_a_unroll: 48, 32, 16, 8, 4, 2, 1 +// max_b_unroll: 8, 4, 2, 1 +// Double precision: +// max_a_unroll: 24, 16, 8, 4, 2, 1 +// max_b_unroll: 8, 4, 2, 1 +template +EIGEN_DONT_INLINE void gemm_kern_avx512( + Index m, Index n, Index k, + Scalar *alpha, const Scalar *a, const Scalar *b, Scalar *c, + Index ldc, Index inc = 1, + Index a_stride = -1, Index b_stride = -1, + Index a_off = 0, Index b_off = 0) +{ + if (a_stride == -1) a_stride = k; + if (b_stride == -1) b_stride = k; + + gemm_class g(m, n, k, ldc, inc, alpha, + a, b, c, is_alpha1, is_beta0, a_stride, b_stride, a_off, b_off); + g.template compute_kern(); +} + +template +class gebp_traits : + public gebp_traits { + using Base = gebp_traits; + +public: + enum {nr = Base::Vectorizable ? 8 : 4}; +}; + +template +class gebp_traits : + public gebp_traits { + using Base = gebp_traits; + +public: + enum {nr = Base::Vectorizable ? 8 : 4}; +}; + +template +struct gemm_pack_rhs +{ + typedef typename packet_traits::type Packet; + typedef typename DataMapper::LinearMapper LinearMapper; + enum { PacketSize = packet_traits::size }; + EIGEN_DONT_INLINE void operator()(Scalar* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride=0, Index offset=0); +}; + +template +EIGEN_DONT_INLINE void gemm_pack_rhs + ::operator()(Scalar* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride, Index offset) +{ + constexpr int nr = 8; + EIGEN_ASM_COMMENT("EIGEN PRODUCT PACK RHS COLMAJOR"); + EIGEN_UNUSED_VARIABLE(stride); + EIGEN_UNUSED_VARIABLE(offset); + eigen_assert(((!PanelMode) && stride==0 && offset==0) || (PanelMode && stride>=depth && offset<=stride)); + conj_if::IsComplex && Conjugate> cj; + Index packet_cols8 = nr>=8 ? (cols/8) * 8 : 0; + Index packet_cols4 = nr>=4 ? (cols/4) * 4 : 0; + Index count = 0; + const Index peeled_k = (depth/PacketSize)*PacketSize; + if(nr>=8) + { + for(Index j2=0; j2 kernel; + + kernel.packet[0] = dm0.template loadPacket(k); + kernel.packet[1] = dm1.template loadPacket(k); + kernel.packet[2] = dm2.template loadPacket(k); + kernel.packet[3] = dm3.template loadPacket(k); + kernel.packet[4] = dm4.template loadPacket(k); + kernel.packet[5] = dm5.template loadPacket(k); + kernel.packet[6] = dm6.template loadPacket(k); + kernel.packet[7] = dm7.template loadPacket(k); + + ptranspose(kernel); + + pstoreu(blockB+count+0*PacketSize, cj.pconj(kernel.packet[0])); + pstoreu(blockB+count+1*PacketSize, cj.pconj(kernel.packet[1%PacketSize])); + pstoreu(blockB+count+2*PacketSize, cj.pconj(kernel.packet[2%PacketSize])); + pstoreu(blockB+count+3*PacketSize, cj.pconj(kernel.packet[3%PacketSize])); + pstoreu(blockB+count+4*PacketSize, cj.pconj(kernel.packet[4%PacketSize])); + pstoreu(blockB+count+5*PacketSize, cj.pconj(kernel.packet[5%PacketSize])); + pstoreu(blockB+count+6*PacketSize, cj.pconj(kernel.packet[6%PacketSize])); + pstoreu(blockB+count+7*PacketSize, cj.pconj(kernel.packet[7%PacketSize])); + count+=8*PacketSize; + } + } + for(; k=4) + { + for(Index j2=packet_cols8; j2 kernel; + kernel.packet[0 ] = dm0.template loadPacket(k); + kernel.packet[1%PacketSize] = dm1.template loadPacket(k); + kernel.packet[2%PacketSize] = dm2.template loadPacket(k); + kernel.packet[3%PacketSize] = dm3.template loadPacket(k); + ptranspose(kernel); + pstoreu(blockB+count+0*PacketSize, cj.pconj(kernel.packet[0])); + pstoreu(blockB+count+1*PacketSize, cj.pconj(kernel.packet[1%PacketSize])); + pstoreu(blockB+count+2*PacketSize, cj.pconj(kernel.packet[2%PacketSize])); + pstoreu(blockB+count+3*PacketSize, cj.pconj(kernel.packet[3%PacketSize])); + count+=4*PacketSize; + } + } + for(; k +struct gemm_pack_rhs +{ + typedef typename packet_traits::type Packet; + typedef typename unpacket_traits::half HalfPacket; + typedef typename unpacket_traits::half>::half QuarterPacket; + typedef typename DataMapper::LinearMapper LinearMapper; + enum { PacketSize = packet_traits::size, + HalfPacketSize = unpacket_traits::size, + QuarterPacketSize = unpacket_traits::size}; + EIGEN_DONT_INLINE void operator()(Scalar* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride=0, Index offset=0) + { + constexpr int nr = 8; + EIGEN_ASM_COMMENT("EIGEN PRODUCT PACK RHS ROWMAJOR"); + EIGEN_UNUSED_VARIABLE(stride); + EIGEN_UNUSED_VARIABLE(offset); + eigen_assert(((!PanelMode) && stride==0 && offset==0) || (PanelMode && stride>=depth && offset<=stride)); + const bool HasHalf = (int)HalfPacketSize < (int)PacketSize; + const bool HasQuarter = (int)QuarterPacketSize < (int)HalfPacketSize; + conj_if::IsComplex && Conjugate> cj; + Index packet_cols8 = nr>=8 ? (cols/8) * 8 : 0; + Index packet_cols4 = nr>=4 ? (cols/4) * 4 : 0; + Index count = 0; + + if(nr>=8) + { + for(Index j2=0; j2(&rhs.data()[k*rhs.stride() + j2]); + Packet A = rhs.template loadPacket(k, j2); + pstoreu(blockB+count, cj.pconj(A)); + } else if (HasHalf && HalfPacketSize==8) { + HalfPacket A = rhs.template loadPacket(k, j2); + pstoreu(blockB+count, cj.pconj(A)); + } else if (HasQuarter && QuarterPacketSize==8) { + QuarterPacket A = rhs.template loadPacket(k, j2); + pstoreu(blockB+count, cj.pconj(A)); + } else if (PacketSize==4) { + // Packet A = ploadu(&rhs.data()[k*rhs.stride() + j2]); + // Packet B = ploadu(&rhs.data()[k*rhs.stride() + j2 + PacketSize]); + Packet A = rhs.template loadPacket(k, j2); + Packet B = rhs.template loadPacket(k, j2 + PacketSize); + pstoreu(blockB+count, cj.pconj(A)); + pstoreu(blockB+count+PacketSize, cj.pconj(B)); + } else { + // const Scalar* b0 = &rhs.data()[k*rhs.stride() + j2]; + const LinearMapper dm0 = rhs.getLinearMapper(k, j2); + blockB[count+0] = cj(dm0(0)); + blockB[count+1] = cj(dm0(1)); + blockB[count+2] = cj(dm0(2)); + blockB[count+3] = cj(dm0(3)); + blockB[count+4] = cj(dm0(4)); + blockB[count+5] = cj(dm0(5)); + blockB[count+6] = cj(dm0(6)); + blockB[count+7] = cj(dm0(7)); + } + count += 8; + } + // skip what we have after + if(PanelMode) count += 8 * (stride-offset-depth); + } + } + + if(nr>=4) + { + for(Index j2=packet_cols8; j2(k, j2); + pstoreu(blockB+count, cj.pconj(A)); + count += PacketSize; + } else if (HasHalf && HalfPacketSize==4) { + HalfPacket A = rhs.template loadPacket(k, j2); + pstoreu(blockB+count, cj.pconj(A)); + count += HalfPacketSize; + } else if (HasQuarter && QuarterPacketSize==4) { + QuarterPacket A = rhs.template loadPacket(k, j2); + pstoreu(blockB+count, cj.pconj(A)); + count += QuarterPacketSize; + } else { + const LinearMapper dm0 = rhs.getLinearMapper(k, j2); + blockB[count+0] = cj(dm0(0)); + blockB[count+1] = cj(dm0(1)); + blockB[count+2] = cj(dm0(2)); + blockB[count+3] = cj(dm0(3)); + count += 4; + } + } + // skip what we have after + if(PanelMode) count += 4 * (stride-offset-depth); + } + } + // copy the remaining columns one at a time (nr==1) + for(Index j2=packet_cols4; j2 +struct gebp_kernel +{ + EIGEN_ALWAYS_INLINE + void operator()(const DataMapper& res, const Scalar* blockA, const Scalar* blockB, + Index rows, Index depth, Index cols, Scalar alpha, + Index strideA=-1, Index strideB=-1, Index offsetA=0, Index offsetB=0); +}; + +template +EIGEN_ALWAYS_INLINE +void gebp_kernel + ::operator()(const DataMapper& res, const Scalar* blockA, const Scalar* blockB, + Index rows, Index depth, Index cols, Scalar alpha, + Index strideA, Index strideB, Index offsetA, Index offsetB) +{ + if (res.incr() == 1) { + if (alpha == 1) { + gemm_kern_avx512(rows, cols, depth, + &alpha, blockA, blockB, (Scalar *)res.data(), res.stride(), + res.incr(), strideA, strideB, offsetA, offsetB); + } else { + gemm_kern_avx512(rows, cols, depth, + &alpha, blockA, blockB, (Scalar *)res.data(), + res.stride(), res.incr(), strideA, strideB, offsetA, offsetB); + } + } else { + if (alpha == 1) { + gemm_kern_avx512(rows, cols, depth, + &alpha, blockA, blockB, (Scalar *)res.data(), res.stride(), + res.incr(), strideA, strideB, offsetA, offsetB); + } else { + gemm_kern_avx512(rows, cols, depth, + &alpha, blockA, blockB, (Scalar *)res.data(), res.stride(), + res.incr(), strideA, strideB, offsetA, offsetB); + } + } +} + +} // namespace Eigen +} // namespace internal + +#endif // GEMM_KERNEL_H diff --git a/Eigen/src/Core/arch/AVX512/PacketMath.h b/Eigen/src/Core/arch/AVX512/PacketMath.h index 337001b13..aab066aed 100644 --- a/Eigen/src/Core/arch/AVX512/PacketMath.h +++ b/Eigen/src/Core/arch/AVX512/PacketMath.h @@ -927,6 +927,35 @@ EIGEN_STRONG_INLINE void pstoreu(double* to, const Packet8d& from, uint8 EIGEN_DEBUG_UNALIGNED_STORE return _mm512_mask_storeu_pd(to, mask, from); } +template +EIGEN_DEVICE_FUNC inline Packet pgather(const Packet& src, const Scalar* from, + Index stride, typename unpacket_traits::mask_t umask); +template <> +EIGEN_DEVICE_FUNC inline Packet16f pgather(const Packet16f& src, + const float* from, + Index stride, + uint16_t umask) { + Packet16i stride_vector = _mm512_set1_epi32(convert_index(stride)); + Packet16i stride_multiplier = + _mm512_set_epi32(15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0); + Packet16i indices = _mm512_mullo_epi32(stride_vector, stride_multiplier); + __mmask16 mask = static_cast<__mmask16>(umask); + + return _mm512_mask_i32gather_ps(src, mask, indices, from, 4); +} +template <> +EIGEN_DEVICE_FUNC inline Packet8d pgather(const Packet8d& src, + const double* from, + Index stride, + uint8_t umask) { + Packet8i stride_vector = _mm256_set1_epi32(convert_index(stride)); + Packet8i stride_multiplier = _mm256_set_epi32(7, 6, 5, 4, 3, 2, 1, 0); + Packet8i indices = _mm256_mullo_epi32(stride_vector, stride_multiplier); + __mmask8 mask = static_cast<__mmask8>(umask); + + return _mm512_mask_i32gather_pd(src, mask, indices, from, 8); +} + template <> EIGEN_DEVICE_FUNC inline Packet16f pgather(const float* from, Index stride) { @@ -956,6 +985,33 @@ EIGEN_DEVICE_FUNC inline Packet16i pgather(const int* from, return _mm512_i32gather_epi32(indices, from, 4); } +template +EIGEN_DEVICE_FUNC inline void pscatter(Scalar* to, const Packet& from, + Index stride, typename unpacket_traits::mask_t umask); +template <> +EIGEN_DEVICE_FUNC inline void pscatter(float* to, + const Packet16f& from, + Index stride, + uint16_t umask) { + Packet16i stride_vector = _mm512_set1_epi32(convert_index(stride)); + Packet16i stride_multiplier = + _mm512_set_epi32(15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0); + Packet16i indices = _mm512_mullo_epi32(stride_vector, stride_multiplier); + __mmask16 mask = static_cast<__mmask16>(umask); + _mm512_mask_i32scatter_ps(to, mask, indices, from, 4); +} +template <> +EIGEN_DEVICE_FUNC inline void pscatter(double* to, + const Packet8d& from, + Index stride, + uint8_t umask) { + Packet8i stride_vector = _mm256_set1_epi32(convert_index(stride)); + Packet8i stride_multiplier = _mm256_set_epi32(7, 6, 5, 4, 3, 2, 1, 0); + Packet8i indices = _mm256_mullo_epi32(stride_vector, stride_multiplier); + __mmask8 mask = static_cast<__mmask8>(umask); + _mm512_mask_i32scatter_pd(to, mask, indices, from, 8); +} + template <> EIGEN_DEVICE_FUNC inline void pscatter(float* to, const Packet16f& from, @@ -1450,28 +1506,24 @@ EIGEN_DEVICE_FUNC inline void ptranspose(PacketBlock& kernel) { kernel.packet[5] = _mm512_castpd_ps(_mm512_unpackhi_pd(_mm512_castps_pd(T4),_mm512_castps_pd(T6))); kernel.packet[6] = _mm512_castpd_ps(_mm512_unpacklo_pd(_mm512_castps_pd(T5),_mm512_castps_pd(T7))); kernel.packet[7] = _mm512_castpd_ps(_mm512_unpackhi_pd(_mm512_castps_pd(T5),_mm512_castps_pd(T7))); - - T0 = _mm512_castpd_ps(_mm512_permutex_pd(_mm512_castps_pd(kernel.packet[4]), 0x4E)); - T0 = _mm512_mask_blend_ps(0xF0F0, kernel.packet[0], T0); - T4 = _mm512_castpd_ps(_mm512_permutex_pd(_mm512_castps_pd(kernel.packet[0]), 0x4E)); - T4 = _mm512_mask_blend_ps(0xF0F0, T4, kernel.packet[4]); - T1 = _mm512_castpd_ps(_mm512_permutex_pd(_mm512_castps_pd(kernel.packet[5]), 0x4E)); - T1 = _mm512_mask_blend_ps(0xF0F0, kernel.packet[1], T1); - T5 = _mm512_castpd_ps(_mm512_permutex_pd(_mm512_castps_pd(kernel.packet[1]), 0x4E)); - T5 = _mm512_mask_blend_ps(0xF0F0, T5, kernel.packet[5]); - T2 = _mm512_castpd_ps(_mm512_permutex_pd(_mm512_castps_pd(kernel.packet[6]), 0x4E)); - T2 = _mm512_mask_blend_ps(0xF0F0, kernel.packet[2], T2); - T6 = _mm512_castpd_ps(_mm512_permutex_pd(_mm512_castps_pd(kernel.packet[2]), 0x4E)); - T6 = _mm512_mask_blend_ps(0xF0F0, T6, kernel.packet[6]); - T3 = _mm512_castpd_ps(_mm512_permutex_pd(_mm512_castps_pd(kernel.packet[7]), 0x4E)); - T3 = _mm512_mask_blend_ps(0xF0F0, kernel.packet[3], T3); - T7 = _mm512_castpd_ps(_mm512_permutex_pd(_mm512_castps_pd(kernel.packet[3]), 0x4E)); - T7 = _mm512_mask_blend_ps(0xF0F0, T7, kernel.packet[7]); - kernel.packet[0] = T0; kernel.packet[1] = T1; - kernel.packet[2] = T2; kernel.packet[3] = T3; - kernel.packet[4] = T4; kernel.packet[5] = T5; - kernel.packet[6] = T6; kernel.packet[7] = T7; + T0 = _mm512_shuffle_f32x4(kernel.packet[0], kernel.packet[4], 0x44); + T1 = _mm512_shuffle_f32x4(kernel.packet[0], kernel.packet[4], 0xee); + T2 = _mm512_shuffle_f32x4(kernel.packet[1], kernel.packet[5], 0x44); + T3 = _mm512_shuffle_f32x4(kernel.packet[1], kernel.packet[5], 0xee); + T4 = _mm512_shuffle_f32x4(kernel.packet[2], kernel.packet[6], 0x44); + T5 = _mm512_shuffle_f32x4(kernel.packet[2], kernel.packet[6], 0xee); + T6 = _mm512_shuffle_f32x4(kernel.packet[3], kernel.packet[7], 0x44); + T7 = _mm512_shuffle_f32x4(kernel.packet[3], kernel.packet[7], 0xee); + + kernel.packet[0] = _mm512_shuffle_f32x4(T0, T2, 0x88); + kernel.packet[2] = _mm512_shuffle_f32x4(T0, T2, 0xdd); + kernel.packet[1] = _mm512_shuffle_f32x4(T4, T6, 0x88); + kernel.packet[3] = _mm512_shuffle_f32x4(T4, T6, 0xdd); + kernel.packet[4] = _mm512_shuffle_f32x4(T1, T3, 0x88); + kernel.packet[6] = _mm512_shuffle_f32x4(T1, T3, 0xdd); + kernel.packet[5] = _mm512_shuffle_f32x4(T5, T7, 0x88); + kernel.packet[7] = _mm512_shuffle_f32x4(T5, T7, 0xdd); } EIGEN_DEVICE_FUNC inline void ptranspose(PacketBlock& kernel) { diff --git a/Eigen/src/Core/arch/AVX512/TrsmUnrolls.inc b/Eigen/src/Core/arch/AVX512/TrsmUnrolls.inc index 22cb1c93d..03640c9b7 100644 --- a/Eigen/src/Core/arch/AVX512/TrsmUnrolls.inc +++ b/Eigen/src/Core/arch/AVX512/TrsmUnrolls.inc @@ -65,6 +65,57 @@ EIGEN_ALWAYS_INLINE auto remMask(int64_t m) { return 0; } +template +EIGEN_ALWAYS_INLINE void trans8x8blocks(PacketBlock& kernel); + +template <> +EIGEN_ALWAYS_INLINE void trans8x8blocks(PacketBlock& kernel) { + __m512 T0 = _mm512_unpacklo_ps(kernel.packet[0],kernel.packet[1]); + __m512 T1 = _mm512_unpackhi_ps(kernel.packet[0],kernel.packet[1]); + __m512 T2 = _mm512_unpacklo_ps(kernel.packet[2],kernel.packet[3]); + __m512 T3 = _mm512_unpackhi_ps(kernel.packet[2],kernel.packet[3]); + __m512 T4 = _mm512_unpacklo_ps(kernel.packet[4],kernel.packet[5]); + __m512 T5 = _mm512_unpackhi_ps(kernel.packet[4],kernel.packet[5]); + __m512 T6 = _mm512_unpacklo_ps(kernel.packet[6],kernel.packet[7]); + __m512 T7 = _mm512_unpackhi_ps(kernel.packet[6],kernel.packet[7]); + + kernel.packet[0] = _mm512_castpd_ps(_mm512_unpacklo_pd(_mm512_castps_pd(T0),_mm512_castps_pd(T2))); + kernel.packet[1] = _mm512_castpd_ps(_mm512_unpackhi_pd(_mm512_castps_pd(T0),_mm512_castps_pd(T2))); + kernel.packet[2] = _mm512_castpd_ps(_mm512_unpacklo_pd(_mm512_castps_pd(T1),_mm512_castps_pd(T3))); + kernel.packet[3] = _mm512_castpd_ps(_mm512_unpackhi_pd(_mm512_castps_pd(T1),_mm512_castps_pd(T3))); + kernel.packet[4] = _mm512_castpd_ps(_mm512_unpacklo_pd(_mm512_castps_pd(T4),_mm512_castps_pd(T6))); + kernel.packet[5] = _mm512_castpd_ps(_mm512_unpackhi_pd(_mm512_castps_pd(T4),_mm512_castps_pd(T6))); + kernel.packet[6] = _mm512_castpd_ps(_mm512_unpacklo_pd(_mm512_castps_pd(T5),_mm512_castps_pd(T7))); + kernel.packet[7] = _mm512_castpd_ps(_mm512_unpackhi_pd(_mm512_castps_pd(T5),_mm512_castps_pd(T7))); + + T0 = _mm512_castpd_ps(_mm512_permutex_pd(_mm512_castps_pd(kernel.packet[4]), 0x4E)); + T0 = _mm512_mask_blend_ps(0xF0F0, kernel.packet[0], T0); + T4 = _mm512_castpd_ps(_mm512_permutex_pd(_mm512_castps_pd(kernel.packet[0]), 0x4E)); + T4 = _mm512_mask_blend_ps(0xF0F0, T4, kernel.packet[4]); + T1 = _mm512_castpd_ps(_mm512_permutex_pd(_mm512_castps_pd(kernel.packet[5]), 0x4E)); + T1 = _mm512_mask_blend_ps(0xF0F0, kernel.packet[1], T1); + T5 = _mm512_castpd_ps(_mm512_permutex_pd(_mm512_castps_pd(kernel.packet[1]), 0x4E)); + T5 = _mm512_mask_blend_ps(0xF0F0, T5, kernel.packet[5]); + T2 = _mm512_castpd_ps(_mm512_permutex_pd(_mm512_castps_pd(kernel.packet[6]), 0x4E)); + T2 = _mm512_mask_blend_ps(0xF0F0, kernel.packet[2], T2); + T6 = _mm512_castpd_ps(_mm512_permutex_pd(_mm512_castps_pd(kernel.packet[2]), 0x4E)); + T6 = _mm512_mask_blend_ps(0xF0F0, T6, kernel.packet[6]); + T3 = _mm512_castpd_ps(_mm512_permutex_pd(_mm512_castps_pd(kernel.packet[7]), 0x4E)); + T3 = _mm512_mask_blend_ps(0xF0F0, kernel.packet[3], T3); + T7 = _mm512_castpd_ps(_mm512_permutex_pd(_mm512_castps_pd(kernel.packet[3]), 0x4E)); + T7 = _mm512_mask_blend_ps(0xF0F0, T7, kernel.packet[7]); + + kernel.packet[0] = T0; kernel.packet[1] = T1; + kernel.packet[2] = T2; kernel.packet[3] = T3; + kernel.packet[4] = T4; kernel.packet[5] = T5; + kernel.packet[6] = T6; kernel.packet[7] = T7; +} + +template <> +EIGEN_ALWAYS_INLINE void trans8x8blocks(PacketBlock& kernel) { + ptranspose(kernel); +} + /*** * Unrolls for tranposed C stores */ @@ -198,7 +249,7 @@ public: r.packet[5] = zmm.packet[packetIndexOffset + zmmStride*5]; r.packet[6] = zmm.packet[packetIndexOffset + zmmStride*6]; r.packet[7] = zmm.packet[packetIndexOffset + zmmStride*7]; - ptranspose(r); + trans8x8blocks(r); zmm.packet[packetIndexOffset + zmmStride*0] = r.packet[0]; zmm.packet[packetIndexOffset + zmmStride*1] = r.packet[1]; zmm.packet[packetIndexOffset + zmmStride*2] = r.packet[2]; diff --git a/Eigen/src/Core/arch/AVX512/TypeCasting.h b/Eigen/src/Core/arch/AVX512/TypeCasting.h index 8baced1d6..d28cca214 100644 --- a/Eigen/src/Core/arch/AVX512/TypeCasting.h +++ b/Eigen/src/Core/arch/AVX512/TypeCasting.h @@ -44,6 +44,34 @@ template<> EIGEN_STRONG_INLINE Packet8f preinterpret(const return _mm512_castps512_ps256(a); } +template<> EIGEN_STRONG_INLINE Packet4f preinterpret(const Packet16f& a) { + return _mm512_castps512_ps128(a); +} + +template<> EIGEN_STRONG_INLINE Packet4d preinterpret(const Packet8d& a) { + return _mm512_castpd512_pd256(a); +} + +template<> EIGEN_STRONG_INLINE Packet2d preinterpret(const Packet8d& a) { + return _mm512_castpd512_pd128(a); +} + +template<> EIGEN_STRONG_INLINE Packet16f preinterpret(const Packet8f& a) { + return _mm512_castps256_ps512(a); +} + +template<> EIGEN_STRONG_INLINE Packet16f preinterpret(const Packet4f& a) { + return _mm512_castps128_ps512(a); +} + +template<> EIGEN_STRONG_INLINE Packet8d preinterpret(const Packet4d& a) { + return _mm512_castpd256_pd512(a); +} + +template<> EIGEN_STRONG_INLINE Packet8d preinterpret(const Packet2d& a) { + return _mm512_castpd128_pd512(a); +} + template<> EIGEN_STRONG_INLINE Packet16f preinterpret(const Packet16f& a) { return a; } diff --git a/Eigen/src/Core/arch/SSE/PacketMath.h b/Eigen/src/Core/arch/SSE/PacketMath.h index 35490a602..e896040b5 100644 --- a/Eigen/src/Core/arch/SSE/PacketMath.h +++ b/Eigen/src/Core/arch/SSE/PacketMath.h @@ -285,6 +285,10 @@ template<> EIGEN_STRONG_INLINE Packet4i padd(const Packet4i& a, const template<> EIGEN_STRONG_INLINE Packet16b padd(const Packet16b& a, const Packet16b& b) { return _mm_or_si128(a,b); } +template EIGEN_STRONG_INLINE Packet padds(const Packet& a, const Packet& b); +template<> EIGEN_STRONG_INLINE Packet4f padds(const Packet4f& a, const Packet4f& b) { return _mm_add_ss(a,b); } +template<> EIGEN_STRONG_INLINE Packet2d padds(const Packet2d& a, const Packet2d& b) { return _mm_add_sd(a,b); } + template<> EIGEN_STRONG_INLINE Packet4f psub(const Packet4f& a, const Packet4f& b) { return _mm_sub_ps(a,b); } template<> EIGEN_STRONG_INLINE Packet2d psub(const Packet2d& a, const Packet2d& b) { return _mm_sub_pd(a,b); } template<> EIGEN_STRONG_INLINE Packet4i psub(const Packet4i& a, const Packet4i& b) { return _mm_sub_epi32(a,b); } @@ -370,6 +374,10 @@ template<> EIGEN_STRONG_INLINE Packet4f pnmadd(const Packet4f& a, const Packet4f template<> EIGEN_STRONG_INLINE Packet2d pnmadd(const Packet2d& a, const Packet2d& b, const Packet2d& c) { return _mm_fnmadd_pd(a,b,c); } template<> EIGEN_STRONG_INLINE Packet4f pnmsub(const Packet4f& a, const Packet4f& b, const Packet4f& c) { return _mm_fnmsub_ps(a,b,c); } template<> EIGEN_STRONG_INLINE Packet2d pnmsub(const Packet2d& a, const Packet2d& b, const Packet2d& c) { return _mm_fnmsub_pd(a,b,c); } + +template EIGEN_STRONG_INLINE Packet pmadds(const Packet& a, const Packet& b, const Packet& c); +template<> EIGEN_STRONG_INLINE Packet4f pmadds(const Packet4f& a, const Packet4f& b, const Packet4f& c) { return _mm_fmadd_ss(a,b,c); } +template<> EIGEN_STRONG_INLINE Packet2d pmadds(const Packet2d& a, const Packet2d& b, const Packet2d& c) { return _mm_fmadd_sd(a,b,c); } #endif #ifdef EIGEN_VECTORIZE_SSE4_1 @@ -746,6 +754,15 @@ template<> EIGEN_STRONG_INLINE Packet16b ploadu(const bool* from) return _mm_loadu_si128(reinterpret_cast(from)); } +// Load lower part of packet zero extending. +template EIGEN_STRONG_INLINE Packet ploadl(const typename unpacket_traits::type* from); +template<> EIGEN_STRONG_INLINE Packet4f ploadl(const float* from) { EIGEN_DEBUG_UNALIGNED_LOAD return _mm_castpd_ps(_mm_load_sd(reinterpret_cast(from))); } +template<> EIGEN_STRONG_INLINE Packet2d ploadl(const double* from) { EIGEN_DEBUG_UNALIGNED_LOAD return _mm_load_sd(from); } + +// Load scalar +template EIGEN_STRONG_INLINE Packet ploads(const typename unpacket_traits::type* from); +template<> EIGEN_STRONG_INLINE Packet4f ploads(const float* from) { EIGEN_DEBUG_UNALIGNED_LOAD return _mm_load_ss(from); } +template<> EIGEN_STRONG_INLINE Packet2d ploads(const double* from) { EIGEN_DEBUG_UNALIGNED_LOAD return _mm_load_sd(from); } template<> EIGEN_STRONG_INLINE Packet4f ploaddup(const float* from) { @@ -787,6 +804,14 @@ template<> EIGEN_STRONG_INLINE void pstoreu(float* to, const Packet4f& template<> EIGEN_STRONG_INLINE void pstoreu(int* to, const Packet4i& from) { EIGEN_DEBUG_UNALIGNED_STORE _mm_storeu_si128(reinterpret_cast<__m128i*>(to), from); } template<> EIGEN_STRONG_INLINE void pstoreu(bool* to, const Packet16b& from) { EIGEN_DEBUG_ALIGNED_STORE _mm_storeu_si128(reinterpret_cast<__m128i*>(to), from); } +template EIGEN_STRONG_INLINE void pstorel(Scalar* to, const Packet& from); +template<> EIGEN_STRONG_INLINE void pstorel(float* to, const Packet4f& from) { EIGEN_DEBUG_UNALIGNED_STORE _mm_storel_pi(reinterpret_cast<__m64*>(to), from); } +template<> EIGEN_STRONG_INLINE void pstorel(double* to, const Packet2d& from) { EIGEN_DEBUG_UNALIGNED_STORE _mm_storel_pd(to, from); } + +template EIGEN_STRONG_INLINE void pstores(Scalar* to, const Packet& from); +template<> EIGEN_STRONG_INLINE void pstores(float* to, const Packet4f& from) { EIGEN_DEBUG_UNALIGNED_STORE _mm_store_ss(to, from); } +template<> EIGEN_STRONG_INLINE void pstores(double* to, const Packet2d& from) { EIGEN_DEBUG_UNALIGNED_STORE _mm_store_sd(to, from); } + template<> EIGEN_DEVICE_FUNC inline Packet4f pgather(const float* from, Index stride) { return _mm_set_ps(from[3*stride], from[2*stride], from[1*stride], from[0*stride]); diff --git a/Eigen/src/Core/arch/SSE/TypeCasting.h b/Eigen/src/Core/arch/SSE/TypeCasting.h index c21d1acd9..a6346ea0e 100644 --- a/Eigen/src/Core/arch/SSE/TypeCasting.h +++ b/Eigen/src/Core/arch/SSE/TypeCasting.h @@ -71,6 +71,14 @@ template<> EIGEN_STRONG_INLINE Packet2d pcast(const Packet4f return _mm_cvtps_pd(a); } +template<> EIGEN_STRONG_INLINE Packet2d preinterpret(const Packet4f& a) { + return _mm_castps_pd(a); +} + +template<> EIGEN_STRONG_INLINE Packet4f preinterpret(const Packet2d& a) { + return _mm_castpd_ps(a); +} + template<> EIGEN_STRONG_INLINE Packet4i preinterpret(const Packet4f& a) { return _mm_castps_si128(a); } diff --git a/Eigen/src/Core/products/GeneralMatrixMatrix.h b/Eigen/src/Core/products/GeneralMatrixMatrix.h index 5262428b6..38bddacdb 100644 --- a/Eigen/src/Core/products/GeneralMatrixMatrix.h +++ b/Eigen/src/Core/products/GeneralMatrixMatrix.h @@ -287,7 +287,6 @@ class gemm_blocking_space LhsScalar; typedef std::conditional_t RhsScalar; - typedef gebp_traits Traits; enum { SizeA = ActualRows * MaxDepth, SizeB = ActualCols * MaxDepth @@ -336,7 +335,6 @@ class gemm_blocking_space LhsScalar; typedef std::conditional_t RhsScalar; - typedef gebp_traits Traits; Index m_sizeA; Index m_sizeB; diff --git a/Eigen/src/Core/util/BlasUtil.h b/Eigen/src/Core/util/BlasUtil.h index f45665e3b..e2eef19a9 100644 --- a/Eigen/src/Core/util/BlasUtil.h +++ b/Eigen/src/Core/util/BlasUtil.h @@ -229,6 +229,7 @@ public: } EIGEN_DEVICE_FUNC const Index stride() const { return m_stride; } + EIGEN_DEVICE_FUNC const Index incr() const { return 1; } EIGEN_DEVICE_FUNC const Scalar* data() const { return m_data; } EIGEN_DEVICE_FUNC Index firstAligned(Index size) const { @@ -402,6 +403,10 @@ public: storePacketBlock_helper spb; spb.store(this, i,j,block); } + + EIGEN_DEVICE_FUNC const Index stride() const { return m_stride; } + EIGEN_DEVICE_FUNC const Index incr() const { return m_incr.value(); } + EIGEN_DEVICE_FUNC Scalar* data() const { return m_data; } protected: Scalar* EIGEN_RESTRICT m_data; const Index m_stride;