From 9b9496ad980afe3626ba67ff871c15c3be8db620 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Antonio=20S=C3=A1nchez?= Date: Fri, 13 May 2022 18:50:33 +0000 Subject: [PATCH] Revert "Add AVX512 optimizations for matrix multiply" This reverts commit 25db0b4a824ba9a092bbb514fbada51bf9d37a18 --- Eigen/Core | 1 - Eigen/src/Core/arch/AVX512/GemmKernel.h | 973 ------------------ Eigen/src/Core/arch/AVX512/PacketMath.h | 66 +- Eigen/src/Core/arch/AVX512/TrsmUnrolls.inc | 2 +- Eigen/src/Core/arch/AVX512/TypeCasting.h | 28 - .../Core/arch/NEON/GeneralBlockPanelKernel.h | 18 +- Eigen/src/Core/arch/SSE/PacketMath.h | 25 - Eigen/src/Core/arch/SSE/TypeCasting.h | 8 - .../Core/products/GeneralBlockPanelKernel.h | 239 ++--- Eigen/src/Core/products/GeneralMatrixMatrix.h | 6 +- .../products/GeneralMatrixMatrixTriangular.h | 4 +- .../Core/products/GeneralMatrixMatrix_BLAS.h | 2 +- .../Core/products/SelfadjointMatrixMatrix.h | 4 +- .../Core/products/TriangularMatrixMatrix.h | 4 +- .../Core/products/TriangularSolverMatrix.h | 4 +- Eigen/src/Core/util/BlasUtil.h | 5 - unsupported/Eigen/MPRealSupport | 4 +- 17 files changed, 142 insertions(+), 1251 deletions(-) delete mode 100644 Eigen/src/Core/arch/AVX512/GemmKernel.h diff --git a/Eigen/Core b/Eigen/Core index 141460d15..7bbdee3cf 100644 --- a/Eigen/Core +++ b/Eigen/Core @@ -191,7 +191,6 @@ using std::ptrdiff_t; #include "src/Core/arch/AVX/MathFunctions.h" #include "src/Core/arch/AVX512/MathFunctions.h" #include "src/Core/arch/AVX512/TrsmKernel.h" - #include "src/Core/arch/AVX512/GemmKernel.h" #elif defined EIGEN_VECTORIZE_AVX // Use AVX for floats and doubles, SSE for integers #include "src/Core/arch/SSE/PacketMath.h" diff --git a/Eigen/src/Core/arch/AVX512/GemmKernel.h b/Eigen/src/Core/arch/AVX512/GemmKernel.h deleted file mode 100644 index 4b2df9d63..000000000 --- a/Eigen/src/Core/arch/AVX512/GemmKernel.h +++ /dev/null @@ -1,973 +0,0 @@ -#ifndef GEMM_KERNEL_H -#define GEMM_KERNEL_H - -#include -#include -#include - -#define SECOND_FETCH (32) - -#if (EIGEN_COMP_GNUC_STRICT != 0) && !defined(EIGEN_ARCH_AVX512_GEMM_KERNEL_USE_LESS_A_REGS) -// Use less registers to load A elements to workaround compiler spills. Loose a -// bit of performance (less than ~2%). -#define EIGEN_ARCH_AVX512_GEMM_KERNEL_USE_LESS_A_REGS -#endif - -namespace Eigen { -namespace internal { - -static inline constexpr int div_up(int a, int b) { - return (a + b - 1) / b; -} - -template -class gemm_class -{ - using vec = typename std::conditional::value, - Packet16f, Packet8d>::type; - using vec_ymm = typename std::conditional::value, - Packet8f, Packet4d>::type; - using vec_xmm = typename std::conditional::value, - Packet4f, Packet2d>::type; - - static constexpr bool is_f32 = sizeof(Scalar) == sizeof(float); - static constexpr bool is_f64 = sizeof(Scalar) == sizeof(double); - -#ifndef EIGEN_ARCH_AVX512_GEMM_KERNEL_USE_LESS_A_REGS - static constexpr int a_regs[] = {0, 1, 2, 3, 4, 5}; -#else - static constexpr int a_regs[] = {0, 1, 2, 0, 1, 2}; -#endif -#ifndef EIGEN_ARCH_AVX512_GEMM_KERNEL_USE_LESS_B_REGS - static constexpr int b_regs[] = {6, 7}; -#else - static constexpr int b_regs[] = {6, 6}; -#endif - static constexpr int c_regs[] = { - 8 , 16, 24, - 9 , 17, 25, - 10, 18, 26, - 11, 19, 27, - 12, 20, 28, - 13, 21, 29, - 14, 22, 30, - 15, 23, 31, - }; - - static constexpr int a_shift = 128; - static constexpr int b_shift = 128; - - static constexpr int nelems_in_cache_line = is_f32 ? 16 : 8; - static constexpr int a_prefetch_size = nelems_in_cache_line * 2; - static constexpr int b_prefetch_size = nelems_in_cache_line * 8; - - vec zmm[32]; - - // gemm arguments. - int64_t m; - const int64_t n, k, ldc; - const Scalar *alpha; - - const Scalar *a, *b; - Scalar *c; - - const bool is_alpha1; - const bool is_beta0; - - const int64_t a_stride, b_stride; - const int64_t a_off, b_off; - -public: - EIGEN_ALWAYS_INLINE void prefetch_a(const Scalar *a_addr) - { - _mm_prefetch((char *) (a_prefetch_size + a_addr - a_shift), _MM_HINT_T0); - } - - EIGEN_ALWAYS_INLINE void prefetch_b(const Scalar *b_addr) - { - _mm_prefetch((char *) (b_prefetch_size + b_addr - b_shift), _MM_HINT_T0); - } - - EIGEN_ALWAYS_INLINE void prefetch_x(const Scalar *x_addr) - { - _mm_prefetch((char *) (x_addr - a_shift), _MM_HINT_T2); - } - - EIGEN_ALWAYS_INLINE void prefetch_c(const Scalar *c_addr) - { -#if defined(__PRFCHW__) && __PRFCHW__ == 1 - _m_prefetchw((void *) c_addr); -#else - _mm_prefetch((char *) c_addr, _MM_HINT_T0); -#endif - } - - template - EIGEN_ALWAYS_INLINE void a_load(vec &a_reg, const Scalar *a_addr) - { - switch (nelems * sizeof(*a_addr) * 8) { - default: - case 512 * 3: a_reg = ploadu(a_addr); break; - case 512 * 2: a_reg = ploadu(a_addr); break; - case 512 * 1: a_reg = ploadu(a_addr); break; - case 256 * 1: a_reg = preinterpret(_mm512_broadcast_f64x4(ploadu(reinterpret_cast(a_addr)))); break; - case 128 * 1: a_reg = preinterpret(_mm512_broadcast_f32x4(ploadu(reinterpret_cast(a_addr)))); break; - case 64 * 1: a_reg = preinterpret(pload1(reinterpret_cast(a_addr))); break; - case 32 * 1: a_reg = pload1(a_addr); break; - } - } - - EIGEN_ALWAYS_INLINE void b_load(vec &b_reg, const Scalar *b_addr) - { - b_reg = pload1(b_addr); - } - - template - EIGEN_ALWAYS_INLINE void c_store(Scalar *mem, vec &src) - { - switch (nelems * sizeof(*mem) * 8) { - default: - case 512 * 3: pstoreu(mem, src); break; - case 512 * 2: pstoreu(mem, src); break; - case 512 * 1: pstoreu(mem, src); break; - case 256 * 1: pstoreu(mem, preinterpret(src)); break; - case 128 * 1: pstoreu(mem, preinterpret(src)); break; - case 64 * 1: pstorel(mem, preinterpret(src)); break; - case 32 * 1: pstores(mem, preinterpret(src)); break; - } - } - - template - EIGEN_ALWAYS_INLINE void vaddm(vec &dst, const Scalar *mem, vec &src) - { - switch (nelems * sizeof(*mem) * 8) { - default: - case 512 * 3: dst = padd(src, ploadu(mem)); break; - case 512 * 2: dst = padd(src, ploadu(mem)); break; - case 512 * 1: dst = padd(src, ploadu(mem)); break; - case 256 * 1: dst = preinterpret(padd(preinterpret(src), ploadu(mem))); break; - case 128 * 1: dst = preinterpret(padd(preinterpret(src), ploadu(mem))); break; - case 64 * 1: dst = preinterpret(padd(preinterpret(src), ploadl(mem))); break; - case 32 * 1: dst = preinterpret(padds(preinterpret(src), ploads(mem))); break; - } - } - - EIGEN_STRONG_INLINE void vfmadd(vec &dst, const vec &src1, const vec &src2) { - dst = pmadd(src1, src2, dst); - -#if (EIGEN_COMP_GNUC != 0) || (EIGEN_COMP_CLANG != 0) - // Workaround register spills for gcc and clang - __asm__ ("#" : [dst] "+v" (dst) : [src1] "%v" (src1), [src2] "v" (src2)); -#endif - } - - template - EIGEN_ALWAYS_INLINE void vfmaddm(vec &dst, const Scalar *mem, vec &src, vec &scale) - { - switch (nelems * sizeof(*mem) * 8) { - default: - case 512 * 3: dst = pmadd(scale, src, ploadu(mem)); break; - case 512 * 2: dst = pmadd(scale, src, ploadu(mem)); break; - case 512 * 1: dst = pmadd(scale, src, ploadu(mem)); break; - case 256 * 1: dst = preinterpret(pmadd(preinterpret(scale), preinterpret(src), ploadu(mem))); break; - case 128 * 1: dst = preinterpret(pmadd(preinterpret(scale), preinterpret(src), ploadu(mem))); break; - case 64 * 1: dst = preinterpret(pmadd(preinterpret(scale), preinterpret(src), ploadl(mem))); break; - case 32 * 1: dst = preinterpret(pmadds(preinterpret(scale), preinterpret(src), ploads(mem))); break; - } - } - - gemm_class(int64_t m_, int64_t n_, int64_t k_, int64_t ldc_, const Scalar *alpha_, - const Scalar *a_, const Scalar *b_, Scalar *c_, - bool is_alpha1_, bool is_beta0_, - int64_t a_stride_, int64_t b_stride_, - int64_t a_off_, int64_t b_off_) - : m(m_) - , n(n_) - , k(k_) - , ldc(ldc_) - , alpha(alpha_) - , a(a_) - , b(b_) - , c(c_) - , is_alpha1(is_alpha1_) - , is_beta0(is_beta0_) - , a_stride(a_stride_) - , b_stride(b_stride_) - , a_off(a_off_) - , b_off(b_off_) - { - - // Zero out all accumulation registers. - zmm[8 ] = pzero(zmm[8 ]); - zmm[9 ] = pzero(zmm[9 ]); - zmm[10] = pzero(zmm[10]); - zmm[11] = pzero(zmm[11]); - zmm[12] = pzero(zmm[12]); - zmm[13] = pzero(zmm[13]); - zmm[14] = pzero(zmm[14]); - zmm[15] = pzero(zmm[15]); - zmm[16] = pzero(zmm[16]); - zmm[17] = pzero(zmm[17]); - zmm[18] = pzero(zmm[18]); - zmm[19] = pzero(zmm[19]); - zmm[20] = pzero(zmm[20]); - zmm[21] = pzero(zmm[21]); - zmm[22] = pzero(zmm[22]); - zmm[23] = pzero(zmm[23]); - zmm[24] = pzero(zmm[24]); - zmm[25] = pzero(zmm[25]); - zmm[26] = pzero(zmm[26]); - zmm[27] = pzero(zmm[27]); - zmm[28] = pzero(zmm[28]); - zmm[29] = pzero(zmm[29]); - zmm[30] = pzero(zmm[30]); - zmm[31] = pzero(zmm[31]); - } - - template - EIGEN_ALWAYS_INLINE std::enable_if_t<(j > endX) || (i > endY)> - a_loads(const Scalar *ao) - { - EIGEN_UNUSED_VARIABLE(ao); - } - - template - EIGEN_ALWAYS_INLINE std::enable_if_t<(j <= endX) && (i <= endY)> - a_loads(const Scalar *ao) - { - if (j < endX) { - if (i < endY) { - auto &a_reg = zmm[a_regs[i + (j % 2) * 3]]; - const Scalar *a_addr = ao + nelems * j + nelems_in_cache_line * i - a_shift; - a_load(a_reg, a_addr); - - a_loads(ao); - } else { - a_loads(ao); - } - } - } - - template - EIGEN_ALWAYS_INLINE std::enable_if_t<(un > max_b_unroll) || (i > um_vecs)> - prefetch_cs(const Scalar *co1, const Scalar *co2) - { - EIGEN_UNUSED_VARIABLE(co1); - EIGEN_UNUSED_VARIABLE(co2); - } - - /* C prefetch loop structure. - * for (int un = 0; un < 8; un++) { - * if (b_unroll >= un + 1) { - * if (un == 4) co2 = co1 + 4 * ldc; - * - * for (int i = 0; i < um_vecs; i++) { - * Scalar *co = (un + 1 <= 4) ? co1 : co2; - * auto co_off = (un % 4) * ldc + a_unroll - 1 + i * nelems_in_cache_line * sizeof *co; - * prefetch_c(co + co_off); - * } - * } - * } - */ - - template - EIGEN_ALWAYS_INLINE std::enable_if_t<(un <= max_b_unroll) && (i <= um_vecs)> - prefetch_cs(Scalar *&co1, Scalar *&co2) - { - if (un < max_b_unroll) { - - if (b_unroll >= un + 1) { - if (un == 4 && i == 0) co2 = co1 + 4 * ldc; - - if (i < um_vecs) { - Scalar *co = (un + 1 <= 4) ? co1 : co2; - auto co_off = (un % 4) * ldc + a_unroll - 1 + i * nelems_in_cache_line * sizeof *co; - prefetch_c(co + co_off); - - prefetch_cs(co1, co2); - } else { - prefetch_cs(co1, co2); - } - - } else { - prefetch_cs(co1, co2); - } - } - } - - // load_c - template - EIGEN_ALWAYS_INLINE std::enable_if_t<(i > um_vecs)> - scale_load_c(const Scalar *cox, vec &alpha_reg) - { - EIGEN_UNUSED_VARIABLE(cox); - EIGEN_UNUSED_VARIABLE(alpha_reg); - } - - template - EIGEN_ALWAYS_INLINE std::enable_if_t<(i <= um_vecs)> - scale_load_c(const Scalar *cox, vec &alpha_reg) - { - - if (i < um_vecs) { - auto &c_reg = zmm[c_regs[i + idx * 3]]; - auto c_mem = cox + i * nelems_in_cache_line; - - if (!is_beta0 && is_alpha1) - vaddm(c_reg, c_mem, c_reg); - else if (!is_beta0 && !is_alpha1) - vfmaddm(c_reg, c_mem, c_reg, alpha_reg); - else if (is_beta0 && !is_alpha1) - c_reg = pmul(alpha_reg, c_reg); - - scale_load_c(cox, alpha_reg); - } - } - - // store_c - template - EIGEN_ALWAYS_INLINE std::enable_if_t<(i > um_vecs)> - write_c(Scalar *cox) - { - EIGEN_UNUSED_VARIABLE(cox); - } - - template - EIGEN_ALWAYS_INLINE std::enable_if_t<(i <= um_vecs)> - write_c(Scalar *cox) - { - if (i < um_vecs) { - auto &c_reg = zmm[c_regs[i + idx * 3]]; - auto c_mem = cox + i * nelems_in_cache_line; - - c_store(c_mem, c_reg); - c_reg = pzero(c_reg); - - write_c(cox); - } - } - - // update c matrix - template - EIGEN_ALWAYS_INLINE std::enable_if_t<(pow > (max_b_unroll << 1)) || (count > (pow + 1) / 2 + 1)> - c_update(Scalar *&co1, Scalar *&co2) - { - EIGEN_UNUSED_VARIABLE(co1); - EIGEN_UNUSED_VARIABLE(co2); - } - - /* C update loop structure. - * co2 = co1 + ldc; - * - * auto &alpha_reg = zmm[0]; - * if (!is_alpha1) alpha_reg = pload1(alpha); - * - * int idx = 0; - * for (pow = 1; pow <= 8; pow <<= 1) { - * - * if (b_unroll >= pow) { - * for (count = 1; count < (pow + 1) / 2 + 1; count++) { - * if (pow >= 4) co2 += ldc; - * - * const Scalar *cox = (idx == 0) ? co1 : co2; - * - * const int um_vecs = div_up(a_unroll, nelems_in_cache_line); - * scale_load_c<0, um_vecs, idx, a_unroll>(cox, alpha_reg); - * write_c<0, um_vecs, idx, a_unroll>(cox); - * - * idx++; - * } - * } - * } - * - * if (b_unroll == 1) - * co1 += ldc; - * else - * co1 = co2 + ldc; - */ - - template - EIGEN_ALWAYS_INLINE std::enable_if_t<(pow <= (max_b_unroll << 1)) && (count <= (pow + 1) / 2 + 1)> - c_update(Scalar *&co1, Scalar *&co2) - { - const bool first_call = idx == 0; - auto &alpha_reg = zmm[0]; - - if (first_call) { - co2 = co1 + ldc; - if (!is_alpha1) alpha_reg = pload1(alpha); - } - - if (pow < (max_b_unroll << 1) && pow <= b_unroll) { - if (count < (pow + 1) / 2 + 1) { - if (pow >= 4) co2 += ldc; - - Scalar *cox = idx == 0 ? co1 : co2; - - const int um_vecs = div_up(a_unroll, nelems_in_cache_line); - scale_load_c<0, um_vecs, idx, a_unroll>(cox, alpha_reg); - write_c<0, um_vecs, idx, a_unroll>(cox); - - // Go to the next count and next idx. - c_update(co1, co2); - } else { - // Go to the next pow and reset count. - c_update(co1, co2); - } - } else { - if (b_unroll == 1) - co1 += ldc; - else - co1 = co2 + ldc; - } - } - - // compute - template - EIGEN_ALWAYS_INLINE std::enable_if_t<(um > um_vecs)> - compute(const Scalar *ao, const Scalar *bo, int &fetchA_idx, int &fetchB_idx, vec &b_reg) - { - EIGEN_UNUSED_VARIABLE(ao); - EIGEN_UNUSED_VARIABLE(bo); - EIGEN_UNUSED_VARIABLE(fetchA_idx); - EIGEN_UNUSED_VARIABLE(fetchB_idx); - EIGEN_UNUSED_VARIABLE(b_reg); - } - - template - EIGEN_ALWAYS_INLINE std::enable_if_t<(um <= um_vecs)> - compute(const Scalar *ao, const Scalar *bo, int &fetchA_idx, int &fetchB_idx, vec &b_reg) - { - if (um < um_vecs) { - auto &c_reg = zmm[c_regs[um + idx * 3]]; - auto &a_reg = zmm[a_regs[um + (uk % 2) * 3]]; - - vfmadd(c_reg, a_reg, b_reg); - - if (!fetch_x && um == 0 && (((idx == 0 || idx == 6) && (uk % 2 == 0 || is_f64 || ktail)) || (idx == 3 && (uk % 2 == 1 || is_f64 || ktail)))) { - prefetch_a(ao + nelems_in_cache_line * fetchA_idx); - fetchA_idx++; - } - - if (um == 0 && idx == 1 && (uk % 2 == 0 || is_f64 || ktail)) { - prefetch_b(bo + nelems_in_cache_line * fetchB_idx); - fetchB_idx++; - } - - compute(ao, bo, fetchA_idx, fetchB_idx, b_reg); - } - } - - // load_a - template - EIGEN_ALWAYS_INLINE std::enable_if_t<(um > um_vecs)> - load_a(const Scalar *ao) - { - EIGEN_UNUSED_VARIABLE(ao); - } - - template - EIGEN_ALWAYS_INLINE std::enable_if_t<(um <= um_vecs)> - load_a(const Scalar *ao) - { - if (um < um_vecs) { - auto &a_reg = zmm[a_regs[um + (uk % 2) * 3]]; -#ifndef EIGEN_ARCH_AVX512_GEMM_KERNEL_USE_LESS_A_REGS - const Scalar *a_addr = ao + nelems * (1 + !ktail + uk) + nelems_in_cache_line * um - a_shift; -#else - const Scalar *a_addr = ao + nelems * (1 + uk) + nelems_in_cache_line * um - a_shift; -#endif - a_load(a_reg, a_addr); - - load_a(ao); - } - } - template - EIGEN_ALWAYS_INLINE std::enable_if_t<(count > (pow + 1) / 2)> - innerkernel_1pow(const Scalar *&aa, const Scalar * const &ao, const Scalar * const &bo, Scalar *&co2, int &fetchA_idx, int &fetchB_idx) - { - EIGEN_UNUSED_VARIABLE(aa); - EIGEN_UNUSED_VARIABLE(ao); - EIGEN_UNUSED_VARIABLE(bo); - EIGEN_UNUSED_VARIABLE(co2); - EIGEN_UNUSED_VARIABLE(fetchA_idx); - EIGEN_UNUSED_VARIABLE(fetchB_idx); - } - - template - EIGEN_ALWAYS_INLINE std::enable_if_t<(count <= (pow + 1) / 2)> - innerkernel_1pow(const Scalar *&aa, const Scalar * const &ao, const Scalar * const &bo, Scalar *&co2, int &fetchA_idx, int &fetchB_idx) - { - const int idx = (pow / 2) + count; - - if (count < (pow + 1) / 2) { - auto &b_reg = zmm[b_regs[idx % 2]]; - - if (fetch_x && uk == 3 && idx == 0) prefetch_x(aa); - if (fetch_x && uk == 3 && idx == 4) aa += 8; - - if (b_unroll >= pow) { - - compute<0, um_vecs, idx, uk, fetch_x, ktail>(ao, bo, fetchA_idx, fetchB_idx, b_reg); - -#ifndef EIGEN_ARCH_AVX512_GEMM_KERNEL_USE_LESS_B_REGS - const Scalar *b_addr = bo + b_unroll * uk + idx + 1 + (b_unroll > 1) - b_shift; -#else - const Scalar *b_addr = bo + b_unroll * uk + idx + 1 - b_shift; -#endif - b_load(b_reg, b_addr); - } - - // Go to the next count. - innerkernel_1pow(aa, ao, bo, co2, fetchA_idx, fetchB_idx); - - } else { - // Maybe prefetch C data after count-loop. - if (pow == 2 && c_fetch) { - if (uk % 3 == 0 && uk > 0) { - co2 += ldc; - } else { - prefetch_c(co2 + (uk % 3) * nelems_in_cache_line); - } - } - } - } - - template - EIGEN_ALWAYS_INLINE void innerkernel_1uk(const Scalar *&aa, const Scalar * const &ao, const Scalar * const &bo, Scalar *&co2, int &fetchA_idx, int &fetchB_idx) - { - const int um_vecs = div_up(a_unroll, nelems_in_cache_line); - - if (max_b_unroll >= 1) innerkernel_1pow(aa, ao, bo, co2, fetchA_idx, fetchB_idx); - if (max_b_unroll >= 2) innerkernel_1pow(aa, ao, bo, co2, fetchA_idx, fetchB_idx); - if (max_b_unroll >= 4) innerkernel_1pow(aa, ao, bo, co2, fetchA_idx, fetchB_idx); - if (max_b_unroll >= 8) innerkernel_1pow(aa, ao, bo, co2, fetchA_idx, fetchB_idx); - - // Load A after pow-loop. - load_a<0, um_vecs, uk, a_unroll, ktail>(ao); - } - - /* Inner kernel loop structure. - * for (int uk = 0; uk < kfactor; uk++) { - * int idx = 0; - * - * for (pow = 1; pow < max_b_unroll << 1; pow <<= 1) { - * for (int count = 0; count < (pow + 1) / 2; count++) { - * auto &b_reg = zmm[b_regs[idx % 2]]; - * - * if (fetch_x && uk == 3 && idx == 0) prefetch_x(aa); - * if (fetch_x && uk == 3 && idx == 4) aa += 8; - * - * if (b_unroll >= pow) { - * compute<0, um_vecs, idx, uk, fetchx, ktail>(ao, bo, fetchA_idx, fetchB_idx, b_reg); - * - * const Scalar *b_addr = bo + b_unroll * uk + idx + 1 + (b_unroll > 1) - b_shift ; - * b_load(b_reg, b_addr); - * } - * idx++; - * } - * - * Maybe prefetch C data. - * if (pow == 2 && c_fetch) { - * if (uk % 3 == 0 && uk > 0) { - * co2 += ldc; - * } else { - * prefetch_c(co2 + (uk % 3) * nelems_in_cache_line); - * } - * } - * } - * - * Load A. - * load_a<0, um_vecs, uk, ktail, a_unroll>(ao); - * } - * - * Advance A/B pointers after uk-loop. - * ao += a_unroll * kfactor; - * bo += b_unroll * kfactor; - */ - - template - EIGEN_ALWAYS_INLINE void innerkernel(const Scalar *&aa, const Scalar *&ao, const Scalar *&bo, Scalar *&co2) - { - int fetchA_idx = 0; - int fetchB_idx = 0; - - const bool fetch_x = k_factor == max_k_factor; - const bool ktail = k_factor == 1; - - static_assert(k_factor <= 4 && k_factor > 0, - "innerkernel maximum k_factor supported is 4"); - - if (k_factor > 0) innerkernel_1uk<0, max_b_unroll, a_unroll, b_unroll, ktail, fetch_x, c_fetch>(aa, ao, bo, co2, fetchA_idx, fetchB_idx); - if (k_factor > 1) innerkernel_1uk<1, max_b_unroll, a_unroll, b_unroll, ktail, fetch_x, c_fetch>(aa, ao, bo, co2, fetchA_idx, fetchB_idx); - if (k_factor > 2) innerkernel_1uk<2, max_b_unroll, a_unroll, b_unroll, ktail, fetch_x, c_fetch>(aa, ao, bo, co2, fetchA_idx, fetchB_idx); - if (k_factor > 3) innerkernel_1uk<3, max_b_unroll, a_unroll, b_unroll, ktail, fetch_x, c_fetch>(aa, ao, bo, co2, fetchA_idx, fetchB_idx); - - // Advance A/B pointers after uk-loop. - ao += a_unroll * k_factor; - bo += b_unroll * k_factor; - } - - - template - EIGEN_ALWAYS_INLINE void kloop(const Scalar *&aa, const Scalar *&ao, const Scalar *&bo, Scalar *&co1, Scalar *&co2) - { - const int um_vecs = div_up(a_unroll, nelems_in_cache_line); -#ifndef EIGEN_ARCH_AVX512_GEMM_KERNEL_USE_LESS_A_REGS - a_loads<0, 2, 0, um_vecs, a_unroll>(ao); -#else - a_loads<0, 1, 0, um_vecs, a_unroll>(ao); -#endif - - b_load(zmm[b_regs[0]], bo - b_shift + 0); -#ifndef EIGEN_ARCH_AVX512_GEMM_KERNEL_USE_LESS_B_REGS - b_load(zmm[b_regs[1]], bo - b_shift + 1); -#endif - -#ifndef SECOND_FETCH - prefetch_cs<0, max_b_unroll, 0, um_vecs, a_unroll, b_unroll>(co1, co2); -#endif // SECOND_FETCH - - // Unrolling k-loop by a factor of 4. - const int max_k_factor = 4; - int64_t loop_count = k / max_k_factor; - - if (loop_count > 0) { -#ifdef SECOND_FETCH - loop_count -= SECOND_FETCH; -#endif - while (loop_count > 0) { - innerkernel(aa, ao, bo, co2); - loop_count--; - } -#ifdef SECOND_FETCH - co2 = co1 + nelems_in_cache_line - 1; - - loop_count += b_unroll; - while (loop_count > 0) { - innerkernel(aa, ao, bo, co2); - loop_count--; - } - - loop_count += SECOND_FETCH - b_unroll; - while (loop_count > 0) { - innerkernel(aa, ao, bo, co2); - loop_count--; - } -#endif - } - - // k-loop remainder handling. - loop_count = k % max_k_factor; - while (loop_count > 0) { - innerkernel(aa, ao, bo, co2); - loop_count--; - } - - // Update C matrix. - c_update<1, max_b_unroll, 1, a_unroll, b_unroll, 0>(co1, co2); - } - - template - EIGEN_ALWAYS_INLINE void nloop(const Scalar *&aa, const Scalar *&ao, const Scalar *&bo, Scalar *&co1, Scalar *&co2) - { - // Set A matrix pointer. - ao = a + a_off * a_unroll; - - // Set B matrix pointer if needed. - bo += b_unroll * b_off; - - kloop(aa, ao, bo, co1, co2); - - // Advance B matrix pointer if needed. - bo += b_unroll * (b_stride - k - b_off); - - // Advance prefetch A pointer. - aa += 16; - } - - template - EIGEN_ALWAYS_INLINE void mloop(const Scalar *&ao, const Scalar *&bo, Scalar *&co1, Scalar *&co2) - { - // Set prefetch A pointers. - const Scalar *aa = a + a_unroll * a_stride; - - // Set C matrix pointers. - co1 = c; - if (a_unroll >= max_a_unroll) co2 = c + 2 * ldc; - c += a_unroll; - - // Set B matrix pointer. - bo = b; - - // Main n-loop. - for (int64_t i = n / max_b_unroll; i > 0; i--) - nloop(aa, ao, bo, co1, co2); - - // n-remainders. - if (n & 4 && max_b_unroll > 4) nloop(aa, ao, bo, co1, co2); -#if 0 - if (n & 2 && max_b_unroll > 2) nloop(aa, ao, bo, co1, co2); - if (n & 1 && max_b_unroll > 1) nloop(aa, ao, bo, co1, co2); -#else - // Copy kernels don't support tails of n = 2 for single/double precision. - // Loop over ones. - int n_rem = 2 * ((n & 2) != 0) + 1 * ((n & 1) != 0); - while (n_rem > 0) {nloop(aa, ao, bo, co1, co2); n_rem--;} -#endif - - // Advance A matrix pointer. - a = ao + a_unroll * (a_stride - k - a_off); - } - - // Compute kernel unrolling C matrix by max_a_unroll x max_b_unroll. - template - EIGEN_ALWAYS_INLINE void compute_kern() - { - a -= -a_shift; - b -= -b_shift; - - const Scalar *ao = nullptr; - const Scalar *bo = nullptr; - Scalar *co1 = nullptr; - Scalar *co2 = nullptr; - - // Main m-loop. - for (; m >= max_a_unroll; m -= max_a_unroll) - mloop(ao, bo, co1, co2); - - // m-remainders. - if (m & 32 && max_a_unroll > 32) mloop<32, max_a_unroll, max_b_unroll>(ao, bo, co1, co2); - if (m & 16 && max_a_unroll > 16) mloop<16, max_a_unroll, max_b_unroll>(ao, bo, co1, co2); - if (m & 8 && max_a_unroll > 8) mloop< 8, max_a_unroll, max_b_unroll>(ao, bo, co1, co2); - if (m & 4 && max_a_unroll > 4) mloop< 4, max_a_unroll, max_b_unroll>(ao, bo, co1, co2); - if (m & 2 && max_a_unroll > 2 && is_f64) mloop< 2, max_a_unroll, max_b_unroll>(ao, bo, co1, co2); - if (m & 1 && max_a_unroll > 1 && is_f64) mloop< 1, max_a_unroll, max_b_unroll>(ao, bo, co1, co2); - - // Copy kernels don't support tails of m = 2 for single precision. - // Loop over ones. - if (is_f32) { - int m_rem = 2 * ((m & 2) != 0) + 1 * ((m & 1) != 0); - while (m_rem > 0) {mloop< 1, max_a_unroll, max_b_unroll>(ao, bo, co1, co2); m_rem--;} - } - } -}; - -// Compute kernel with max unroll support of: -// Single precision: -// max_a_unroll: 48, 32, 16, 8, 4, 2, 1 -// max_b_unroll: 8, 4, 2, 1 -// Double precision: -// max_a_unroll: 24, 16, 8, 4, 2, 1 -// max_b_unroll: 8, 4, 2, 1 -template -EIGEN_DONT_INLINE void gemm_kern_avx512(int64_t *p_m, int64_t *p_n, int64_t *p_k, - Scalar *alpha, const Scalar *a, const Scalar *b, Scalar *c, - int64_t ldc, int64_t a_stride = -1, int64_t b_stride = -1, - int64_t a_off = 0, int64_t b_off = 0) -{ - if (a_stride == -1) a_stride = *p_k; - if (b_stride == -1) b_stride = *p_k; - - gemm_class g(*p_m, *p_n, *p_k, ldc, alpha, a, b, c, - is_alpha1, is_beta0, a_stride, b_stride, a_off, b_off); - g.template compute_kern(); -} - -template -bool gemm_kernel(int64_t m, int64_t n, int64_t k, c_t alpha, - const a_t *a, const b_t *b, c_t *c, int64_t ldc, - int64_t a_stride = -1, int64_t b_stride = -1, - int64_t a_off = 0, int64_t b_off = 0) -{ - EIGEN_UNUSED_VARIABLE(m); - EIGEN_UNUSED_VARIABLE(n); - EIGEN_UNUSED_VARIABLE(k); - EIGEN_UNUSED_VARIABLE(alpha); - EIGEN_UNUSED_VARIABLE(a); - EIGEN_UNUSED_VARIABLE(b); - EIGEN_UNUSED_VARIABLE(c); - EIGEN_UNUSED_VARIABLE(ldc); - EIGEN_UNUSED_VARIABLE(a_stride); - EIGEN_UNUSED_VARIABLE(b_stride); - EIGEN_UNUSED_VARIABLE(a_off); - EIGEN_UNUSED_VARIABLE(b_off); - return false; -} - -template <> -bool gemm_kernel(int64_t m, int64_t n, int64_t k, float alpha, - const float *a, const float *b, float *c, int64_t ldc, - int64_t a_stride, int64_t b_stride, - int64_t a_off, int64_t b_off) -{ - if (alpha == 1.f) - gemm_kern_avx512(&m, &n, &k, &alpha, a, b, c, - ldc, a_stride, b_stride, a_off, b_off); - else - gemm_kern_avx512(&m, &n, &k, &alpha, a, b, c, - ldc, a_stride, b_stride, a_off, b_off); - - return true; -} - -template <> -bool gemm_kernel(int64_t m, int64_t n, int64_t k, double alpha, - const double *a, const double *b, double *c, int64_t ldc, - int64_t a_stride, int64_t b_stride, - int64_t a_off, int64_t b_off) -{ - if (alpha == 1.) - gemm_kern_avx512(&m, &n, &k, &alpha, a, b, c, - ldc, a_stride, b_stride, a_off, b_off); - else - gemm_kern_avx512(&m, &n, &k, &alpha, a, b, c, - ldc, a_stride, b_stride, a_off, b_off); - - return true; -} - -template -struct gemm_pack_rhs; - -template -struct gemm_pack_rhs -{ - typedef typename packet_traits::type Packet; - typedef typename DataMapper::LinearMapper LinearMapper; - enum { PacketSize = packet_traits::size }; - EIGEN_DONT_INLINE void operator()(Scalar* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride=0, Index offset=0); -}; - -template -EIGEN_DONT_INLINE void gemm_pack_rhs - ::operator()(Scalar* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride, Index offset) -{ - constexpr int nr = 8; - EIGEN_ASM_COMMENT("EIGEN PRODUCT PACK RHS COLMAJOR"); - EIGEN_UNUSED_VARIABLE(stride); - EIGEN_UNUSED_VARIABLE(offset); - eigen_assert(((!PanelMode) && stride==0 && offset==0) || (PanelMode && stride>=depth && offset<=stride)); - conj_if::IsComplex && Conjugate> cj; - Index packet_cols8 = nr>=8 ? (cols/8) * 8 : 0; - Index packet_cols4 = nr>=4 ? (cols/4) * 4 : 0; - Index count = 0; - const Index peeled_k = (depth/PacketSize)*PacketSize; - if(nr>=8) - { - for(Index j2=0; j2 kernel; - - kernel.packet[0] = dm0.template loadPacket(k); - kernel.packet[1] = dm1.template loadPacket(k); - kernel.packet[2] = dm2.template loadPacket(k); - kernel.packet[3] = dm3.template loadPacket(k); - kernel.packet[4] = dm4.template loadPacket(k); - kernel.packet[5] = dm5.template loadPacket(k); - kernel.packet[6] = dm6.template loadPacket(k); - kernel.packet[7] = dm7.template loadPacket(k); - - ptranspose(kernel); - - pstoreu(blockB+count+0*PacketSize, cj.pconj(kernel.packet[0])); - pstoreu(blockB+count+1*PacketSize, cj.pconj(kernel.packet[1%PacketSize])); - pstoreu(blockB+count+2*PacketSize, cj.pconj(kernel.packet[2%PacketSize])); - pstoreu(blockB+count+3*PacketSize, cj.pconj(kernel.packet[3%PacketSize])); - pstoreu(blockB+count+4*PacketSize, cj.pconj(kernel.packet[4%PacketSize])); - pstoreu(blockB+count+5*PacketSize, cj.pconj(kernel.packet[5%PacketSize])); - pstoreu(blockB+count+6*PacketSize, cj.pconj(kernel.packet[6%PacketSize])); - pstoreu(blockB+count+7*PacketSize, cj.pconj(kernel.packet[7%PacketSize])); - count+=8*PacketSize; - } - } - for(; k=4) - { - for(Index j2=packet_cols8; j2 kernel; - kernel.packet[0 ] = dm0.template loadPacket(k); - kernel.packet[1%PacketSize] = dm1.template loadPacket(k); - kernel.packet[2%PacketSize] = dm2.template loadPacket(k); - kernel.packet[3%PacketSize] = dm3.template loadPacket(k); - ptranspose(kernel); - pstoreu(blockB+count+0*PacketSize, cj.pconj(kernel.packet[0])); - pstoreu(blockB+count+1*PacketSize, cj.pconj(kernel.packet[1%PacketSize])); - pstoreu(blockB+count+2*PacketSize, cj.pconj(kernel.packet[2%PacketSize])); - pstoreu(blockB+count+3*PacketSize, cj.pconj(kernel.packet[3%PacketSize])); - count+=4*PacketSize; - } - } - for(; k& kernel) { EIGEN_INSERT_8f_INTO_16f(OUTPUT[INDEX], INPUT[2 * INDEX], \ INPUT[2 * INDEX + STRIDE]); -template EIGEN_DEVICE_FUNC inline void ptranspose(PacketBlock& kernel) { __m512 T0 = _mm512_unpacklo_ps(kernel.packet[0],kernel.packet[1]); __m512 T1 = _mm512_unpackhi_ps(kernel.packet[0],kernel.packet[1]); @@ -1451,49 +1450,28 @@ EIGEN_DEVICE_FUNC inline void ptranspose(PacketBlock& kernel) { kernel.packet[5] = _mm512_castpd_ps(_mm512_unpackhi_pd(_mm512_castps_pd(T4),_mm512_castps_pd(T6))); kernel.packet[6] = _mm512_castpd_ps(_mm512_unpacklo_pd(_mm512_castps_pd(T5),_mm512_castps_pd(T7))); kernel.packet[7] = _mm512_castpd_ps(_mm512_unpackhi_pd(_mm512_castps_pd(T5),_mm512_castps_pd(T7))); + + T0 = _mm512_castpd_ps(_mm512_permutex_pd(_mm512_castps_pd(kernel.packet[4]), 0x4E)); + T0 = _mm512_mask_blend_ps(0xF0F0, kernel.packet[0], T0); + T4 = _mm512_castpd_ps(_mm512_permutex_pd(_mm512_castps_pd(kernel.packet[0]), 0x4E)); + T4 = _mm512_mask_blend_ps(0xF0F0, T4, kernel.packet[4]); + T1 = _mm512_castpd_ps(_mm512_permutex_pd(_mm512_castps_pd(kernel.packet[5]), 0x4E)); + T1 = _mm512_mask_blend_ps(0xF0F0, kernel.packet[1], T1); + T5 = _mm512_castpd_ps(_mm512_permutex_pd(_mm512_castps_pd(kernel.packet[1]), 0x4E)); + T5 = _mm512_mask_blend_ps(0xF0F0, T5, kernel.packet[5]); + T2 = _mm512_castpd_ps(_mm512_permutex_pd(_mm512_castps_pd(kernel.packet[6]), 0x4E)); + T2 = _mm512_mask_blend_ps(0xF0F0, kernel.packet[2], T2); + T6 = _mm512_castpd_ps(_mm512_permutex_pd(_mm512_castps_pd(kernel.packet[2]), 0x4E)); + T6 = _mm512_mask_blend_ps(0xF0F0, T6, kernel.packet[6]); + T3 = _mm512_castpd_ps(_mm512_permutex_pd(_mm512_castps_pd(kernel.packet[7]), 0x4E)); + T3 = _mm512_mask_blend_ps(0xF0F0, kernel.packet[3], T3); + T7 = _mm512_castpd_ps(_mm512_permutex_pd(_mm512_castps_pd(kernel.packet[3]), 0x4E)); + T7 = _mm512_mask_blend_ps(0xF0F0, T7, kernel.packet[7]); - // Transpose for gemm is slightly different than trsm. - if (!for_trsm) { - T0 = _mm512_shuffle_f32x4(kernel.packet[0], kernel.packet[4], 0x44); - T1 = _mm512_shuffle_f32x4(kernel.packet[0], kernel.packet[4], 0xee); - T2 = _mm512_shuffle_f32x4(kernel.packet[1], kernel.packet[5], 0x44); - T3 = _mm512_shuffle_f32x4(kernel.packet[1], kernel.packet[5], 0xee); - T4 = _mm512_shuffle_f32x4(kernel.packet[2], kernel.packet[6], 0x44); - T5 = _mm512_shuffle_f32x4(kernel.packet[2], kernel.packet[6], 0xee); - T6 = _mm512_shuffle_f32x4(kernel.packet[3], kernel.packet[7], 0x44); - T7 = _mm512_shuffle_f32x4(kernel.packet[3], kernel.packet[7], 0xee); - - kernel.packet[0] = _mm512_shuffle_f32x4(T0, T2, 0x88); - kernel.packet[2] = _mm512_shuffle_f32x4(T0, T2, 0xdd); - kernel.packet[1] = _mm512_shuffle_f32x4(T4, T6, 0x88); - kernel.packet[3] = _mm512_shuffle_f32x4(T4, T6, 0xdd); - kernel.packet[4] = _mm512_shuffle_f32x4(T1, T3, 0x88); - kernel.packet[6] = _mm512_shuffle_f32x4(T1, T3, 0xdd); - kernel.packet[5] = _mm512_shuffle_f32x4(T5, T7, 0x88); - kernel.packet[7] = _mm512_shuffle_f32x4(T5, T7, 0xdd); - } else { - T0 = _mm512_castpd_ps(_mm512_permutex_pd(_mm512_castps_pd(kernel.packet[4]), 0x4E)); - T0 = _mm512_mask_blend_ps(0xF0F0, kernel.packet[0], T0); - T4 = _mm512_castpd_ps(_mm512_permutex_pd(_mm512_castps_pd(kernel.packet[0]), 0x4E)); - T4 = _mm512_mask_blend_ps(0xF0F0, T4, kernel.packet[4]); - T1 = _mm512_castpd_ps(_mm512_permutex_pd(_mm512_castps_pd(kernel.packet[5]), 0x4E)); - T1 = _mm512_mask_blend_ps(0xF0F0, kernel.packet[1], T1); - T5 = _mm512_castpd_ps(_mm512_permutex_pd(_mm512_castps_pd(kernel.packet[1]), 0x4E)); - T5 = _mm512_mask_blend_ps(0xF0F0, T5, kernel.packet[5]); - T2 = _mm512_castpd_ps(_mm512_permutex_pd(_mm512_castps_pd(kernel.packet[6]), 0x4E)); - T2 = _mm512_mask_blend_ps(0xF0F0, kernel.packet[2], T2); - T6 = _mm512_castpd_ps(_mm512_permutex_pd(_mm512_castps_pd(kernel.packet[2]), 0x4E)); - T6 = _mm512_mask_blend_ps(0xF0F0, T6, kernel.packet[6]); - T3 = _mm512_castpd_ps(_mm512_permutex_pd(_mm512_castps_pd(kernel.packet[7]), 0x4E)); - T3 = _mm512_mask_blend_ps(0xF0F0, kernel.packet[3], T3); - T7 = _mm512_castpd_ps(_mm512_permutex_pd(_mm512_castps_pd(kernel.packet[3]), 0x4E)); - T7 = _mm512_mask_blend_ps(0xF0F0, T7, kernel.packet[7]); - - kernel.packet[0] = T0; kernel.packet[1] = T1; - kernel.packet[2] = T2; kernel.packet[3] = T3; - kernel.packet[4] = T4; kernel.packet[5] = T5; - kernel.packet[6] = T6; kernel.packet[7] = T7; - } + kernel.packet[0] = T0; kernel.packet[1] = T1; + kernel.packet[2] = T2; kernel.packet[3] = T3; + kernel.packet[4] = T4; kernel.packet[5] = T5; + kernel.packet[6] = T6; kernel.packet[7] = T7; } EIGEN_DEVICE_FUNC inline void ptranspose(PacketBlock& kernel) { @@ -1571,9 +1549,7 @@ EIGEN_DEVICE_FUNC inline void ptranspose(PacketBlock& kernel) { PACK_OUTPUT_D(kernel.packet, tmp.packet, 3, 1); } -template EIGEN_DEVICE_FUNC inline void ptranspose(PacketBlock& kernel) { - // Transpose for trsm is the same as for gemm. __m512d T0 = _mm512_unpacklo_pd(kernel.packet[0],kernel.packet[1]); __m512d T1 = _mm512_unpackhi_pd(kernel.packet[0],kernel.packet[1]); __m512d T2 = _mm512_unpacklo_pd(kernel.packet[2],kernel.packet[3]); diff --git a/Eigen/src/Core/arch/AVX512/TrsmUnrolls.inc b/Eigen/src/Core/arch/AVX512/TrsmUnrolls.inc index 9fd7de92d..22cb1c93d 100644 --- a/Eigen/src/Core/arch/AVX512/TrsmUnrolls.inc +++ b/Eigen/src/Core/arch/AVX512/TrsmUnrolls.inc @@ -198,7 +198,7 @@ public: r.packet[5] = zmm.packet[packetIndexOffset + zmmStride*5]; r.packet[6] = zmm.packet[packetIndexOffset + zmmStride*6]; r.packet[7] = zmm.packet[packetIndexOffset + zmmStride*7]; - ptranspose(r); + ptranspose(r); zmm.packet[packetIndexOffset + zmmStride*0] = r.packet[0]; zmm.packet[packetIndexOffset + zmmStride*1] = r.packet[1]; zmm.packet[packetIndexOffset + zmmStride*2] = r.packet[2]; diff --git a/Eigen/src/Core/arch/AVX512/TypeCasting.h b/Eigen/src/Core/arch/AVX512/TypeCasting.h index d28cca214..8baced1d6 100644 --- a/Eigen/src/Core/arch/AVX512/TypeCasting.h +++ b/Eigen/src/Core/arch/AVX512/TypeCasting.h @@ -44,34 +44,6 @@ template<> EIGEN_STRONG_INLINE Packet8f preinterpret(const return _mm512_castps512_ps256(a); } -template<> EIGEN_STRONG_INLINE Packet4f preinterpret(const Packet16f& a) { - return _mm512_castps512_ps128(a); -} - -template<> EIGEN_STRONG_INLINE Packet4d preinterpret(const Packet8d& a) { - return _mm512_castpd512_pd256(a); -} - -template<> EIGEN_STRONG_INLINE Packet2d preinterpret(const Packet8d& a) { - return _mm512_castpd512_pd128(a); -} - -template<> EIGEN_STRONG_INLINE Packet16f preinterpret(const Packet8f& a) { - return _mm512_castps256_ps512(a); -} - -template<> EIGEN_STRONG_INLINE Packet16f preinterpret(const Packet4f& a) { - return _mm512_castps128_ps512(a); -} - -template<> EIGEN_STRONG_INLINE Packet8d preinterpret(const Packet4d& a) { - return _mm512_castpd256_pd512(a); -} - -template<> EIGEN_STRONG_INLINE Packet8d preinterpret(const Packet2d& a) { - return _mm512_castpd128_pd512(a); -} - template<> EIGEN_STRONG_INLINE Packet16f preinterpret(const Packet16f& a) { return a; } diff --git a/Eigen/src/Core/arch/NEON/GeneralBlockPanelKernel.h b/Eigen/src/Core/arch/NEON/GeneralBlockPanelKernel.h index d64b1a021..6cd6edd56 100644 --- a/Eigen/src/Core/arch/NEON/GeneralBlockPanelKernel.h +++ b/Eigen/src/Core/arch/NEON/GeneralBlockPanelKernel.h @@ -8,9 +8,9 @@ namespace internal { // Clang seems to excessively spill registers in the GEBP kernel on 32-bit arm. // Here we specialize gebp_traits to eliminate these register spills. // See #2138. -template -struct gebp_traits - : gebp_traits +template<> +struct gebp_traits + : gebp_traits { EIGEN_STRONG_INLINE void acc(const AccPacket& c, const ResPacket& alpha, ResPacket& r) const { @@ -43,9 +43,9 @@ struct gebp_traits -struct gebp_traits - : gebp_traits +template<> +struct gebp_traits + : gebp_traits { typedef float RhsPacket; typedef float32x4_t RhsPacketx4; @@ -108,9 +108,9 @@ struct gebp_traits -struct gebp_traits - : gebp_traits +template<> +struct gebp_traits + : gebp_traits { typedef double RhsPacket; diff --git a/Eigen/src/Core/arch/SSE/PacketMath.h b/Eigen/src/Core/arch/SSE/PacketMath.h index e896040b5..35490a602 100755 --- a/Eigen/src/Core/arch/SSE/PacketMath.h +++ b/Eigen/src/Core/arch/SSE/PacketMath.h @@ -285,10 +285,6 @@ template<> EIGEN_STRONG_INLINE Packet4i padd(const Packet4i& a, const template<> EIGEN_STRONG_INLINE Packet16b padd(const Packet16b& a, const Packet16b& b) { return _mm_or_si128(a,b); } -template EIGEN_STRONG_INLINE Packet padds(const Packet& a, const Packet& b); -template<> EIGEN_STRONG_INLINE Packet4f padds(const Packet4f& a, const Packet4f& b) { return _mm_add_ss(a,b); } -template<> EIGEN_STRONG_INLINE Packet2d padds(const Packet2d& a, const Packet2d& b) { return _mm_add_sd(a,b); } - template<> EIGEN_STRONG_INLINE Packet4f psub(const Packet4f& a, const Packet4f& b) { return _mm_sub_ps(a,b); } template<> EIGEN_STRONG_INLINE Packet2d psub(const Packet2d& a, const Packet2d& b) { return _mm_sub_pd(a,b); } template<> EIGEN_STRONG_INLINE Packet4i psub(const Packet4i& a, const Packet4i& b) { return _mm_sub_epi32(a,b); } @@ -374,10 +370,6 @@ template<> EIGEN_STRONG_INLINE Packet4f pnmadd(const Packet4f& a, const Packet4f template<> EIGEN_STRONG_INLINE Packet2d pnmadd(const Packet2d& a, const Packet2d& b, const Packet2d& c) { return _mm_fnmadd_pd(a,b,c); } template<> EIGEN_STRONG_INLINE Packet4f pnmsub(const Packet4f& a, const Packet4f& b, const Packet4f& c) { return _mm_fnmsub_ps(a,b,c); } template<> EIGEN_STRONG_INLINE Packet2d pnmsub(const Packet2d& a, const Packet2d& b, const Packet2d& c) { return _mm_fnmsub_pd(a,b,c); } - -template EIGEN_STRONG_INLINE Packet pmadds(const Packet& a, const Packet& b, const Packet& c); -template<> EIGEN_STRONG_INLINE Packet4f pmadds(const Packet4f& a, const Packet4f& b, const Packet4f& c) { return _mm_fmadd_ss(a,b,c); } -template<> EIGEN_STRONG_INLINE Packet2d pmadds(const Packet2d& a, const Packet2d& b, const Packet2d& c) { return _mm_fmadd_sd(a,b,c); } #endif #ifdef EIGEN_VECTORIZE_SSE4_1 @@ -754,15 +746,6 @@ template<> EIGEN_STRONG_INLINE Packet16b ploadu(const bool* from) return _mm_loadu_si128(reinterpret_cast(from)); } -// Load lower part of packet zero extending. -template EIGEN_STRONG_INLINE Packet ploadl(const typename unpacket_traits::type* from); -template<> EIGEN_STRONG_INLINE Packet4f ploadl(const float* from) { EIGEN_DEBUG_UNALIGNED_LOAD return _mm_castpd_ps(_mm_load_sd(reinterpret_cast(from))); } -template<> EIGEN_STRONG_INLINE Packet2d ploadl(const double* from) { EIGEN_DEBUG_UNALIGNED_LOAD return _mm_load_sd(from); } - -// Load scalar -template EIGEN_STRONG_INLINE Packet ploads(const typename unpacket_traits::type* from); -template<> EIGEN_STRONG_INLINE Packet4f ploads(const float* from) { EIGEN_DEBUG_UNALIGNED_LOAD return _mm_load_ss(from); } -template<> EIGEN_STRONG_INLINE Packet2d ploads(const double* from) { EIGEN_DEBUG_UNALIGNED_LOAD return _mm_load_sd(from); } template<> EIGEN_STRONG_INLINE Packet4f ploaddup(const float* from) { @@ -804,14 +787,6 @@ template<> EIGEN_STRONG_INLINE void pstoreu(float* to, const Packet4f& template<> EIGEN_STRONG_INLINE void pstoreu(int* to, const Packet4i& from) { EIGEN_DEBUG_UNALIGNED_STORE _mm_storeu_si128(reinterpret_cast<__m128i*>(to), from); } template<> EIGEN_STRONG_INLINE void pstoreu(bool* to, const Packet16b& from) { EIGEN_DEBUG_ALIGNED_STORE _mm_storeu_si128(reinterpret_cast<__m128i*>(to), from); } -template EIGEN_STRONG_INLINE void pstorel(Scalar* to, const Packet& from); -template<> EIGEN_STRONG_INLINE void pstorel(float* to, const Packet4f& from) { EIGEN_DEBUG_UNALIGNED_STORE _mm_storel_pi(reinterpret_cast<__m64*>(to), from); } -template<> EIGEN_STRONG_INLINE void pstorel(double* to, const Packet2d& from) { EIGEN_DEBUG_UNALIGNED_STORE _mm_storel_pd(to, from); } - -template EIGEN_STRONG_INLINE void pstores(Scalar* to, const Packet& from); -template<> EIGEN_STRONG_INLINE void pstores(float* to, const Packet4f& from) { EIGEN_DEBUG_UNALIGNED_STORE _mm_store_ss(to, from); } -template<> EIGEN_STRONG_INLINE void pstores(double* to, const Packet2d& from) { EIGEN_DEBUG_UNALIGNED_STORE _mm_store_sd(to, from); } - template<> EIGEN_DEVICE_FUNC inline Packet4f pgather(const float* from, Index stride) { return _mm_set_ps(from[3*stride], from[2*stride], from[1*stride], from[0*stride]); diff --git a/Eigen/src/Core/arch/SSE/TypeCasting.h b/Eigen/src/Core/arch/SSE/TypeCasting.h index a6346ea0e..c21d1acd9 100644 --- a/Eigen/src/Core/arch/SSE/TypeCasting.h +++ b/Eigen/src/Core/arch/SSE/TypeCasting.h @@ -71,14 +71,6 @@ template<> EIGEN_STRONG_INLINE Packet2d pcast(const Packet4f return _mm_cvtps_pd(a); } -template<> EIGEN_STRONG_INLINE Packet2d preinterpret(const Packet4f& a) { - return _mm_castps_pd(a); -} - -template<> EIGEN_STRONG_INLINE Packet4f preinterpret(const Packet2d& a) { - return _mm_castpd_ps(a); -} - template<> EIGEN_STRONG_INLINE Packet4i preinterpret(const Packet4f& a) { return _mm_castps_si128(a); } diff --git a/Eigen/src/Core/products/GeneralBlockPanelKernel.h b/Eigen/src/Core/products/GeneralBlockPanelKernel.h index 2502cd9a5..b1a127754 100644 --- a/Eigen/src/Core/products/GeneralBlockPanelKernel.h +++ b/Eigen/src/Core/products/GeneralBlockPanelKernel.h @@ -2,7 +2,6 @@ // for linear algebra. // // Copyright (C) 2008-2009 Gael Guennebaud -// Modifications Copyright (C) 2022 Intel Corporation // // This Source Code Form is subject to the terms of the Mozilla // Public License v. 2.0. If a copy of the MPL was not distributed @@ -24,7 +23,7 @@ enum GEBPPacketSizeType { GEBPPacketQuarter }; -template +template class gebp_traits; @@ -126,7 +125,7 @@ inline void manage_caching_sizes(Action action, std::ptrdiff_t* l1, std::ptrdiff template void evaluateProductBlockingSizesHeuristic(Index& k, Index& m, Index& n, Index num_threads = 1) { - typedef gebp_traits Traits; + typedef gebp_traits Traits; // Explanations: // Let's recall that the product algorithms form mc x kc vertical panels A' on the lhs and @@ -417,7 +416,7 @@ struct packet_conditional { typedef T2 type; }; * cplx*real : unpack rhs to constant packets, ... * real*cplx : load lhs as (a0,a0,a1,a1), and mul as usual */ -template +template class gebp_traits { public: @@ -430,7 +429,6 @@ public: PACKET_DECL_COND_POSTFIX(_, Res, PacketSize_); enum { - UnitResIncr = UnitResIncr_, ConjLhs = ConjLhs_, ConjRhs = ConjRhs_, Vectorizable = unpacket_traits::vectorizable && unpacket_traits::vectorizable, @@ -439,17 +437,9 @@ public: ResPacketSize = Vectorizable ? unpacket_traits::size : 1, NumberOfRegisters = EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS, - IsReal = std::is_same::value - && (std::is_same::value - || std::is_same::value), // register block size along the N direction must be 1 or 4 -#if defined(EIGEN_VECTORIZE_AVX512) - // AVX512 support nr = 8 for unit inner strides for result matrix. - nr = IsReal && Vectorizable && UnitResIncr ? 8 : 4, -#else nr = 4, -#endif // register block size along the M direction (currently, this one cannot be modified) default_mr = (plain_enum_min(16, NumberOfRegisters)/2/nr)*LhsPacketSize, @@ -555,9 +545,8 @@ public: }; - -template -class gebp_traits, RealScalar, UnitResIncr_, ConjLhs_, false, Arch, PacketSize_> +template +class gebp_traits, RealScalar, ConjLhs_, false, Arch, PacketSize_> { public: typedef std::complex LhsScalar; @@ -767,8 +756,8 @@ template struct unpacket_traits > { // return res; // } -template -class gebp_traits, std::complex, UnitResIncr_, ConjLhs_, ConjRhs_, Arch, PacketSize_ > +template +class gebp_traits, std::complex, ConjLhs_, ConjRhs_, Arch, PacketSize_ > { public: typedef std::complex Scalar; @@ -933,8 +922,8 @@ protected: conj_helper cj; }; -template -class gebp_traits, UnitResIncr, false, ConjRhs_, Arch, PacketSize_ > +template +class gebp_traits, false, ConjRhs_, Arch, PacketSize_ > { public: typedef std::complex Scalar; @@ -1069,9 +1058,9 @@ protected: template struct gebp_kernel { - typedef gebp_traits Traits; - typedef gebp_traits HalfTraits; - typedef gebp_traits QuarterTraits; + typedef gebp_traits Traits; + typedef gebp_traits HalfTraits; + typedef gebp_traits QuarterTraits; typedef typename Traits::ResScalar ResScalar; typedef typename Traits::LhsPacket LhsPacket; @@ -1082,7 +1071,7 @@ struct gebp_kernel typedef typename RhsPanelHelper::type RhsPanel15; - typedef gebp_traits SwappedTraits; + typedef gebp_traits SwappedTraits; typedef typename SwappedTraits::ResScalar SResScalar; typedef typename SwappedTraits::LhsPacket SLhsPacket; @@ -1120,11 +1109,11 @@ struct gebp_kernel }; template::LhsProgress> +int SwappedLhsProgress = gebp_traits::LhsProgress> struct last_row_process_16_packets { - typedef gebp_traits Traits; - typedef gebp_traits SwappedTraits; + typedef gebp_traits Traits; + typedef gebp_traits SwappedTraits; typedef typename Traits::ResScalar ResScalar; typedef typename SwappedTraits::LhsPacket SLhsPacket; @@ -1152,8 +1141,8 @@ struct last_row_process_16_packets template struct last_row_process_16_packets { - typedef gebp_traits Traits; - typedef gebp_traits SwappedTraits; + typedef gebp_traits Traits; + typedef gebp_traits SwappedTraits; typedef typename Traits::ResScalar ResScalar; typedef typename SwappedTraits::LhsPacket SLhsPacket; @@ -1419,15 +1408,6 @@ void gebp_kernel=4 ? (cols/4) * 4 : 0; Index count = 0; const Index peeled_k = (depth/PacketSize)*PacketSize; - if(nr>=8) - { - for(Index j2=0; j2 kernel; - - kernel.packet[0] = dm0.template loadPacket(k); - kernel.packet[1] = dm1.template loadPacket(k); - kernel.packet[2] = dm2.template loadPacket(k); - kernel.packet[3] = dm3.template loadPacket(k); - kernel.packet[4] = dm4.template loadPacket(k); - kernel.packet[5] = dm5.template loadPacket(k); - kernel.packet[6] = dm6.template loadPacket(k); - kernel.packet[7] = dm7.template loadPacket(k); - - ptranspose(kernel); - - pstoreu(blockB+count+0*PacketSize, cj.pconj(kernel.packet[0])); - pstoreu(blockB+count+1*PacketSize, cj.pconj(kernel.packet[1%PacketSize])); - pstoreu(blockB+count+2*PacketSize, cj.pconj(kernel.packet[2%PacketSize])); - pstoreu(blockB+count+3*PacketSize, cj.pconj(kernel.packet[3%PacketSize])); - pstoreu(blockB+count+4*PacketSize, cj.pconj(kernel.packet[4%PacketSize])); - pstoreu(blockB+count+5*PacketSize, cj.pconj(kernel.packet[5%PacketSize])); - pstoreu(blockB+count+6*PacketSize, cj.pconj(kernel.packet[6%PacketSize])); - pstoreu(blockB+count+7*PacketSize, cj.pconj(kernel.packet[7%PacketSize])); - count+=8*PacketSize; - } - } -#endif - for(; k=8) +// { +// for(Index j2=0; j2 kernel; +// for (int p = 0; p < PacketSize; ++p) { +// kernel.packet[p] = ploadu(&rhs[(j2+p)*rhsStride+k]); +// } +// ptranspose(kernel); +// for (int p = 0; p < PacketSize; ++p) { +// pstoreu(blockB+count, cj.pconj(kernel.packet[p])); +// count+=PacketSize; +// } +// } +// } +// for(; k=4) { @@ -2558,50 +2522,39 @@ struct gemm_pack_rhs=4 ? (cols/4) * 4 : 0; Index count = 0; - if(nr>=8) - { - for(Index j2=0; j2(&rhs.data()[k*rhs.stride() + j2]); - Packet A = rhs.template loadPacket(k, j2); - pstoreu(blockB+count, cj.pconj(A)); - } else if (HasHalf && HalfPacketSize==8) { - HalfPacket A = rhs.template loadPacket(k, j2); - pstoreu(blockB+count, cj.pconj(A)); - } else if (HasQuarter && QuarterPacketSize==8) { - QuarterPacket A = rhs.template loadPacket(k, j2); - pstoreu(blockB+count, cj.pconj(A)); - } else if (PacketSize==4) { - // Packet A = ploadu(&rhs.data()[k*rhs.stride() + j2]); - // Packet B = ploadu(&rhs.data()[k*rhs.stride() + j2 + PacketSize]); - Packet A = rhs.template loadPacket(k, j2); - Packet B = rhs.template loadPacket(k, j2 + PacketSize); - pstoreu(blockB+count, cj.pconj(A)); - pstoreu(blockB+count+PacketSize, cj.pconj(B)); - } else { - // const Scalar* b0 = &rhs.data()[k*rhs.stride() + j2]; - const LinearMapper dm0 = rhs.getLinearMapper(k, j2); - blockB[count+0] = cj(dm0(0)); - blockB[count+1] = cj(dm0(1)); - blockB[count+2] = cj(dm0(2)); - blockB[count+3] = cj(dm0(3)); - blockB[count+4] = cj(dm0(4)); - blockB[count+5] = cj(dm0(5)); - blockB[count+6] = cj(dm0(6)); - blockB[count+7] = cj(dm0(7)); - } - count += 8; - } - // skip what we have after - if(PanelMode) count += 8 * (stride-offset-depth); - } - } - + // if(nr>=8) + // { + // for(Index j2=0; j2(&rhs[k*rhsStride + j2]); + // pstoreu(blockB+count, cj.pconj(A)); + // } else if (PacketSize==4) { + // Packet A = ploadu(&rhs[k*rhsStride + j2]); + // Packet B = ploadu(&rhs[k*rhsStride + j2 + PacketSize]); + // pstoreu(blockB+count, cj.pconj(A)); + // pstoreu(blockB+count+PacketSize, cj.pconj(B)); + // } else { + // const Scalar* b0 = &rhs[k*rhsStride + j2]; + // blockB[count+0] = cj(b0[0]); + // blockB[count+1] = cj(b0[1]); + // blockB[count+2] = cj(b0[2]); + // blockB[count+3] = cj(b0[3]); + // blockB[count+4] = cj(b0[4]); + // blockB[count+5] = cj(b0[5]); + // blockB[count+6] = cj(b0[6]); + // blockB[count+7] = cj(b0[7]); + // } + // count += 8; + // } + // // skip what we have after + // if(PanelMode) count += 8 * (stride-offset-depth); + // } + // } if(nr>=4) { for(Index j2=packet_cols8; j2 struct general_matrix_matrix_product { - typedef gebp_traits Traits; + typedef gebp_traits Traits; typedef typename ScalarBinaryOpTraits::ReturnType ResScalar; static EIGEN_STRONG_INLINE void run( @@ -57,7 +57,7 @@ template< struct general_matrix_matrix_product { -typedef gebp_traits Traits; +typedef gebp_traits Traits; typedef typename ScalarBinaryOpTraits::ReturnType ResScalar; static void run(Index rows, Index cols, Index depth, @@ -287,6 +287,7 @@ class gemm_blocking_space LhsScalar; typedef std::conditional_t RhsScalar; + typedef gebp_traits Traits; enum { SizeA = ActualRows * MaxDepth, SizeB = ActualCols * MaxDepth @@ -335,6 +336,7 @@ class gemm_blocking_space LhsScalar; typedef std::conditional_t RhsScalar; + typedef gebp_traits Traits; Index m_sizeA; Index m_sizeB; diff --git a/Eigen/src/Core/products/GeneralMatrixMatrixTriangular.h b/Eigen/src/Core/products/GeneralMatrixMatrixTriangular.h index b0024e75b..716f2ca78 100644 --- a/Eigen/src/Core/products/GeneralMatrixMatrixTriangular.h +++ b/Eigen/src/Core/products/GeneralMatrixMatrixTriangular.h @@ -67,7 +67,7 @@ struct general_matrix_matrix_triangular_product& blocking) { - typedef gebp_traits Traits; + typedef gebp_traits Traits; typedef const_blas_data_mapper LhsMapper; typedef const_blas_data_mapper RhsMapper; @@ -140,7 +140,7 @@ struct general_matrix_matrix_triangular_product struct tribb_kernel { - typedef gebp_traits Traits; + typedef gebp_traits Traits; typedef typename Traits::ResScalar ResScalar; enum { diff --git a/Eigen/src/Core/products/GeneralMatrixMatrix_BLAS.h b/Eigen/src/Core/products/GeneralMatrixMatrix_BLAS.h index 66217e5d9..490fe67c5 100644 --- a/Eigen/src/Core/products/GeneralMatrixMatrix_BLAS.h +++ b/Eigen/src/Core/products/GeneralMatrixMatrix_BLAS.h @@ -55,7 +55,7 @@ template< \ int RhsStorageOrder, bool ConjugateRhs> \ struct general_matrix_matrix_product \ { \ -typedef gebp_traits Traits; \ +typedef gebp_traits Traits; \ \ static void run(Index rows, Index cols, Index depth, \ const EIGTYPE* _lhs, Index lhsStride, \ diff --git a/Eigen/src/Core/products/SelfadjointMatrixMatrix.h b/Eigen/src/Core/products/SelfadjointMatrixMatrix.h index 7c547870a..c7bb44596 100644 --- a/Eigen/src/Core/products/SelfadjointMatrixMatrix.h +++ b/Eigen/src/Core/products/SelfadjointMatrixMatrix.h @@ -351,7 +351,7 @@ EIGEN_DONT_INLINE void product_selfadjoint_matrix Traits; + typedef gebp_traits Traits; typedef const_blas_data_mapper LhsMapper; typedef const_blas_data_mapper LhsTransposeMapper; @@ -446,7 +446,7 @@ EIGEN_DONT_INLINE void product_selfadjoint_matrix Traits; + typedef gebp_traits Traits; typedef const_blas_data_mapper LhsMapper; typedef blas_data_mapper ResMapper; diff --git a/Eigen/src/Core/products/TriangularMatrixMatrix.h b/Eigen/src/Core/products/TriangularMatrixMatrix.h index 6266dad86..770107a0b 100644 --- a/Eigen/src/Core/products/TriangularMatrixMatrix.h +++ b/Eigen/src/Core/products/TriangularMatrixMatrix.h @@ -89,7 +89,7 @@ struct product_triangular_matrix_matrix { - typedef gebp_traits Traits; + typedef gebp_traits Traits; enum { SmallPanelWidth = 2 * plain_enum_max(Traits::mr, Traits::nr), IsLower = (Mode&Lower) == Lower, @@ -247,7 +247,7 @@ struct product_triangular_matrix_matrix { - typedef gebp_traits Traits; + typedef gebp_traits Traits; enum { SmallPanelWidth = plain_enum_max(Traits::mr, Traits::nr), IsLower = (Mode&Lower) == Lower, diff --git a/Eigen/src/Core/products/TriangularSolverMatrix.h b/Eigen/src/Core/products/TriangularSolverMatrix.h index 2e6d6a8ca..def6a28f2 100644 --- a/Eigen/src/Core/products/TriangularSolverMatrix.h +++ b/Eigen/src/Core/products/TriangularSolverMatrix.h @@ -189,7 +189,7 @@ EIGEN_DONT_INLINE void triangular_solve_matrix Traits; + typedef gebp_traits Traits; enum { SmallPanelWidth = plain_enum_max(Traits::mr, Traits::nr), @@ -336,7 +336,7 @@ EIGEN_DONT_INLINE void triangular_solve_matrix Traits; + typedef gebp_traits Traits; enum { RhsStorageOrder = TriStorageOrder, SmallPanelWidth = plain_enum_max(Traits::mr, Traits::nr), diff --git a/Eigen/src/Core/util/BlasUtil.h b/Eigen/src/Core/util/BlasUtil.h index f59a55faa..f45665e3b 100755 --- a/Eigen/src/Core/util/BlasUtil.h +++ b/Eigen/src/Core/util/BlasUtil.h @@ -173,7 +173,6 @@ class blas_data_mapper public: typedef BlasLinearMapper LinearMapper; typedef BlasVectorMapper VectorMapper; - static constexpr int incr = 1; EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE blas_data_mapper(Scalar* data, Index stride, Index incr=1) : m_data(data), m_stride(stride) @@ -286,7 +285,6 @@ class blas_data_mapper { public: typedef BlasLinearMapper LinearMapper; - static constexpr int incr = Incr; EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE blas_data_mapper(Scalar* data, Index stride, Index incr) : m_data(data), m_stride(stride), m_incr(incr) {} @@ -404,9 +402,6 @@ public: storePacketBlock_helper spb; spb.store(this, i,j,block); } - - EIGEN_DEVICE_FUNC const Index stride() const { return m_stride; } - EIGEN_DEVICE_FUNC Scalar* data() const { return m_data; } protected: Scalar* EIGEN_RESTRICT m_data; const Index m_stride; diff --git a/unsupported/Eigen/MPRealSupport b/unsupported/Eigen/MPRealSupport index 8167bb1c2..c4ea4ec5f 100644 --- a/unsupported/Eigen/MPRealSupport +++ b/unsupported/Eigen/MPRealSupport @@ -143,8 +143,8 @@ int main() // Specialize GEBP kernel and traits for mpreal (no need for peeling, nor complicated stuff) // This also permits to directly call mpfr's routines and avoid many temporaries produced by mpreal - template - class gebp_traits + template<> + class gebp_traits { public: typedef mpfr::mpreal ResScalar;