From 37673ca1bc71ed1cd6e10397241376c34bd7251f Mon Sep 17 00:00:00 2001 From: b-shi Date: Fri, 17 Jun 2022 18:05:26 +0000 Subject: [PATCH] AVX512 TRSM kernels use alloca if EIGEN_NO_MALLOC requested --- Eigen/src/Core/arch/AVX512/GemmKernel.h | 1881 +++++++++-------- Eigen/src/Core/arch/AVX512/TrsmKernel.h | 1174 +++++----- Eigen/src/Core/arch/AVX512/TrsmUnrolls.inc | 1056 +++++---- .../Core/products/TriangularSolverMatrix.h | 6 +- 4 files changed, 2087 insertions(+), 2030 deletions(-) diff --git a/Eigen/src/Core/arch/AVX512/GemmKernel.h b/Eigen/src/Core/arch/AVX512/GemmKernel.h index d7198986e..477c50fa5 100644 --- a/Eigen/src/Core/arch/AVX512/GemmKernel.h +++ b/Eigen/src/Core/arch/AVX512/GemmKernel.h @@ -20,832 +20,895 @@ #include "../../InternalHeaderCheck.h" -#define SECOND_FETCH (32) +#define SECOND_FETCH (32) #if (EIGEN_COMP_GNUC_STRICT != 0) && !defined(EIGEN_ARCH_AVX512_GEMM_KERNEL_USE_LESS_A_REGS) // Use less registers to load A elements to workaround compiler spills. Loose a // bit of performance (less than ~2%). #define EIGEN_ARCH_AVX512_GEMM_KERNEL_USE_LESS_A_REGS #endif - namespace Eigen { namespace internal { template -class gemm_class -{ - using vec = typename packet_traits::type; - using vec_ymm = typename unpacket_traits::half; - using vec_xmm = typename unpacket_traits::half; - using umask_t = typename unpacket_traits::mask_t; +class gemm_class { + using vec = typename packet_traits::type; + using vec_ymm = typename unpacket_traits::half; + using vec_xmm = typename unpacket_traits::half; + using umask_t = typename unpacket_traits::mask_t; - static constexpr bool is_f32 = sizeof(Scalar) == sizeof(float); - static constexpr bool is_f64 = sizeof(Scalar) == sizeof(double); + static constexpr bool is_f32 = sizeof(Scalar) == sizeof(float); + static constexpr bool is_f64 = sizeof(Scalar) == sizeof(double); #ifndef EIGEN_ARCH_AVX512_GEMM_KERNEL_USE_LESS_A_REGS - static constexpr bool use_less_a_regs = !is_unit_inc; + static constexpr bool use_less_a_regs = !is_unit_inc; #else - static constexpr bool use_less_a_regs = true; + static constexpr bool use_less_a_regs = true; #endif #ifndef EIGEN_ARCH_AVX512_GEMM_KERNEL_USE_LESS_B_REGS - static constexpr bool use_less_b_regs = !is_unit_inc; + static constexpr bool use_less_b_regs = !is_unit_inc; #else - static constexpr bool use_less_b_regs = true; + static constexpr bool use_less_b_regs = true; #endif - static constexpr int a_regs[] = {0, 1, 2, - use_less_a_regs ? 0 : 3, - use_less_a_regs ? 1 : 4, - use_less_a_regs ? 2 : 5 - }; - static constexpr int b_regs[] = {6, - use_less_b_regs ? 6 : 7 - }; - static constexpr int c_regs[] = { - 8 , 16, 24, - 9 , 17, 25, - 10, 18, 26, - 11, 19, 27, - 12, 20, 28, - 13, 21, 29, - 14, 22, 30, - 15, 23, 31, - }; + static constexpr int a_regs[] = {0, 1, 2, use_less_a_regs ? 0 : 3, use_less_a_regs ? 1 : 4, use_less_a_regs ? 2 : 5}; + static constexpr int b_regs[] = {6, use_less_b_regs ? 6 : 7}; + static constexpr int c_regs[] = { + 8, 16, 24, 9, 17, 25, 10, 18, 26, 11, 19, 27, 12, 20, 28, 13, 21, 29, 14, 22, 30, 15, 23, 31, + }; - static constexpr int alpha_load_reg = 0; - static constexpr int c_load_regs[] = {1, 2, 6}; + static constexpr int alpha_load_reg = 0; + static constexpr int c_load_regs[] = {1, 2, 6}; - static constexpr int a_shift = 128; - static constexpr int b_shift = 128; + static constexpr int a_shift = 128; + static constexpr int b_shift = 128; - static constexpr int nelems_in_cache_line = is_f32 ? 16 : 8; - static constexpr int a_prefetch_size = nelems_in_cache_line * 2; - static constexpr int b_prefetch_size = nelems_in_cache_line * 8; + static constexpr int nelems_in_cache_line = is_f32 ? 16 : 8; + static constexpr int a_prefetch_size = nelems_in_cache_line * 2; + static constexpr int b_prefetch_size = nelems_in_cache_line * 8; - vec zmm[32]; - umask_t mask; + vec zmm[32]; + umask_t mask; - // gemm arguments. - Index m; - const Index n, k, ldc; - const Index inc; - const Scalar *alpha; + // gemm arguments. + Index m; + const Index n, k, ldc; + const Index inc; + const Scalar *alpha; - const Scalar *a, *b; - Scalar *c; + const Scalar *a, *b; + Scalar *c; - const bool is_alpha1; - const bool is_beta0; + const bool is_alpha1; + const bool is_beta0; - const Index a_stride, b_stride; - const Index a_off, b_off; + const Index a_stride, b_stride; + const Index a_off, b_off; - static EIGEN_ALWAYS_INLINE constexpr int div_up(int a, int b) { - return a == 0 ? 0 : (a - 1) / b + 1; - } + static EIGEN_ALWAYS_INLINE constexpr int div_up(int a, int b) { return a == 0 ? 0 : (a - 1) / b + 1; } - EIGEN_ALWAYS_INLINE void prefetch_a(const Scalar *a_addr) - { - _mm_prefetch((char *) (a_prefetch_size + a_addr - a_shift), _MM_HINT_T0); - } + EIGEN_ALWAYS_INLINE void prefetch_a(const Scalar *a_addr) { + _mm_prefetch((char *)(a_prefetch_size + a_addr - a_shift), _MM_HINT_T0); + } - EIGEN_ALWAYS_INLINE void prefetch_b(const Scalar *b_addr) - { - _mm_prefetch((char *) (b_prefetch_size + b_addr - b_shift), _MM_HINT_T0); - } + EIGEN_ALWAYS_INLINE void prefetch_b(const Scalar *b_addr) { + _mm_prefetch((char *)(b_prefetch_size + b_addr - b_shift), _MM_HINT_T0); + } - EIGEN_ALWAYS_INLINE void prefetch_x(const Scalar *x_addr) - { - _mm_prefetch((char *) (x_addr - a_shift), _MM_HINT_T2); - } + EIGEN_ALWAYS_INLINE void prefetch_x(const Scalar *x_addr) { _mm_prefetch((char *)(x_addr - a_shift), _MM_HINT_T2); } - EIGEN_ALWAYS_INLINE void prefetch_c(const Scalar *c_addr) - { + EIGEN_ALWAYS_INLINE void prefetch_c(const Scalar *c_addr) { #if defined(__PRFCHW__) && __PRFCHW__ == 1 - _m_prefetchw((void *) c_addr); + _m_prefetchw((void *)c_addr); #else - _mm_prefetch((char *) c_addr, _MM_HINT_T0); + _mm_prefetch((char *)c_addr, _MM_HINT_T0); #endif - } + } - template - EIGEN_ALWAYS_INLINE void a_load(vec &a_reg, const Scalar *a_addr) - { - switch (nelems * sizeof(*a_addr) * 8) { + template + EIGEN_ALWAYS_INLINE void a_load(vec &a_reg, const Scalar *a_addr) { + switch (nelems * sizeof(*a_addr) * 8) { + default: + case 512 * 3: + a_reg = ploadu(a_addr); + break; + case 512 * 2: + a_reg = ploadu(a_addr); + break; + case 512 * 1: + a_reg = ploadu(a_addr); + break; + case 256 * 1: + a_reg = preinterpret(_mm512_broadcast_f64x4(ploadu(reinterpret_cast(a_addr)))); + break; + case 128 * 1: + a_reg = preinterpret(_mm512_broadcast_f32x4(ploadu(reinterpret_cast(a_addr)))); + break; + case 64 * 1: + a_reg = preinterpret(pload1(reinterpret_cast(a_addr))); + break; + case 32 * 1: + a_reg = pload1(a_addr); + break; + } + } + + EIGEN_ALWAYS_INLINE void b_load(vec &b_reg, const Scalar *b_addr) { b_reg = pload1(b_addr); } + + template + EIGEN_ALWAYS_INLINE void c_store(Scalar *mem, vec &src) { + if (is_unit_inc) { + switch (nelems * sizeof(*mem) * 8) { default: - case 512 * 3: a_reg = ploadu(a_addr); break; - case 512 * 2: a_reg = ploadu(a_addr); break; - case 512 * 1: a_reg = ploadu(a_addr); break; - case 256 * 1: a_reg = preinterpret(_mm512_broadcast_f64x4(ploadu(reinterpret_cast(a_addr)))); break; - case 128 * 1: a_reg = preinterpret(_mm512_broadcast_f32x4(ploadu(reinterpret_cast(a_addr)))); break; - case 64 * 1: a_reg = preinterpret(pload1(reinterpret_cast(a_addr))); break; - case 32 * 1: a_reg = pload1(a_addr); break; - } + case 512 * 3: + pstoreu(mem, src); + break; + case 512 * 2: + pstoreu(mem, src); + break; + case 512 * 1: + pstoreu(mem, src); + break; + case 256 * 1: + pstoreu(mem, preinterpret(src)); + break; + case 128 * 1: + pstoreu(mem, preinterpret(src)); + break; + case 64 * 1: + pstorel(mem, preinterpret(src)); + break; + case 32 * 1: + pstores(mem, preinterpret(src)); + break; + } + } else { + switch (nelems * sizeof(*mem) * 8) { + default: + case 512 * 3: + pscatter(mem, src, inc); + break; + case 512 * 2: + pscatter(mem, src, inc); + break; + case 512 * 1: + pscatter(mem, src, inc); + break; + case 256 * 1: + pscatter(mem, src, inc, mask); + break; + case 128 * 1: + pscatter(mem, src, inc, mask); + break; + case 64 * 1: + pscatter(mem, src, inc, mask); + break; + case 32 * 1: + pscatter(mem, src, inc, mask); + break; + } } + } - EIGEN_ALWAYS_INLINE void b_load(vec &b_reg, const Scalar *b_addr) - { - b_reg = pload1(b_addr); + template + EIGEN_ALWAYS_INLINE void vaddm(vec &dst, const Scalar *mem, vec &src, vec ®) { + if (is_unit_inc) { + switch (nelems * sizeof(*mem) * 8) { + default: + case 512 * 3: + dst = padd(src, ploadu(mem)); + break; + case 512 * 2: + dst = padd(src, ploadu(mem)); + break; + case 512 * 1: + dst = padd(src, ploadu(mem)); + break; + case 256 * 1: + dst = preinterpret(padd(preinterpret(src), ploadu(mem))); + break; + case 128 * 1: + dst = preinterpret(padd(preinterpret(src), ploadu(mem))); + break; + case 64 * 1: + dst = preinterpret(padd(preinterpret(src), ploadl(mem))); + break; + case 32 * 1: + dst = preinterpret(padds(preinterpret(src), ploads(mem))); + break; + } + } else { + // Zero out scratch register + reg = pzero(reg); + + switch (nelems * sizeof(*mem) * 8) { + default: + case 512 * 3: + reg = pgather(mem, inc); + dst = padd(src, reg); + break; + case 512 * 2: + reg = pgather(mem, inc); + dst = padd(src, reg); + break; + case 512 * 1: + reg = pgather(mem, inc); + dst = padd(src, reg); + break; + case 256 * 1: + reg = preinterpret(pgather(mem, inc)); + dst = preinterpret(padd(preinterpret(src), preinterpret(reg))); + break; + case 128 * 1: + reg = preinterpret(pgather(mem, inc)); + dst = preinterpret(padd(preinterpret(src), preinterpret(reg))); + break; + case 64 * 1: + if (is_f32) { + reg = pgather(reg, mem, inc, mask); + dst = preinterpret(padd(preinterpret(src), preinterpret(reg))); + } else { + dst = preinterpret(padd(preinterpret(src), ploadl(mem))); + } + break; + case 32 * 1: + dst = preinterpret(padds(preinterpret(src), ploads(mem))); + break; + } } + } - template - EIGEN_ALWAYS_INLINE void c_store(Scalar *mem, vec &src) - { - if (is_unit_inc) { - switch (nelems * sizeof(*mem) * 8) { - default: - case 512 * 3: pstoreu(mem, src); break; - case 512 * 2: pstoreu(mem, src); break; - case 512 * 1: pstoreu(mem, src); break; - case 256 * 1: pstoreu(mem, preinterpret(src)); break; - case 128 * 1: pstoreu(mem, preinterpret(src)); break; - case 64 * 1: pstorel(mem, preinterpret(src)); break; - case 32 * 1: pstores(mem, preinterpret(src)); break; - } - } else { - switch (nelems * sizeof(*mem) * 8) { - default: - case 512 * 3: pscatter(mem, src, inc); break; - case 512 * 2: pscatter(mem, src, inc); break; - case 512 * 1: pscatter(mem, src, inc); break; - case 256 * 1: pscatter(mem, src, inc, mask); break; - case 128 * 1: pscatter(mem, src, inc, mask); break; - case 64 * 1: pscatter(mem, src, inc, mask); break; - case 32 * 1: pscatter(mem, src, inc, mask); break; - } - } - } - - template - EIGEN_ALWAYS_INLINE void vaddm(vec &dst, const Scalar *mem, vec &src, vec ®) - { - if (is_unit_inc) { - switch (nelems * sizeof(*mem) * 8) { - default: - case 512 * 3: dst = padd(src, ploadu(mem)); break; - case 512 * 2: dst = padd(src, ploadu(mem)); break; - case 512 * 1: dst = padd(src, ploadu(mem)); break; - case 256 * 1: dst = preinterpret(padd(preinterpret(src), ploadu(mem))); break; - case 128 * 1: dst = preinterpret(padd(preinterpret(src), ploadu(mem))); break; - case 64 * 1: dst = preinterpret(padd(preinterpret(src), ploadl(mem))); break; - case 32 * 1: dst = preinterpret(padds(preinterpret(src), ploads(mem))); break; - } - } else { - // Zero out scratch register - reg = pzero(reg); - - switch (nelems * sizeof(*mem) * 8) { - default: - case 512 * 3: reg = pgather(mem, inc); dst = padd(src, reg); break; - case 512 * 2: reg = pgather(mem, inc); dst = padd(src, reg); break; - case 512 * 1: reg = pgather(mem, inc); dst = padd(src, reg); break; - case 256 * 1: reg = preinterpret(pgather(mem, inc)); dst = preinterpret(padd(preinterpret(src), preinterpret(reg))); break; - case 128 * 1: reg = preinterpret(pgather(mem, inc)); dst = preinterpret(padd(preinterpret(src), preinterpret(reg))); break; - case 64 * 1: if (is_f32) { - reg = pgather(reg, mem, inc, mask); - dst = preinterpret(padd(preinterpret(src), preinterpret(reg))); - } else { - dst = preinterpret(padd(preinterpret(src), ploadl(mem))); - } - break; - case 32 * 1: dst = preinterpret(padds(preinterpret(src), ploads(mem))); break; - } - } - } - - EIGEN_STRONG_INLINE void vfmadd(vec &dst, const vec &src1, const vec &src2) { - dst = pmadd(src1, src2, dst); + EIGEN_STRONG_INLINE void vfmadd(vec &dst, const vec &src1, const vec &src2) { + dst = pmadd(src1, src2, dst); #if (EIGEN_COMP_GNUC != 0) || (EIGEN_COMP_CLANG != 0) - // Workaround register spills for gcc and clang - __asm__ ("#" : [dst] "+v" (dst) : [src1] "%v" (src1), [src2] "v" (src2)); + // Workaround register spills for gcc and clang + __asm__("#" : [dst] "+v"(dst) : [src1] "%v"(src1), [src2] "v"(src2)); #endif + } + + template + EIGEN_ALWAYS_INLINE void vfmaddm(vec &dst, const Scalar *mem, vec &src, vec &scale, vec ®) { + if (is_unit_inc) { + switch (nelems * sizeof(*mem) * 8) { + default: + case 512 * 3: + dst = pmadd(scale, src, ploadu(mem)); + break; + case 512 * 2: + dst = pmadd(scale, src, ploadu(mem)); + break; + case 512 * 1: + dst = pmadd(scale, src, ploadu(mem)); + break; + case 256 * 1: + dst = + preinterpret(pmadd(preinterpret(scale), preinterpret(src), ploadu(mem))); + break; + case 128 * 1: + dst = + preinterpret(pmadd(preinterpret(scale), preinterpret(src), ploadu(mem))); + break; + case 64 * 1: + dst = + preinterpret(pmadd(preinterpret(scale), preinterpret(src), ploadl(mem))); + break; + case 32 * 1: + dst = + preinterpret(pmadds(preinterpret(scale), preinterpret(src), ploads(mem))); + break; + } + } else { + // Zero out scratch register + reg = pzero(reg); + + switch (nelems * sizeof(*mem) * 8) { + default: + case 512 * 3: + reg = pgather(mem, inc); + dst = pmadd(scale, src, reg); + break; + case 512 * 2: + reg = pgather(mem, inc); + dst = pmadd(scale, src, reg); + break; + case 512 * 1: + reg = pgather(mem, inc); + dst = pmadd(scale, src, reg); + break; + case 256 * 1: + reg = preinterpret(pgather(mem, inc)); + dst = preinterpret( + pmadd(preinterpret(scale), preinterpret(src), preinterpret(reg))); + break; + case 128 * 1: + reg = preinterpret(pgather(mem, inc)); + dst = preinterpret( + pmadd(preinterpret(scale), preinterpret(src), preinterpret(reg))); + break; + case 64 * 1: + if (is_f32) { + reg = pgather(reg, mem, inc, mask); + dst = preinterpret( + pmadd(preinterpret(scale), preinterpret(src), preinterpret(reg))); + } else { + dst = preinterpret( + pmadd(preinterpret(scale), preinterpret(src), ploadl(mem))); + } + break; + case 32 * 1: + dst = + preinterpret(pmadds(preinterpret(scale), preinterpret(src), ploads(mem))); + break; + } } + } - template - EIGEN_ALWAYS_INLINE void vfmaddm(vec &dst, const Scalar *mem, vec &src, vec &scale, vec ®) - { - if (is_unit_inc) { - switch (nelems * sizeof(*mem) * 8) { - default: - case 512 * 3: dst = pmadd(scale, src, ploadu(mem)); break; - case 512 * 2: dst = pmadd(scale, src, ploadu(mem)); break; - case 512 * 1: dst = pmadd(scale, src, ploadu(mem)); break; - case 256 * 1: dst = preinterpret(pmadd(preinterpret(scale), preinterpret(src), ploadu(mem))); break; - case 128 * 1: dst = preinterpret(pmadd(preinterpret(scale), preinterpret(src), ploadu(mem))); break; - case 64 * 1: dst = preinterpret(pmadd(preinterpret(scale), preinterpret(src), ploadl(mem))); break; - case 32 * 1: dst = preinterpret(pmadds(preinterpret(scale), preinterpret(src), ploads(mem))); break; - } - } else { - // Zero out scratch register - reg = pzero(reg); + template + EIGEN_ALWAYS_INLINE std::enable_if_t<(j > endX) || (i > endY)> a_loads(const Scalar *ao) { + EIGEN_UNUSED_VARIABLE(ao); + } - switch (nelems * sizeof(*mem) * 8) { - default: - case 512 * 3: reg = pgather(mem, inc); dst = pmadd(scale, src, reg); break; - case 512 * 2: reg = pgather(mem, inc); dst = pmadd(scale, src, reg); break; - case 512 * 1: reg = pgather(mem, inc); dst = pmadd(scale, src, reg); break; - case 256 * 1: reg = preinterpret(pgather(mem, inc)); dst = preinterpret(pmadd(preinterpret(scale), preinterpret(src), preinterpret(reg))); break; - case 128 * 1: reg = preinterpret(pgather(mem, inc)); dst = preinterpret(pmadd(preinterpret(scale), preinterpret(src), preinterpret(reg))); break; - case 64 * 1: if (is_f32) { - reg = pgather(reg, mem, inc, mask); - dst = preinterpret(pmadd(preinterpret(scale), preinterpret(src), preinterpret(reg))); - } else { - dst = preinterpret(pmadd(preinterpret(scale), preinterpret(src), ploadl(mem))); - } - break; - case 32 * 1: dst = preinterpret(pmadds(preinterpret(scale), preinterpret(src), ploads(mem))); break; - } - } + template + EIGEN_ALWAYS_INLINE std::enable_if_t<(j <= endX) && (i <= endY)> a_loads(const Scalar *ao) { + if (j < endX) { + if (i < endY) { + auto &a_reg = zmm[a_regs[i + (j % 2) * 3]]; + const Scalar *a_addr = ao + nelems * j + nelems_in_cache_line * i - a_shift; + a_load(a_reg, a_addr); + + a_loads(ao); + } else { + a_loads(ao); + } } + } - template - EIGEN_ALWAYS_INLINE std::enable_if_t<(j > endX) || (i > endY)> - a_loads(const Scalar *ao) - { - EIGEN_UNUSED_VARIABLE(ao); - } + template + EIGEN_ALWAYS_INLINE std::enable_if_t<(un > max_b_unroll) || (i > um_vecs)> prefetch_cs(const Scalar *co1, + const Scalar *co2) { + EIGEN_UNUSED_VARIABLE(co1); + EIGEN_UNUSED_VARIABLE(co2); + } - template - EIGEN_ALWAYS_INLINE std::enable_if_t<(j <= endX) && (i <= endY)> - a_loads(const Scalar *ao) - { - if (j < endX) { - if (i < endY) { - auto &a_reg = zmm[a_regs[i + (j % 2) * 3]]; - const Scalar *a_addr = ao + nelems * j + nelems_in_cache_line * i - a_shift; - a_load(a_reg, a_addr); + /* C prefetch loop structure. + * for (int un = 0; un < 8; un++) { + * if (b_unroll >= un + 1) { + * if (un == 4) co2 = co1 + 4 * ldc; + * + * for (int i = 0; i < um_vecs; i++) { + * Scalar *co = (un + 1 <= 4) ? co1 : co2; + * auto co_off = (un % 4) * ldc + a_unroll - 1 + i * nelems_in_cache_line * sizeof *co; + * prefetch_c(co + co_off); + * } + * } + * } + */ - a_loads(ao); - } else { - a_loads(ao); - } - } - } + template + EIGEN_ALWAYS_INLINE std::enable_if_t<(un <= max_b_unroll) && (i <= um_vecs)> prefetch_cs(Scalar *&co1, Scalar *&co2) { + if (un < max_b_unroll) { + if (b_unroll >= un + 1) { + if (un == 4 && i == 0) co2 = co1 + 4 * ldc; - template - EIGEN_ALWAYS_INLINE std::enable_if_t<(un > max_b_unroll) || (i > um_vecs)> - prefetch_cs(const Scalar *co1, const Scalar *co2) - { - EIGEN_UNUSED_VARIABLE(co1); - EIGEN_UNUSED_VARIABLE(co2); - } - - /* C prefetch loop structure. - * for (int un = 0; un < 8; un++) { - * if (b_unroll >= un + 1) { - * if (un == 4) co2 = co1 + 4 * ldc; - * - * for (int i = 0; i < um_vecs; i++) { - * Scalar *co = (un + 1 <= 4) ? co1 : co2; - * auto co_off = (un % 4) * ldc + a_unroll - 1 + i * nelems_in_cache_line * sizeof *co; - * prefetch_c(co + co_off); - * } - * } - * } - */ - - template - EIGEN_ALWAYS_INLINE std::enable_if_t<(un <= max_b_unroll) && (i <= um_vecs)> - prefetch_cs(Scalar *&co1, Scalar *&co2) - { - if (un < max_b_unroll) { - - if (b_unroll >= un + 1) { - if (un == 4 && i == 0) co2 = co1 + 4 * ldc; - - if (i < um_vecs) { - Scalar *co = (un + 1 <= 4) ? co1 : co2; - auto co_off = (un % 4) * ldc + a_unroll - 1 + i * nelems_in_cache_line * sizeof *co; - prefetch_c(co + co_off); - - prefetch_cs(co1, co2); - } else { - prefetch_cs(co1, co2); - } - - } else { - prefetch_cs(co1, co2); - } - } - } - - // load_c - template - EIGEN_ALWAYS_INLINE std::enable_if_t<(i > um_vecs)> - scale_load_c(const Scalar *cox, vec &alpha_reg) - { - EIGEN_UNUSED_VARIABLE(cox); - EIGEN_UNUSED_VARIABLE(alpha_reg); - } - - template - EIGEN_ALWAYS_INLINE std::enable_if_t<(i <= um_vecs)> - scale_load_c(const Scalar *cox, vec &alpha_reg) - { if (i < um_vecs) { - auto &c_reg = zmm[c_regs[i + idx * 3]]; - auto &c_load_reg = zmm[c_load_regs[i % 3]]; - auto c_mem = cox; - if (is_unit_inc) - c_mem += i * nelems_in_cache_line; - else - c_mem += i * nelems_in_cache_line * inc; - - - if (!is_beta0 && is_alpha1) - vaddm(c_reg, c_mem, c_reg, c_load_reg); - else if (!is_beta0 && !is_alpha1) - vfmaddm(c_reg, c_mem, c_reg, alpha_reg, c_load_reg); - else if (is_beta0 && !is_alpha1) - c_reg = pmul(alpha_reg, c_reg); - - scale_load_c(cox, alpha_reg); - } - } - - // store_c - template - EIGEN_ALWAYS_INLINE std::enable_if_t<(i > um_vecs)> - write_c(Scalar *cox) - { - EIGEN_UNUSED_VARIABLE(cox); - } - - template - EIGEN_ALWAYS_INLINE std::enable_if_t<(i <= um_vecs)> - write_c(Scalar *cox) - { - if (i < um_vecs) { - auto &c_reg = zmm[c_regs[i + idx * 3]]; - auto c_mem = cox; - if (is_unit_inc) - c_mem += i * nelems_in_cache_line; - else - c_mem += i * nelems_in_cache_line * inc; - - c_store(c_mem, c_reg); - c_reg = pzero(c_reg); - - write_c(cox); - } - } - - /* C update loop structure. - * co2 = co1 + ldc; - * - * auto &alpha_reg = zmm[alpha_load_reg]; - * if (!is_alpha1) alpha_reg = pload1(alpha); - * - * int idx = 0; - * for (pow = 1; pow <= 8; pow <<= 1) { - * - * if (b_unroll >= pow) { - * for (count = 1; count < (pow + 1) / 2 + 1; count++) { - * if (pow >= 4) co2 += ldc; - * - * const Scalar *cox = (idx == 0) ? co1 : co2; - * - * const int um_vecs = div_up(a_unroll, nelems_in_cache_line); - * scale_load_c<0, um_vecs, idx, a_unroll>(cox, alpha_reg); - * write_c<0, um_vecs, idx, a_unroll>(cox); - * - * idx++; - * } - * } - * } - * - * if (b_unroll == 1) - * co1 += ldc; - * else - * co1 = co2 + ldc; - */ - - template - EIGEN_ALWAYS_INLINE void c_update_1count(Scalar *&cox) - { - if (pow >= 4) cox += ldc; - - const int um_vecs = div_up(a_unroll, nelems_in_cache_line); - auto &alpha_reg = zmm[alpha_load_reg]; - - scale_load_c<0, um_vecs, idx, a_unroll>(cox, alpha_reg); - write_c<0, um_vecs, idx, a_unroll>(cox); - } - - template - EIGEN_ALWAYS_INLINE void c_update_1pow(Scalar *&co1, Scalar *&co2) - { - constexpr int idx = pow / 2; - Scalar *&cox = idx == 0 ? co1 : co2; - - constexpr int max_count = (pow + 1) / 2; - static_assert(max_count <= 4, "Unsupported max_count."); - - if (1 <= max_count) c_update_1count(cox); - if (2 <= max_count) c_update_1count(cox); - if (3 <= max_count) c_update_1count(cox); - if (4 <= max_count) c_update_1count(cox); - } - - template - EIGEN_ALWAYS_INLINE void c_update(Scalar *&co1, Scalar *&co2) - { - auto &alpha_reg = zmm[alpha_load_reg]; - - co2 = co1 + ldc; - if (!is_alpha1) alpha_reg = pload1(alpha); - if (!is_unit_inc && a_unroll < nelems_in_cache_line) - mask = static_cast((1ull << a_unroll) - 1); - - static_assert(max_b_unroll <= 8, "Unsupported max_b_unroll"); - - if (1 <= max_b_unroll && 1 <= b_unroll) c_update_1pow<1, a_unroll>(co1, co2); - if (2 <= max_b_unroll && 2 <= b_unroll) c_update_1pow<2, a_unroll>(co1, co2); - if (4 <= max_b_unroll && 4 <= b_unroll) c_update_1pow<4, a_unroll>(co1, co2); - if (8 <= max_b_unroll && 8 <= b_unroll) c_update_1pow<8, a_unroll>(co1, co2); - - if (b_unroll == 1) - co1 += ldc; - else - co1 = co2 + ldc; - } - - // compute - template - EIGEN_ALWAYS_INLINE std::enable_if_t<(um > um_vecs)> - compute(const Scalar *ao, const Scalar *bo, int &fetchA_idx, int &fetchB_idx, vec &b_reg) - { - EIGEN_UNUSED_VARIABLE(ao); - EIGEN_UNUSED_VARIABLE(bo); - EIGEN_UNUSED_VARIABLE(fetchA_idx); - EIGEN_UNUSED_VARIABLE(fetchB_idx); - EIGEN_UNUSED_VARIABLE(b_reg); - } - - template - EIGEN_ALWAYS_INLINE std::enable_if_t<(um <= um_vecs)> - compute(const Scalar *ao, const Scalar *bo, int &fetchA_idx, int &fetchB_idx, vec &b_reg) - { - if (um < um_vecs) { - auto &c_reg = zmm[c_regs[um + idx * 3]]; - auto &a_reg = zmm[a_regs[um + (uk % 2) * 3]]; - - vfmadd(c_reg, a_reg, b_reg); - - if (!fetch_x && um == 0 && (((idx == 0 || idx == 6) && (uk % 2 == 0 || is_f64 || ktail)) || (idx == 3 && (uk % 2 == 1 || is_f64 || ktail)))) { - prefetch_a(ao + nelems_in_cache_line * fetchA_idx); - fetchA_idx++; - } - - if (um == 0 && idx == 1 && (uk % 2 == 0 || is_f64 || ktail)) { - prefetch_b(bo + nelems_in_cache_line * fetchB_idx); - fetchB_idx++; - } - - compute(ao, bo, fetchA_idx, fetchB_idx, b_reg); - } - } - - // load_a - template - EIGEN_ALWAYS_INLINE std::enable_if_t<(um > um_vecs)> - load_a(const Scalar *ao) - { - EIGEN_UNUSED_VARIABLE(ao); - } - - template - EIGEN_ALWAYS_INLINE std::enable_if_t<(um <= um_vecs)> - load_a(const Scalar *ao) - { - if (um < um_vecs) { - auto &a_reg = zmm[a_regs[um + (uk % 2) * 3]]; - const Scalar *a_addr = ao - + nelems * (1 + !ktail * !use_less_a_regs + uk) - + nelems_in_cache_line * um - a_shift; - a_load(a_reg, a_addr); - - load_a(ao); - } - } - template - EIGEN_ALWAYS_INLINE std::enable_if_t<(count > (pow + 1) / 2)> - innerkernel_1pow(const Scalar *&aa, const Scalar * const &ao, const Scalar * const &bo, Scalar *&co2, int &fetchA_idx, int &fetchB_idx) - { - EIGEN_UNUSED_VARIABLE(aa); - EIGEN_UNUSED_VARIABLE(ao); - EIGEN_UNUSED_VARIABLE(bo); - EIGEN_UNUSED_VARIABLE(co2); - EIGEN_UNUSED_VARIABLE(fetchA_idx); - EIGEN_UNUSED_VARIABLE(fetchB_idx); - } - - template - EIGEN_ALWAYS_INLINE std::enable_if_t<(count <= (pow + 1) / 2)> - innerkernel_1pow(const Scalar *&aa, const Scalar * const &ao, const Scalar * const &bo, Scalar *&co2, int &fetchA_idx, int &fetchB_idx) - { - const int idx = (pow / 2) + count; - - if (count < (pow + 1) / 2) { - auto &b_reg = zmm[b_regs[idx % 2]]; - - if (fetch_x && uk == 3 && idx == 0) prefetch_x(aa); - if (fetch_x && uk == 3 && idx == 4) aa += 8; - - if (b_unroll >= pow) { - - compute<0, um_vecs, idx, uk, fetch_x, ktail>(ao, bo, fetchA_idx, fetchB_idx, b_reg); - - const Scalar *b_addr = bo + b_unroll * uk + idx + 1 - + (b_unroll > 1) * !use_less_b_regs - b_shift; - b_load(b_reg, b_addr); - } - - // Go to the next count. - innerkernel_1pow(aa, ao, bo, co2, fetchA_idx, fetchB_idx); + Scalar *co = (un + 1 <= 4) ? co1 : co2; + auto co_off = (un % 4) * ldc + a_unroll - 1 + i * nelems_in_cache_line * sizeof *co; + prefetch_c(co + co_off); + prefetch_cs(co1, co2); } else { - // Maybe prefetch C data after count-loop. - if (pow == 2 && c_fetch) { - if (uk % 3 == 0 && uk > 0) { - co2 += ldc; - } else { - prefetch_c(co2 + (uk % 3) * nelems_in_cache_line); - } - } + prefetch_cs(co1, co2); } + + } else { + prefetch_cs(co1, co2); + } } + } - template - EIGEN_ALWAYS_INLINE void innerkernel_1uk(const Scalar *&aa, const Scalar * const &ao, const Scalar * const &bo, Scalar *&co2, int &fetchA_idx, int &fetchB_idx) - { - const int um_vecs = div_up(a_unroll, nelems_in_cache_line); + // load_c + template + EIGEN_ALWAYS_INLINE std::enable_if_t<(i > um_vecs)> scale_load_c(const Scalar *cox, vec &alpha_reg) { + EIGEN_UNUSED_VARIABLE(cox); + EIGEN_UNUSED_VARIABLE(alpha_reg); + } - if (max_b_unroll >= 1) innerkernel_1pow(aa, ao, bo, co2, fetchA_idx, fetchB_idx); - if (max_b_unroll >= 2) innerkernel_1pow(aa, ao, bo, co2, fetchA_idx, fetchB_idx); - if (max_b_unroll >= 4) innerkernel_1pow(aa, ao, bo, co2, fetchA_idx, fetchB_idx); - if (max_b_unroll >= 8) innerkernel_1pow(aa, ao, bo, co2, fetchA_idx, fetchB_idx); + template + EIGEN_ALWAYS_INLINE std::enable_if_t<(i <= um_vecs)> scale_load_c(const Scalar *cox, vec &alpha_reg) { + if (i < um_vecs) { + auto &c_reg = zmm[c_regs[i + idx * 3]]; + auto &c_load_reg = zmm[c_load_regs[i % 3]]; + auto c_mem = cox; + if (is_unit_inc) + c_mem += i * nelems_in_cache_line; + else + c_mem += i * nelems_in_cache_line * inc; - // Load A after pow-loop. - load_a<0, um_vecs, uk, a_unroll, ktail>(ao); + if (!is_beta0 && is_alpha1) + vaddm(c_reg, c_mem, c_reg, c_load_reg); + else if (!is_beta0 && !is_alpha1) + vfmaddm(c_reg, c_mem, c_reg, alpha_reg, c_load_reg); + else if (is_beta0 && !is_alpha1) + c_reg = pmul(alpha_reg, c_reg); + + scale_load_c(cox, alpha_reg); } + } - /* Inner kernel loop structure. - * for (int uk = 0; uk < kfactor; uk++) { - * int idx = 0; - * - * for (pow = 1; pow < max_b_unroll << 1; pow <<= 1) { - * for (int count = 0; count < (pow + 1) / 2; count++) { - * auto &b_reg = zmm[b_regs[idx % 2]]; - * - * if (fetch_x && uk == 3 && idx == 0) prefetch_x(aa); - * if (fetch_x && uk == 3 && idx == 4) aa += 8; - * - * if (b_unroll >= pow) { - * compute<0, um_vecs, idx, uk, fetchx, ktail>(ao, bo, fetchA_idx, fetchB_idx, b_reg); - * - * const Scalar *b_addr = bo + b_unroll * uk + idx + 1 + (b_unroll > 1) - b_shift ; - * b_load(b_reg, b_addr); - * } - * idx++; - * } - * - * Maybe prefetch C data. - * if (pow == 2 && c_fetch) { - * if (uk % 3 == 0 && uk > 0) { - * co2 += ldc; - * } else { - * prefetch_c(co2 + (uk % 3) * nelems_in_cache_line); - * } - * } - * } - * - * Load A. - * load_a<0, um_vecs, uk, ktail, a_unroll>(ao); - * } - * - * Advance A/B pointers after uk-loop. - * ao += a_unroll * kfactor; - * bo += b_unroll * kfactor; - */ + // store_c + template + EIGEN_ALWAYS_INLINE std::enable_if_t<(i > um_vecs)> write_c(Scalar *cox) { + EIGEN_UNUSED_VARIABLE(cox); + } - template - EIGEN_ALWAYS_INLINE void innerkernel(const Scalar *&aa, const Scalar *&ao, const Scalar *&bo, Scalar *&co2) - { - int fetchA_idx = 0; - int fetchB_idx = 0; + template + EIGEN_ALWAYS_INLINE std::enable_if_t<(i <= um_vecs)> write_c(Scalar *cox) { + if (i < um_vecs) { + auto &c_reg = zmm[c_regs[i + idx * 3]]; + auto c_mem = cox; + if (is_unit_inc) + c_mem += i * nelems_in_cache_line; + else + c_mem += i * nelems_in_cache_line * inc; - const bool fetch_x = k_factor == max_k_factor; - const bool ktail = k_factor == 1; + c_store(c_mem, c_reg); + c_reg = pzero(c_reg); - static_assert(k_factor <= 4 && k_factor > 0, - "innerkernel maximum k_factor supported is 4"); - - if (k_factor > 0) innerkernel_1uk<0, max_b_unroll, a_unroll, b_unroll, ktail, fetch_x, c_fetch>(aa, ao, bo, co2, fetchA_idx, fetchB_idx); - if (k_factor > 1) innerkernel_1uk<1, max_b_unroll, a_unroll, b_unroll, ktail, fetch_x, c_fetch>(aa, ao, bo, co2, fetchA_idx, fetchB_idx); - if (k_factor > 2) innerkernel_1uk<2, max_b_unroll, a_unroll, b_unroll, ktail, fetch_x, c_fetch>(aa, ao, bo, co2, fetchA_idx, fetchB_idx); - if (k_factor > 3) innerkernel_1uk<3, max_b_unroll, a_unroll, b_unroll, ktail, fetch_x, c_fetch>(aa, ao, bo, co2, fetchA_idx, fetchB_idx); - - // Advance A/B pointers after uk-loop. - ao += a_unroll * k_factor; - bo += b_unroll * k_factor; + write_c(cox); } + } + /* C update loop structure. + * co2 = co1 + ldc; + * + * auto &alpha_reg = zmm[alpha_load_reg]; + * if (!is_alpha1) alpha_reg = pload1(alpha); + * + * int idx = 0; + * for (pow = 1; pow <= 8; pow <<= 1) { + * + * if (b_unroll >= pow) { + * for (count = 1; count < (pow + 1) / 2 + 1; count++) { + * if (pow >= 4) co2 += ldc; + * + * const Scalar *cox = (idx == 0) ? co1 : co2; + * + * const int um_vecs = div_up(a_unroll, nelems_in_cache_line); + * scale_load_c<0, um_vecs, idx, a_unroll>(cox, alpha_reg); + * write_c<0, um_vecs, idx, a_unroll>(cox); + * + * idx++; + * } + * } + * } + * + * if (b_unroll == 1) + * co1 += ldc; + * else + * co1 = co2 + ldc; + */ - template - EIGEN_ALWAYS_INLINE void kloop(const Scalar *&aa, const Scalar *&ao, const Scalar *&bo, Scalar *&co1, Scalar *&co2) - { - const int um_vecs = div_up(a_unroll, nelems_in_cache_line); - if (!use_less_a_regs) - a_loads<0, 2, 0, um_vecs, a_unroll>(ao); - else - a_loads<0, 1, 0, um_vecs, a_unroll>(ao); + template + EIGEN_ALWAYS_INLINE void c_update_1count(Scalar *&cox) { + if (pow >= 4) cox += ldc; - b_load(zmm[b_regs[0]], bo - b_shift + 0); - if (!use_less_b_regs) - b_load(zmm[b_regs[1]], bo - b_shift + 1); + const int um_vecs = div_up(a_unroll, nelems_in_cache_line); + auto &alpha_reg = zmm[alpha_load_reg]; + + scale_load_c<0, um_vecs, idx, a_unroll>(cox, alpha_reg); + write_c<0, um_vecs, idx, a_unroll>(cox); + } + + template + EIGEN_ALWAYS_INLINE void c_update_1pow(Scalar *&co1, Scalar *&co2) { + constexpr int idx = pow / 2; + Scalar *&cox = idx == 0 ? co1 : co2; + + constexpr int max_count = (pow + 1) / 2; + static_assert(max_count <= 4, "Unsupported max_count."); + + if (1 <= max_count) c_update_1count(cox); + if (2 <= max_count) c_update_1count(cox); + if (3 <= max_count) c_update_1count(cox); + if (4 <= max_count) c_update_1count(cox); + } + + template + EIGEN_ALWAYS_INLINE void c_update(Scalar *&co1, Scalar *&co2) { + auto &alpha_reg = zmm[alpha_load_reg]; + + co2 = co1 + ldc; + if (!is_alpha1) alpha_reg = pload1(alpha); + if (!is_unit_inc && a_unroll < nelems_in_cache_line) mask = static_cast((1ull << a_unroll) - 1); + + static_assert(max_b_unroll <= 8, "Unsupported max_b_unroll"); + + if (1 <= max_b_unroll && 1 <= b_unroll) c_update_1pow<1, a_unroll>(co1, co2); + if (2 <= max_b_unroll && 2 <= b_unroll) c_update_1pow<2, a_unroll>(co1, co2); + if (4 <= max_b_unroll && 4 <= b_unroll) c_update_1pow<4, a_unroll>(co1, co2); + if (8 <= max_b_unroll && 8 <= b_unroll) c_update_1pow<8, a_unroll>(co1, co2); + + if (b_unroll == 1) + co1 += ldc; + else + co1 = co2 + ldc; + } + + // compute + template + EIGEN_ALWAYS_INLINE std::enable_if_t<(um > um_vecs)> compute(const Scalar *ao, const Scalar *bo, int &fetchA_idx, + int &fetchB_idx, vec &b_reg) { + EIGEN_UNUSED_VARIABLE(ao); + EIGEN_UNUSED_VARIABLE(bo); + EIGEN_UNUSED_VARIABLE(fetchA_idx); + EIGEN_UNUSED_VARIABLE(fetchB_idx); + EIGEN_UNUSED_VARIABLE(b_reg); + } + + template + EIGEN_ALWAYS_INLINE std::enable_if_t<(um <= um_vecs)> compute(const Scalar *ao, const Scalar *bo, int &fetchA_idx, + int &fetchB_idx, vec &b_reg) { + if (um < um_vecs) { + auto &c_reg = zmm[c_regs[um + idx * 3]]; + auto &a_reg = zmm[a_regs[um + (uk % 2) * 3]]; + + vfmadd(c_reg, a_reg, b_reg); + + if (!fetch_x && um == 0 && + (((idx == 0 || idx == 6) && (uk % 2 == 0 || is_f64 || ktail)) || + (idx == 3 && (uk % 2 == 1 || is_f64 || ktail)))) { + prefetch_a(ao + nelems_in_cache_line * fetchA_idx); + fetchA_idx++; + } + + if (um == 0 && idx == 1 && (uk % 2 == 0 || is_f64 || ktail)) { + prefetch_b(bo + nelems_in_cache_line * fetchB_idx); + fetchB_idx++; + } + + compute(ao, bo, fetchA_idx, fetchB_idx, b_reg); + } + } + + // load_a + template + EIGEN_ALWAYS_INLINE std::enable_if_t<(um > um_vecs)> load_a(const Scalar *ao) { + EIGEN_UNUSED_VARIABLE(ao); + } + + template + EIGEN_ALWAYS_INLINE std::enable_if_t<(um <= um_vecs)> load_a(const Scalar *ao) { + if (um < um_vecs) { + auto &a_reg = zmm[a_regs[um + (uk % 2) * 3]]; + const Scalar *a_addr = ao + nelems * (1 + !ktail * !use_less_a_regs + uk) + nelems_in_cache_line * um - a_shift; + a_load(a_reg, a_addr); + + load_a(ao); + } + } + template + EIGEN_ALWAYS_INLINE std::enable_if_t<(count > (pow + 1) / 2)> innerkernel_1pow(const Scalar *&aa, + const Scalar *const &ao, + const Scalar *const &bo, Scalar *&co2, + int &fetchA_idx, int &fetchB_idx) { + EIGEN_UNUSED_VARIABLE(aa); + EIGEN_UNUSED_VARIABLE(ao); + EIGEN_UNUSED_VARIABLE(bo); + EIGEN_UNUSED_VARIABLE(co2); + EIGEN_UNUSED_VARIABLE(fetchA_idx); + EIGEN_UNUSED_VARIABLE(fetchB_idx); + } + + template + EIGEN_ALWAYS_INLINE std::enable_if_t<(count <= (pow + 1) / 2)> innerkernel_1pow(const Scalar *&aa, + const Scalar *const &ao, + const Scalar *const &bo, Scalar *&co2, + int &fetchA_idx, int &fetchB_idx) { + const int idx = (pow / 2) + count; + + if (count < (pow + 1) / 2) { + auto &b_reg = zmm[b_regs[idx % 2]]; + + if (fetch_x && uk == 3 && idx == 0) prefetch_x(aa); + if (fetch_x && uk == 3 && idx == 4) aa += 8; + + if (b_unroll >= pow) { + compute<0, um_vecs, idx, uk, fetch_x, ktail>(ao, bo, fetchA_idx, fetchB_idx, b_reg); + + const Scalar *b_addr = bo + b_unroll * uk + idx + 1 + (b_unroll > 1) * !use_less_b_regs - b_shift; + b_load(b_reg, b_addr); + } + + // Go to the next count. + innerkernel_1pow(aa, ao, bo, co2, fetchA_idx, + fetchB_idx); + + } else { + // Maybe prefetch C data after count-loop. + if (pow == 2 && c_fetch) { + if (uk % 3 == 0 && uk > 0) { + co2 += ldc; + } else { + prefetch_c(co2 + (uk % 3) * nelems_in_cache_line); + } + } + } + } + + template + EIGEN_ALWAYS_INLINE void innerkernel_1uk(const Scalar *&aa, const Scalar *const &ao, const Scalar *const &bo, + Scalar *&co2, int &fetchA_idx, int &fetchB_idx) { + const int um_vecs = div_up(a_unroll, nelems_in_cache_line); + + if (max_b_unroll >= 1) + innerkernel_1pow(aa, ao, bo, co2, fetchA_idx, fetchB_idx); + if (max_b_unroll >= 2) + innerkernel_1pow(aa, ao, bo, co2, fetchA_idx, fetchB_idx); + if (max_b_unroll >= 4) + innerkernel_1pow(aa, ao, bo, co2, fetchA_idx, fetchB_idx); + if (max_b_unroll >= 8) + innerkernel_1pow(aa, ao, bo, co2, fetchA_idx, fetchB_idx); + + // Load A after pow-loop. + load_a<0, um_vecs, uk, a_unroll, ktail>(ao); + } + + /* Inner kernel loop structure. + * for (int uk = 0; uk < kfactor; uk++) { + * int idx = 0; + * + * for (pow = 1; pow < max_b_unroll << 1; pow <<= 1) { + * for (int count = 0; count < (pow + 1) / 2; count++) { + * auto &b_reg = zmm[b_regs[idx % 2]]; + * + * if (fetch_x && uk == 3 && idx == 0) prefetch_x(aa); + * if (fetch_x && uk == 3 && idx == 4) aa += 8; + * + * if (b_unroll >= pow) { + * compute<0, um_vecs, idx, uk, fetchx, ktail>(ao, bo, fetchA_idx, fetchB_idx, b_reg); + * + * const Scalar *b_addr = bo + b_unroll * uk + idx + 1 + (b_unroll > 1) - b_shift ; + * b_load(b_reg, b_addr); + * } + * idx++; + * } + * + * Maybe prefetch C data. + * if (pow == 2 && c_fetch) { + * if (uk % 3 == 0 && uk > 0) { + * co2 += ldc; + * } else { + * prefetch_c(co2 + (uk % 3) * nelems_in_cache_line); + * } + * } + * } + * + * Load A. + * load_a<0, um_vecs, uk, ktail, a_unroll>(ao); + * } + * + * Advance A/B pointers after uk-loop. + * ao += a_unroll * kfactor; + * bo += b_unroll * kfactor; + */ + + template + EIGEN_ALWAYS_INLINE void innerkernel(const Scalar *&aa, const Scalar *&ao, const Scalar *&bo, Scalar *&co2) { + int fetchA_idx = 0; + int fetchB_idx = 0; + + const bool fetch_x = k_factor == max_k_factor; + const bool ktail = k_factor == 1; + + static_assert(k_factor <= 4 && k_factor > 0, "innerkernel maximum k_factor supported is 4"); + + if (k_factor > 0) + innerkernel_1uk<0, max_b_unroll, a_unroll, b_unroll, ktail, fetch_x, c_fetch>(aa, ao, bo, co2, fetchA_idx, + fetchB_idx); + if (k_factor > 1) + innerkernel_1uk<1, max_b_unroll, a_unroll, b_unroll, ktail, fetch_x, c_fetch>(aa, ao, bo, co2, fetchA_idx, + fetchB_idx); + if (k_factor > 2) + innerkernel_1uk<2, max_b_unroll, a_unroll, b_unroll, ktail, fetch_x, c_fetch>(aa, ao, bo, co2, fetchA_idx, + fetchB_idx); + if (k_factor > 3) + innerkernel_1uk<3, max_b_unroll, a_unroll, b_unroll, ktail, fetch_x, c_fetch>(aa, ao, bo, co2, fetchA_idx, + fetchB_idx); + + // Advance A/B pointers after uk-loop. + ao += a_unroll * k_factor; + bo += b_unroll * k_factor; + } + + template + EIGEN_ALWAYS_INLINE void kloop(const Scalar *&aa, const Scalar *&ao, const Scalar *&bo, Scalar *&co1, Scalar *&co2) { + const int um_vecs = div_up(a_unroll, nelems_in_cache_line); + if (!use_less_a_regs) + a_loads<0, 2, 0, um_vecs, a_unroll>(ao); + else + a_loads<0, 1, 0, um_vecs, a_unroll>(ao); + + b_load(zmm[b_regs[0]], bo - b_shift + 0); + if (!use_less_b_regs) b_load(zmm[b_regs[1]], bo - b_shift + 1); #ifndef SECOND_FETCH - prefetch_cs<0, max_b_unroll, 0, um_vecs, a_unroll, b_unroll>(co1, co2); -#endif // SECOND_FETCH + prefetch_cs<0, max_b_unroll, 0, um_vecs, a_unroll, b_unroll>(co1, co2); +#endif // SECOND_FETCH - // Unrolling k-loop by a factor of 4. - const int max_k_factor = 4; - Index loop_count = k / max_k_factor; + // Unrolling k-loop by a factor of 4. + const int max_k_factor = 4; + Index loop_count = k / max_k_factor; - if (loop_count > 0) { + if (loop_count > 0) { #ifdef SECOND_FETCH - loop_count -= SECOND_FETCH; + loop_count -= SECOND_FETCH; #endif - while (loop_count > 0) { - innerkernel(aa, ao, bo, co2); - loop_count--; - } + while (loop_count > 0) { + innerkernel(aa, ao, bo, co2); + loop_count--; + } #ifdef SECOND_FETCH - co2 = co1 + nelems_in_cache_line - 1; + co2 = co1 + nelems_in_cache_line - 1; - loop_count += b_unroll; - while (loop_count > 0) { - innerkernel(aa, ao, bo, co2); - loop_count--; - } + loop_count += b_unroll; + while (loop_count > 0) { + innerkernel(aa, ao, bo, co2); + loop_count--; + } - loop_count += SECOND_FETCH - b_unroll; - while (loop_count > 0) { - innerkernel(aa, ao, bo, co2); - loop_count--; - } + loop_count += SECOND_FETCH - b_unroll; + while (loop_count > 0) { + innerkernel(aa, ao, bo, co2); + loop_count--; + } #endif - } - - // k-loop remainder handling. - loop_count = k % max_k_factor; - while (loop_count > 0) { - innerkernel(aa, ao, bo, co2); - loop_count--; - } - - // Update C matrix. - c_update(co1, co2); } - template - EIGEN_ALWAYS_INLINE void nloop(const Scalar *&aa, const Scalar *&ao, const Scalar *&bo, Scalar *&co1, Scalar *&co2) - { - // Set A matrix pointer. - ao = a + a_off * a_unroll; - - // Set B matrix pointer if needed. - bo += b_unroll * b_off; - - kloop(aa, ao, bo, co1, co2); - - // Advance B matrix pointer if needed. - bo += b_unroll * (b_stride - k - b_off); - - // Advance prefetch A pointer. - aa += 16; + // k-loop remainder handling. + loop_count = k % max_k_factor; + while (loop_count > 0) { + innerkernel(aa, ao, bo, co2); + loop_count--; } - template - EIGEN_ALWAYS_INLINE void mloop(const Scalar *&ao, const Scalar *&bo, Scalar *&co1, Scalar *&co2) - { - // Set prefetch A pointers. - const Scalar *aa = a + a_unroll * a_stride; + // Update C matrix. + c_update(co1, co2); + } - // Set C matrix pointers. - co1 = c; - if (a_unroll >= max_a_unroll) co2 = c + 2 * ldc; - if (is_unit_inc) - c += a_unroll; - else - c += a_unroll * inc; + template + EIGEN_ALWAYS_INLINE void nloop(const Scalar *&aa, const Scalar *&ao, const Scalar *&bo, Scalar *&co1, Scalar *&co2) { + // Set A matrix pointer. + ao = a + a_off * a_unroll; - // Set B matrix pointer. - bo = b; + // Set B matrix pointer if needed. + bo += b_unroll * b_off; - // Main n-loop. - for (Index i = n / max_b_unroll; i > 0; i--) - nloop(aa, ao, bo, co1, co2); + kloop(aa, ao, bo, co1, co2); - // n-remainders. - if (n & 4 && max_b_unroll > 4) nloop(aa, ao, bo, co1, co2); + // Advance B matrix pointer if needed. + bo += b_unroll * (b_stride - k - b_off); + + // Advance prefetch A pointer. + aa += 16; + } + + template + EIGEN_ALWAYS_INLINE void mloop(const Scalar *&ao, const Scalar *&bo, Scalar *&co1, Scalar *&co2) { + // Set prefetch A pointers. + const Scalar *aa = a + a_unroll * a_stride; + + // Set C matrix pointers. + co1 = c; + if (a_unroll >= max_a_unroll) co2 = c + 2 * ldc; + if (is_unit_inc) + c += a_unroll; + else + c += a_unroll * inc; + + // Set B matrix pointer. + bo = b; + + // Main n-loop. + for (Index i = n / max_b_unroll; i > 0; i--) nloop(aa, ao, bo, co1, co2); + + // n-remainders. + if (n & 4 && max_b_unroll > 4) nloop(aa, ao, bo, co1, co2); #if 0 if (n & 2 && max_b_unroll > 2) nloop(aa, ao, bo, co1, co2); if (n & 1 && max_b_unroll > 1) nloop(aa, ao, bo, co1, co2); #else - // Copy kernels don't support tails of n = 2 for single/double precision. - // Loop over ones. - int n_rem = 2 * ((n & 2) != 0) + 1 * ((n & 1) != 0); - while (n_rem > 0) {nloop(aa, ao, bo, co1, co2); n_rem--;} + // Copy kernels don't support tails of n = 2 for single/double precision. + // Loop over ones. + int n_rem = 2 * ((n & 2) != 0) + 1 * ((n & 1) != 0); + while (n_rem > 0) { + nloop(aa, ao, bo, co1, co2); + n_rem--; + } #endif - // Advance A matrix pointer. - a = ao + a_unroll * (a_stride - k - a_off); + // Advance A matrix pointer. + a = ao + a_unroll * (a_stride - k - a_off); + } + + public: + // Compute kernel unrolling C matrix by max_a_unroll x max_b_unroll. + template + EIGEN_ALWAYS_INLINE void compute_kern() { + a -= -a_shift; + b -= -b_shift; + + const Scalar *ao = nullptr; + const Scalar *bo = nullptr; + Scalar *co1 = nullptr; + Scalar *co2 = nullptr; + + // Main m-loop. + for (; m >= max_a_unroll; m -= max_a_unroll) mloop(ao, bo, co1, co2); + + // m-remainders. + if (m & 32 && max_a_unroll > 32) mloop<32, max_a_unroll, max_b_unroll>(ao, bo, co1, co2); + if (m & 16 && max_a_unroll > 16) mloop<16, max_a_unroll, max_b_unroll>(ao, bo, co1, co2); + if (m & 8 && max_a_unroll > 8) mloop<8, max_a_unroll, max_b_unroll>(ao, bo, co1, co2); + if (m & 4 && max_a_unroll > 4) mloop<4, max_a_unroll, max_b_unroll>(ao, bo, co1, co2); + if (m & 2 && max_a_unroll > 2 && is_f64) mloop<2, max_a_unroll, max_b_unroll>(ao, bo, co1, co2); + if (m & 1 && max_a_unroll > 1 && is_f64) mloop<1, max_a_unroll, max_b_unroll>(ao, bo, co1, co2); + + // Copy kernels don't support tails of m = 2 for single precision. + // Loop over ones. + if (is_f32) { + int m_rem = 2 * ((m & 2) != 0) + 1 * ((m & 1) != 0); + while (m_rem > 0) { + mloop<1, max_a_unroll, max_b_unroll>(ao, bo, co1, co2); + m_rem--; + } } + } -public: - // Compute kernel unrolling C matrix by max_a_unroll x max_b_unroll. - template - EIGEN_ALWAYS_INLINE void compute_kern() - { - a -= -a_shift; - b -= -b_shift; - - const Scalar *ao = nullptr; - const Scalar *bo = nullptr; - Scalar *co1 = nullptr; - Scalar *co2 = nullptr; - - // Main m-loop. - for (; m >= max_a_unroll; m -= max_a_unroll) - mloop(ao, bo, co1, co2); - - // m-remainders. - if (m & 32 && max_a_unroll > 32) mloop<32, max_a_unroll, max_b_unroll>(ao, bo, co1, co2); - if (m & 16 && max_a_unroll > 16) mloop<16, max_a_unroll, max_b_unroll>(ao, bo, co1, co2); - if (m & 8 && max_a_unroll > 8) mloop< 8, max_a_unroll, max_b_unroll>(ao, bo, co1, co2); - if (m & 4 && max_a_unroll > 4) mloop< 4, max_a_unroll, max_b_unroll>(ao, bo, co1, co2); - if (m & 2 && max_a_unroll > 2 && is_f64) mloop< 2, max_a_unroll, max_b_unroll>(ao, bo, co1, co2); - if (m & 1 && max_a_unroll > 1 && is_f64) mloop< 1, max_a_unroll, max_b_unroll>(ao, bo, co1, co2); - - // Copy kernels don't support tails of m = 2 for single precision. - // Loop over ones. - if (is_f32) { - int m_rem = 2 * ((m & 2) != 0) + 1 * ((m & 1) != 0); - while (m_rem > 0) {mloop< 1, max_a_unroll, max_b_unroll>(ao, bo, co1, co2); m_rem--;} - } - } - - gemm_class(Index m_, Index n_, Index k_, Index ldc_, Index inc_, - const Scalar *alpha_, - const Scalar *a_, const Scalar *b_, Scalar *c_, - bool is_alpha1_, bool is_beta0_, - Index a_stride_, Index b_stride_, - Index a_off_, Index b_off_) - : m(m_) - , n(n_) - , k(k_) - , ldc(ldc_) - , inc(inc_) - , alpha(alpha_) - , a(a_) - , b(b_) - , c(c_) - , is_alpha1(is_alpha1_) - , is_beta0(is_beta0_) - , a_stride(a_stride_) - , b_stride(b_stride_) - , a_off(a_off_) - , b_off(b_off_) - { - // Zero out all accumulation registers. - zmm[8 ] = pzero(zmm[8 ]); - zmm[9 ] = pzero(zmm[9 ]); - zmm[10] = pzero(zmm[10]); - zmm[11] = pzero(zmm[11]); - zmm[12] = pzero(zmm[12]); - zmm[13] = pzero(zmm[13]); - zmm[14] = pzero(zmm[14]); - zmm[15] = pzero(zmm[15]); - zmm[16] = pzero(zmm[16]); - zmm[17] = pzero(zmm[17]); - zmm[18] = pzero(zmm[18]); - zmm[19] = pzero(zmm[19]); - zmm[20] = pzero(zmm[20]); - zmm[21] = pzero(zmm[21]); - zmm[22] = pzero(zmm[22]); - zmm[23] = pzero(zmm[23]); - zmm[24] = pzero(zmm[24]); - zmm[25] = pzero(zmm[25]); - zmm[26] = pzero(zmm[26]); - zmm[27] = pzero(zmm[27]); - zmm[28] = pzero(zmm[28]); - zmm[29] = pzero(zmm[29]); - zmm[30] = pzero(zmm[30]); - zmm[31] = pzero(zmm[31]); - } + gemm_class(Index m_, Index n_, Index k_, Index ldc_, Index inc_, const Scalar *alpha_, const Scalar *a_, + const Scalar *b_, Scalar *c_, bool is_alpha1_, bool is_beta0_, Index a_stride_, Index b_stride_, + Index a_off_, Index b_off_) + : m(m_), + n(n_), + k(k_), + ldc(ldc_), + inc(inc_), + alpha(alpha_), + a(a_), + b(b_), + c(c_), + is_alpha1(is_alpha1_), + is_beta0(is_beta0_), + a_stride(a_stride_), + b_stride(b_stride_), + a_off(a_off_), + b_off(b_off_) { + // Zero out all accumulation registers. + zmm[8] = pzero(zmm[8]); + zmm[9] = pzero(zmm[9]); + zmm[10] = pzero(zmm[10]); + zmm[11] = pzero(zmm[11]); + zmm[12] = pzero(zmm[12]); + zmm[13] = pzero(zmm[13]); + zmm[14] = pzero(zmm[14]); + zmm[15] = pzero(zmm[15]); + zmm[16] = pzero(zmm[16]); + zmm[17] = pzero(zmm[17]); + zmm[18] = pzero(zmm[18]); + zmm[19] = pzero(zmm[19]); + zmm[20] = pzero(zmm[20]); + zmm[21] = pzero(zmm[21]); + zmm[22] = pzero(zmm[22]); + zmm[23] = pzero(zmm[23]); + zmm[24] = pzero(zmm[24]); + zmm[25] = pzero(zmm[25]); + zmm[26] = pzero(zmm[26]); + zmm[27] = pzero(zmm[27]); + zmm[28] = pzero(zmm[28]); + zmm[29] = pzero(zmm[29]); + zmm[30] = pzero(zmm[30]); + zmm[31] = pzero(zmm[31]); + } }; // Compute kernel with max unroll support of: @@ -856,81 +919,74 @@ public: // max_a_unroll: 24, 16, 8, 4, 2, 1 // max_b_unroll: 8, 4, 2, 1 template -EIGEN_DONT_INLINE void gemm_kern_avx512( - Index m, Index n, Index k, - Scalar *alpha, const Scalar *a, const Scalar *b, Scalar *c, - Index ldc, Index inc = 1, - Index a_stride = -1, Index b_stride = -1, - Index a_off = 0, Index b_off = 0) -{ - if (a_stride == -1) a_stride = k; - if (b_stride == -1) b_stride = k; +EIGEN_DONT_INLINE void gemm_kern_avx512(Index m, Index n, Index k, Scalar *alpha, const Scalar *a, const Scalar *b, + Scalar *c, Index ldc, Index inc = 1, Index a_stride = -1, Index b_stride = -1, + Index a_off = 0, Index b_off = 0) { + if (a_stride == -1) a_stride = k; + if (b_stride == -1) b_stride = k; - gemm_class g(m, n, k, ldc, inc, alpha, - a, b, c, is_alpha1, is_beta0, a_stride, b_stride, a_off, b_off); - g.template compute_kern(); + gemm_class g(m, n, k, ldc, inc, alpha, a, b, c, is_alpha1, is_beta0, a_stride, b_stride, a_off, + b_off); + g.template compute_kern(); } -template -class gebp_traits : - public gebp_traits { +template +class gebp_traits + : public gebp_traits { using Base = gebp_traits; -public: - enum {nr = Base::Vectorizable ? 8 : 4}; + public: + enum { nr = Base::Vectorizable ? 8 : 4 }; }; -template -class gebp_traits : - public gebp_traits { +template +class gebp_traits + : public gebp_traits { using Base = gebp_traits; -public: - enum {nr = Base::Vectorizable ? 8 : 4}; + public: + enum { nr = Base::Vectorizable ? 8 : 4 }; }; -template -struct gemm_pack_rhs -{ +template +struct gemm_pack_rhs { typedef typename packet_traits::type Packet; typedef typename DataMapper::LinearMapper LinearMapper; enum { PacketSize = packet_traits::size }; - EIGEN_DONT_INLINE void operator()(Scalar* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride=0, Index offset=0); + EIGEN_DONT_INLINE void operator()(Scalar *blockB, const DataMapper &rhs, Index depth, Index cols, Index stride = 0, + Index offset = 0); }; -template -EIGEN_DONT_INLINE void gemm_pack_rhs - ::operator()(Scalar* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride, Index offset) -{ +template +EIGEN_DONT_INLINE void gemm_pack_rhs::operator()( + Scalar *blockB, const DataMapper &rhs, Index depth, Index cols, Index stride, Index offset) { constexpr int nr = 8; EIGEN_ASM_COMMENT("EIGEN PRODUCT PACK RHS COLMAJOR"); EIGEN_UNUSED_VARIABLE(stride); EIGEN_UNUSED_VARIABLE(offset); - eigen_assert(((!PanelMode) && stride==0 && offset==0) || (PanelMode && stride>=depth && offset<=stride)); + eigen_assert(((!PanelMode) && stride == 0 && offset == 0) || (PanelMode && stride >= depth && offset <= stride)); conj_if::IsComplex && Conjugate> cj; - Index packet_cols8 = nr>=8 ? (cols/8) * 8 : 0; - Index packet_cols4 = nr>=4 ? (cols/4) * 4 : 0; + Index packet_cols8 = nr >= 8 ? (cols / 8) * 8 : 0; + Index packet_cols4 = nr >= 4 ? (cols / 4) * 4 : 0; Index count = 0; - const Index peeled_k = (depth/PacketSize)*PacketSize; - if(nr>=8) - { - for(Index j2=0; j2= 8) { + for (Index j2 = 0; j2 < packet_cols8; j2 += 8) { // skip what we have before - if(PanelMode) count += 8 * offset; - const LinearMapper dm0 = rhs.getLinearMapper(0, j2+0); - const LinearMapper dm1 = rhs.getLinearMapper(0, j2+1); - const LinearMapper dm2 = rhs.getLinearMapper(0, j2+2); - const LinearMapper dm3 = rhs.getLinearMapper(0, j2+3); - const LinearMapper dm4 = rhs.getLinearMapper(0, j2+4); - const LinearMapper dm5 = rhs.getLinearMapper(0, j2+5); - const LinearMapper dm6 = rhs.getLinearMapper(0, j2+6); - const LinearMapper dm7 = rhs.getLinearMapper(0, j2+7); - Index k=0; - if((PacketSize%8)==0) // TODO enable vectorized transposition for PacketSize==4 + if (PanelMode) count += 8 * offset; + const LinearMapper dm0 = rhs.getLinearMapper(0, j2 + 0); + const LinearMapper dm1 = rhs.getLinearMapper(0, j2 + 1); + const LinearMapper dm2 = rhs.getLinearMapper(0, j2 + 2); + const LinearMapper dm3 = rhs.getLinearMapper(0, j2 + 3); + const LinearMapper dm4 = rhs.getLinearMapper(0, j2 + 4); + const LinearMapper dm5 = rhs.getLinearMapper(0, j2 + 5); + const LinearMapper dm6 = rhs.getLinearMapper(0, j2 + 6); + const LinearMapper dm7 = rhs.getLinearMapper(0, j2 + 7); + Index k = 0; + if ((PacketSize % 8) == 0) // TODO enable vectorized transposition for PacketSize==4 { - for(; k kernel; + for (; k < peeled_k; k += PacketSize) { + PacketBlock kernel; kernel.packet[0] = dm0.template loadPacket(k); kernel.packet[1] = dm1.template loadPacket(k); @@ -943,244 +999,227 @@ EIGEN_DONT_INLINE void gemm_pack_rhs=4) - { - for(Index j2=packet_cols8; j2= 4) { + for (Index j2 = packet_cols8; j2 < packet_cols4; j2 += 4) { // skip what we have before - if(PanelMode) count += 4 * offset; + if (PanelMode) count += 4 * offset; const LinearMapper dm0 = rhs.getLinearMapper(0, j2 + 0); const LinearMapper dm1 = rhs.getLinearMapper(0, j2 + 1); const LinearMapper dm2 = rhs.getLinearMapper(0, j2 + 2); const LinearMapper dm3 = rhs.getLinearMapper(0, j2 + 3); - Index k=0; - if((PacketSize%4)==0) // TODO enable vectorized transposition for PacketSize==2 ?? + Index k = 0; + if ((PacketSize % 4) == 0) // TODO enable vectorized transposition for PacketSize==2 ?? { - for(; k kernel; - kernel.packet[0 ] = dm0.template loadPacket(k); - kernel.packet[1%PacketSize] = dm1.template loadPacket(k); - kernel.packet[2%PacketSize] = dm2.template loadPacket(k); - kernel.packet[3%PacketSize] = dm3.template loadPacket(k); + for (; k < peeled_k; k += PacketSize) { + PacketBlock kernel; + kernel.packet[0] = dm0.template loadPacket(k); + kernel.packet[1 % PacketSize] = dm1.template loadPacket(k); + kernel.packet[2 % PacketSize] = dm2.template loadPacket(k); + kernel.packet[3 % PacketSize] = dm3.template loadPacket(k); ptranspose(kernel); - pstoreu(blockB+count+0*PacketSize, cj.pconj(kernel.packet[0])); - pstoreu(blockB+count+1*PacketSize, cj.pconj(kernel.packet[1%PacketSize])); - pstoreu(blockB+count+2*PacketSize, cj.pconj(kernel.packet[2%PacketSize])); - pstoreu(blockB+count+3*PacketSize, cj.pconj(kernel.packet[3%PacketSize])); - count+=4*PacketSize; + pstoreu(blockB + count + 0 * PacketSize, cj.pconj(kernel.packet[0])); + pstoreu(blockB + count + 1 * PacketSize, cj.pconj(kernel.packet[1 % PacketSize])); + pstoreu(blockB + count + 2 * PacketSize, cj.pconj(kernel.packet[2 % PacketSize])); + pstoreu(blockB + count + 3 * PacketSize, cj.pconj(kernel.packet[3 % PacketSize])); + count += 4 * PacketSize; } } - for(; k -struct gemm_pack_rhs -{ +template +struct gemm_pack_rhs { typedef typename packet_traits::type Packet; typedef typename unpacket_traits::half HalfPacket; typedef typename unpacket_traits::half>::half QuarterPacket; typedef typename DataMapper::LinearMapper LinearMapper; - enum { PacketSize = packet_traits::size, - HalfPacketSize = unpacket_traits::size, - QuarterPacketSize = unpacket_traits::size}; - EIGEN_DONT_INLINE void operator()(Scalar* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride=0, Index offset=0) - { + enum { + PacketSize = packet_traits::size, + HalfPacketSize = unpacket_traits::size, + QuarterPacketSize = unpacket_traits::size + }; + EIGEN_DONT_INLINE void operator()(Scalar *blockB, const DataMapper &rhs, Index depth, Index cols, Index stride = 0, + Index offset = 0) { constexpr int nr = 8; EIGEN_ASM_COMMENT("EIGEN PRODUCT PACK RHS ROWMAJOR"); EIGEN_UNUSED_VARIABLE(stride); EIGEN_UNUSED_VARIABLE(offset); - eigen_assert(((!PanelMode) && stride==0 && offset==0) || (PanelMode && stride>=depth && offset<=stride)); + eigen_assert(((!PanelMode) && stride == 0 && offset == 0) || (PanelMode && stride >= depth && offset <= stride)); const bool HasHalf = (int)HalfPacketSize < (int)PacketSize; const bool HasQuarter = (int)QuarterPacketSize < (int)HalfPacketSize; conj_if::IsComplex && Conjugate> cj; - Index packet_cols8 = nr>=8 ? (cols/8) * 8 : 0; - Index packet_cols4 = nr>=4 ? (cols/4) * 4 : 0; + Index packet_cols8 = nr >= 8 ? (cols / 8) * 8 : 0; + Index packet_cols4 = nr >= 4 ? (cols / 4) * 4 : 0; Index count = 0; - if(nr>=8) - { - for(Index j2=0; j2= 8) { + for (Index j2 = 0; j2 < packet_cols8; j2 += 8) { // skip what we have before - if(PanelMode) count += 8 * offset; - for(Index k=0; k(&rhs.data()[k*rhs.stride() + j2]); Packet A = rhs.template loadPacket(k, j2); - pstoreu(blockB+count, cj.pconj(A)); - } else if (HasHalf && HalfPacketSize==8) { + pstoreu(blockB + count, cj.pconj(A)); + } else if (HasHalf && HalfPacketSize == 8) { HalfPacket A = rhs.template loadPacket(k, j2); - pstoreu(blockB+count, cj.pconj(A)); - } else if (HasQuarter && QuarterPacketSize==8) { + pstoreu(blockB + count, cj.pconj(A)); + } else if (HasQuarter && QuarterPacketSize == 8) { QuarterPacket A = rhs.template loadPacket(k, j2); - pstoreu(blockB+count, cj.pconj(A)); - } else if (PacketSize==4) { + pstoreu(blockB + count, cj.pconj(A)); + } else if (PacketSize == 4) { // Packet A = ploadu(&rhs.data()[k*rhs.stride() + j2]); // Packet B = ploadu(&rhs.data()[k*rhs.stride() + j2 + PacketSize]); Packet A = rhs.template loadPacket(k, j2); Packet B = rhs.template loadPacket(k, j2 + PacketSize); - pstoreu(blockB+count, cj.pconj(A)); - pstoreu(blockB+count+PacketSize, cj.pconj(B)); + pstoreu(blockB + count, cj.pconj(A)); + pstoreu(blockB + count + PacketSize, cj.pconj(B)); } else { // const Scalar* b0 = &rhs.data()[k*rhs.stride() + j2]; const LinearMapper dm0 = rhs.getLinearMapper(k, j2); - blockB[count+0] = cj(dm0(0)); - blockB[count+1] = cj(dm0(1)); - blockB[count+2] = cj(dm0(2)); - blockB[count+3] = cj(dm0(3)); - blockB[count+4] = cj(dm0(4)); - blockB[count+5] = cj(dm0(5)); - blockB[count+6] = cj(dm0(6)); - blockB[count+7] = cj(dm0(7)); + blockB[count + 0] = cj(dm0(0)); + blockB[count + 1] = cj(dm0(1)); + blockB[count + 2] = cj(dm0(2)); + blockB[count + 3] = cj(dm0(3)); + blockB[count + 4] = cj(dm0(4)); + blockB[count + 5] = cj(dm0(5)); + blockB[count + 6] = cj(dm0(6)); + blockB[count + 7] = cj(dm0(7)); } count += 8; } // skip what we have after - if(PanelMode) count += 8 * (stride-offset-depth); + if (PanelMode) count += 8 * (stride - offset - depth); } } - if(nr>=4) - { - for(Index j2=packet_cols8; j2= 4) { + for (Index j2 = packet_cols8; j2 < packet_cols4; j2 += 4) { // skip what we have before - if(PanelMode) count += 4 * offset; - for(Index k=0; k(k, j2); - pstoreu(blockB+count, cj.pconj(A)); + pstoreu(blockB + count, cj.pconj(A)); count += PacketSize; - } else if (HasHalf && HalfPacketSize==4) { + } else if (HasHalf && HalfPacketSize == 4) { HalfPacket A = rhs.template loadPacket(k, j2); - pstoreu(blockB+count, cj.pconj(A)); + pstoreu(blockB + count, cj.pconj(A)); count += HalfPacketSize; - } else if (HasQuarter && QuarterPacketSize==4) { + } else if (HasQuarter && QuarterPacketSize == 4) { QuarterPacket A = rhs.template loadPacket(k, j2); - pstoreu(blockB+count, cj.pconj(A)); + pstoreu(blockB + count, cj.pconj(A)); count += QuarterPacketSize; } else { const LinearMapper dm0 = rhs.getLinearMapper(k, j2); - blockB[count+0] = cj(dm0(0)); - blockB[count+1] = cj(dm0(1)); - blockB[count+2] = cj(dm0(2)); - blockB[count+3] = cj(dm0(3)); + blockB[count + 0] = cj(dm0(0)); + blockB[count + 1] = cj(dm0(1)); + blockB[count + 2] = cj(dm0(2)); + blockB[count + 3] = cj(dm0(3)); count += 4; } } // skip what we have after - if(PanelMode) count += 4 * (stride-offset-depth); + if (PanelMode) count += 4 * (stride - offset - depth); } } // copy the remaining columns one at a time (nr==1) - for(Index j2=packet_cols4; j2 -struct gebp_kernel -{ +template +struct gebp_kernel { EIGEN_ALWAYS_INLINE - void operator()(const DataMapper& res, const Scalar* blockA, const Scalar* blockB, - Index rows, Index depth, Index cols, Scalar alpha, - Index strideA=-1, Index strideB=-1, Index offsetA=0, Index offsetB=0); + void operator()(const DataMapper &res, const Scalar *blockA, const Scalar *blockB, Index rows, Index depth, + Index cols, Scalar alpha, Index strideA = -1, Index strideB = -1, Index offsetA = 0, + Index offsetB = 0); }; -template -EIGEN_ALWAYS_INLINE -void gebp_kernel - ::operator()(const DataMapper& res, const Scalar* blockA, const Scalar* blockB, - Index rows, Index depth, Index cols, Scalar alpha, - Index strideA, Index strideB, Index offsetA, Index offsetB) -{ +template +EIGEN_ALWAYS_INLINE void gebp_kernel::operator()( + const DataMapper &res, const Scalar *blockA, const Scalar *blockB, Index rows, Index depth, Index cols, + Scalar alpha, Index strideA, Index strideB, Index offsetA, Index offsetB) { if (res.incr() == 1) { if (alpha == 1) { - gemm_kern_avx512(rows, cols, depth, - &alpha, blockA, blockB, (Scalar *)res.data(), res.stride(), - res.incr(), strideA, strideB, offsetA, offsetB); + gemm_kern_avx512(rows, cols, depth, &alpha, blockA, blockB, + (Scalar *)res.data(), res.stride(), res.incr(), strideA, + strideB, offsetA, offsetB); } else { - gemm_kern_avx512(rows, cols, depth, - &alpha, blockA, blockB, (Scalar *)res.data(), - res.stride(), res.incr(), strideA, strideB, offsetA, offsetB); + gemm_kern_avx512(rows, cols, depth, &alpha, blockA, blockB, + (Scalar *)res.data(), res.stride(), res.incr(), strideA, + strideB, offsetA, offsetB); } } else { if (alpha == 1) { - gemm_kern_avx512(rows, cols, depth, - &alpha, blockA, blockB, (Scalar *)res.data(), res.stride(), - res.incr(), strideA, strideB, offsetA, offsetB); + gemm_kern_avx512(rows, cols, depth, &alpha, blockA, blockB, + (Scalar *)res.data(), res.stride(), res.incr(), strideA, + strideB, offsetA, offsetB); } else { - gemm_kern_avx512(rows, cols, depth, - &alpha, blockA, blockB, (Scalar *)res.data(), res.stride(), - res.incr(), strideA, strideB, offsetA, offsetB); + gemm_kern_avx512(rows, cols, depth, &alpha, blockA, blockB, + (Scalar *)res.data(), res.stride(), res.incr(), strideA, + strideB, offsetA, offsetB); } } } -} // namespace Eigen -} // namespace internal +} // namespace internal +} // namespace Eigen -#endif // GEMM_KERNEL_H +#endif // GEMM_KERNEL_H diff --git a/Eigen/src/Core/arch/AVX512/TrsmKernel.h b/Eigen/src/Core/arch/AVX512/TrsmKernel.h index b39524958..1b351ea56 100644 --- a/Eigen/src/Core/arch/AVX512/TrsmKernel.h +++ b/Eigen/src/Core/arch/AVX512/TrsmKernel.h @@ -12,13 +12,20 @@ #include "../../InternalHeaderCheck.h" -#define EIGEN_USE_AVX512_TRSM_KERNELS // Comment out to prevent using optimized trsm kernels. - -#if defined(EIGEN_USE_AVX512_TRSM_KERNELS) -#define EIGEN_USE_AVX512_TRSM_R_KERNELS -#if !defined(EIGEN_NO_MALLOC) // Separate MACRO since these kernels require malloc -#define EIGEN_USE_AVX512_TRSM_L_KERNELS +#if !defined(EIGEN_USE_AVX512_TRSM_KERNELS) +#define EIGEN_USE_AVX512_TRSM_KERNELS 1 #endif + +#if EIGEN_USE_AVX512_TRSM_KERNELS +#if !defined(EIGEN_USE_AVX512_TRSM_R_KERNELS) +#define EIGEN_USE_AVX512_TRSM_R_KERNELS 1 +#endif +#if !defined(EIGEN_USE_AVX512_TRSM_L_KERNELS) +#define EIGEN_USE_AVX512_TRSM_L_KERNELS 1 +#endif +#else // EIGEN_USE_AVX512_TRSM_KERNELS == 0 +#define EIGEN_USE_AVX512_TRSM_R_KERNELS 0 +#define EIGEN_USE_AVX512_TRSM_L_KERNELS 0 #endif #if defined(EIGEN_HAS_CXX17_IFCONSTEXPR) @@ -49,8 +56,7 @@ typedef Packet4d vecHalfDouble; // Note: this depends on macros and typedefs above. #include "TrsmUnrolls.inc" - -#if defined(EIGEN_USE_AVX512_TRSM_KERNELS) && (EIGEN_COMP_CLANG != 0) +#if (EIGEN_USE_AVX512_TRSM_KERNELS) && (EIGEN_COMP_CLANG != 0) /** * For smaller problem sizes, and certain compilers, using the optimized kernels trsmKernelL/R directly * is faster than the packed versions in TriangularSolverMatrix.h. @@ -67,35 +73,47 @@ typedef Packet4d vecHalfDouble; * M = Dimension of triangular matrix * */ -#define EIGEN_ENABLE_AVX512_NOCOPY_TRSM_CUTOFFS // Comment out to disable no-copy dispatch +#if !defined(EIGEN_ENABLE_AVX512_NOCOPY_TRSM_CUTOFFS) +#define EIGEN_ENABLE_AVX512_NOCOPY_TRSM_CUTOFFS 1 +#endif -#if defined(EIGEN_ENABLE_AVX512_NOCOPY_TRSM_CUTOFFS) -#define EIGEN_ENABLE_AVX512_NOCOPY_TRSM_R_CUTOFFS -#if !defined(EIGEN_NO_MALLOC) -#define EIGEN_ENABLE_AVX512_NOCOPY_TRSM_L_CUTOFFS +#if EIGEN_ENABLE_AVX512_NOCOPY_TRSM_CUTOFFS + +#if EIGEN_USE_AVX512_TRSM_R_KERNELS +#if !defined(EIGEN_ENABLE_AVX512_NOCOPY_TRSM_R_CUTOFFS) +#define EIGEN_ENABLE_AVX512_NOCOPY_TRSM_R_CUTOFFS 1 +#endif // !defined(EIGEN_ENABLE_AVX512_NOCOPY_TRSM_R_CUTOFFS) #endif + +#if EIGEN_USE_AVX512_TRSM_L_KERNELS +#if !defined(EIGEN_ENABLE_AVX512_NOCOPY_TRSM_L_CUTOFFS) +#define EIGEN_ENABLE_AVX512_NOCOPY_TRSM_L_CUTOFFS 1 #endif +#endif // EIGEN_USE_AVX512_TRSM_L_KERNELS + +#else // EIGEN_ENABLE_AVX512_NOCOPY_TRSM_CUTOFFS == 0 +#define EIGEN_ENABLE_AVX512_NOCOPY_TRSM_R_CUTOFFS 0 +#define EIGEN_ENABLE_AVX512_NOCOPY_TRSM_L_CUTOFFS 0 +#endif // EIGEN_ENABLE_AVX512_NOCOPY_TRSM_CUTOFFS template -int64_t avx512_trsm_cutoff(int64_t L2Size, int64_t N, double L2Cap){ - const int64_t U3 = 3*packet_traits::size; - const int64_t MaxNb = 5*U3; +int64_t avx512_trsm_cutoff(int64_t L2Size, int64_t N, double L2Cap) { + const int64_t U3 = 3 * packet_traits::size; + const int64_t MaxNb = 5 * U3; int64_t Nb = std::min(MaxNb, N); - double cutoff_d = (((L2Size*L2Cap)/(sizeof(Scalar)))-(EIGEN_AVX_MAX_NUM_ROW)*Nb)/ - ((EIGEN_AVX_MAX_NUM_ROW)+Nb); + double cutoff_d = + (((L2Size * L2Cap) / (sizeof(Scalar))) - (EIGEN_AVX_MAX_NUM_ROW)*Nb) / ((EIGEN_AVX_MAX_NUM_ROW) + Nb); int64_t cutoff_l = static_cast(cutoff_d); - return (cutoff_l/EIGEN_AVX_MAX_NUM_ROW)*EIGEN_AVX_MAX_NUM_ROW; + return (cutoff_l / EIGEN_AVX_MAX_NUM_ROW) * EIGEN_AVX_MAX_NUM_ROW; } #endif - /** * Used by gemmKernel for the case A/B row-major and C col-major. */ template -static EIGEN_ALWAYS_INLINE -void transStoreC(PacketBlock &zmm, - Scalar *C_arr, int64_t LDC, int64_t remM_ = 0, int64_t remN_ = 0) { +static EIGEN_ALWAYS_INLINE void transStoreC(PacketBlock &zmm, + Scalar *C_arr, int64_t LDC, int64_t remM_ = 0, int64_t remN_ = 0) { EIGEN_UNUSED_VARIABLE(remN_); EIGEN_UNUSED_VARIABLE(remM_); using urolls = unrolls::trans; @@ -104,76 +122,76 @@ void transStoreC(PacketBlock &zmm, constexpr int64_t U2 = urolls::PacketSize * 2; constexpr int64_t U1 = urolls::PacketSize * 1; - static_assert( unrollN == U1 || unrollN == U2 || unrollN == U3, "unrollN should be a multiple of PacketSize"); - static_assert( unrollM == EIGEN_AVX_MAX_NUM_ROW, "unrollM should be equal to EIGEN_AVX_MAX_NUM_ROW"); + static_assert(unrollN == U1 || unrollN == U2 || unrollN == U3, "unrollN should be a multiple of PacketSize"); + static_assert(unrollM == EIGEN_AVX_MAX_NUM_ROW, "unrollM should be equal to EIGEN_AVX_MAX_NUM_ROW"); urolls::template transpose(zmm); EIGEN_IF_CONSTEXPR(unrollN > U2) urolls::template transpose(zmm); EIGEN_IF_CONSTEXPR(unrollN > U1) urolls::template transpose(zmm); - static_assert( (remN && unrollN == U1) || !remN, "When handling N remainder set unrollN=U1"); + static_assert((remN && unrollN == U1) || !remN, "When handling N remainder set unrollN=U1"); EIGEN_IF_CONSTEXPR(!remN) { - urolls::template storeC(C_arr, LDC, zmm, remM_); + urolls::template storeC(C_arr, LDC, zmm, remM_); EIGEN_IF_CONSTEXPR(unrollN > U1) { - constexpr int64_t unrollN_ = std::min(unrollN-U1, U1); - urolls::template storeC(C_arr + U1*LDC, LDC, zmm, remM_); + constexpr int64_t unrollN_ = std::min(unrollN - U1, U1); + urolls::template storeC(C_arr + U1 * LDC, LDC, zmm, remM_); } EIGEN_IF_CONSTEXPR(unrollN > U2) { - constexpr int64_t unrollN_ = std::min(unrollN-U2, U1); - urolls:: template storeC(C_arr + U2*LDC, LDC, zmm, remM_); + constexpr int64_t unrollN_ = std::min(unrollN - U2, U1); + urolls::template storeC(C_arr + U2 * LDC, LDC, zmm, remM_); } } else { - EIGEN_IF_CONSTEXPR( (std::is_same::value) ) { + EIGEN_IF_CONSTEXPR((std::is_same::value)) { // Note: without "if constexpr" this section of code will also be // parsed by the compiler so each of the storeC will still be instantiated. // We use enable_if in aux_storeC to set it to an empty function for // these cases. - if(remN_ == 15) + if (remN_ == 15) urolls::template storeC<15, unrollN, 0, remM>(C_arr, LDC, zmm, remM_); - else if(remN_ == 14) + else if (remN_ == 14) urolls::template storeC<14, unrollN, 0, remM>(C_arr, LDC, zmm, remM_); - else if(remN_ == 13) + else if (remN_ == 13) urolls::template storeC<13, unrollN, 0, remM>(C_arr, LDC, zmm, remM_); - else if(remN_ == 12) + else if (remN_ == 12) urolls::template storeC<12, unrollN, 0, remM>(C_arr, LDC, zmm, remM_); - else if(remN_ == 11) + else if (remN_ == 11) urolls::template storeC<11, unrollN, 0, remM>(C_arr, LDC, zmm, remM_); - else if(remN_ == 10) + else if (remN_ == 10) urolls::template storeC<10, unrollN, 0, remM>(C_arr, LDC, zmm, remM_); - else if(remN_ == 9) + else if (remN_ == 9) urolls::template storeC<9, unrollN, 0, remM>(C_arr, LDC, zmm, remM_); - else if(remN_ == 8) + else if (remN_ == 8) urolls::template storeC<8, unrollN, 0, remM>(C_arr, LDC, zmm, remM_); - else if(remN_ == 7) + else if (remN_ == 7) urolls::template storeC<7, unrollN, 0, remM>(C_arr, LDC, zmm, remM_); - else if(remN_ == 6) + else if (remN_ == 6) urolls::template storeC<6, unrollN, 0, remM>(C_arr, LDC, zmm, remM_); - else if(remN_ == 5) + else if (remN_ == 5) urolls::template storeC<5, unrollN, 0, remM>(C_arr, LDC, zmm, remM_); - else if(remN_ == 4) + else if (remN_ == 4) urolls::template storeC<4, unrollN, 0, remM>(C_arr, LDC, zmm, remM_); - else if(remN_ == 3) + else if (remN_ == 3) urolls::template storeC<3, unrollN, 0, remM>(C_arr, LDC, zmm, remM_); - else if(remN_ == 2) + else if (remN_ == 2) urolls::template storeC<2, unrollN, 0, remM>(C_arr, LDC, zmm, remM_); - else if(remN_ == 1) + else if (remN_ == 1) urolls::template storeC<1, unrollN, 0, remM>(C_arr, LDC, zmm, remM_); } else { - if(remN_ == 7) + if (remN_ == 7) urolls::template storeC<7, unrollN, 0, remM>(C_arr, LDC, zmm, remM_); - else if(remN_ == 6) + else if (remN_ == 6) urolls::template storeC<6, unrollN, 0, remM>(C_arr, LDC, zmm, remM_); - else if(remN_ == 5) + else if (remN_ == 5) urolls::template storeC<5, unrollN, 0, remM>(C_arr, LDC, zmm, remM_); - else if(remN_ == 4) + else if (remN_ == 4) urolls::template storeC<4, unrollN, 0, remM>(C_arr, LDC, zmm, remM_); - else if(remN_ == 3) + else if (remN_ == 3) urolls::template storeC<3, unrollN, 0, remM>(C_arr, LDC, zmm, remM_); - else if(remN_ == 2) + else if (remN_ == 2) urolls::template storeC<2, unrollN, 0, remM>(C_arr, LDC, zmm, remM_); - else if(remN_ == 1) + else if (remN_ == 1) urolls::template storeC<1, unrollN, 0, remM>(C_arr, LDC, zmm, remM_); } } @@ -194,505 +212,503 @@ void transStoreC(PacketBlock &zmm, * handleKRem: Handle arbitrary K? This is not needed for trsm. */ template -void gemmKernel(Scalar *A_arr, Scalar *B_arr, Scalar *C_arr, - int64_t M, int64_t N, int64_t K, - int64_t LDA, int64_t LDB, int64_t LDC) { +void gemmKernel(Scalar *A_arr, Scalar *B_arr, Scalar *C_arr, int64_t M, int64_t N, int64_t K, int64_t LDA, int64_t LDB, + int64_t LDC) { using urolls = unrolls::gemm; constexpr int64_t U3 = urolls::PacketSize * 3; constexpr int64_t U2 = urolls::PacketSize * 2; constexpr int64_t U1 = urolls::PacketSize * 1; - using vec = typename std::conditional::value, - vecFullFloat, - vecFullDouble>::type; - int64_t N_ = (N/U3)*U3; - int64_t M_ = (M/EIGEN_AVX_MAX_NUM_ROW)*EIGEN_AVX_MAX_NUM_ROW; - int64_t K_ = (K/EIGEN_AVX_MAX_K_UNROL)*EIGEN_AVX_MAX_K_UNROL; + using vec = typename std::conditional::value, vecFullFloat, vecFullDouble>::type; + int64_t N_ = (N / U3) * U3; + int64_t M_ = (M / EIGEN_AVX_MAX_NUM_ROW) * EIGEN_AVX_MAX_NUM_ROW; + int64_t K_ = (K / EIGEN_AVX_MAX_K_UNROL) * EIGEN_AVX_MAX_K_UNROL; int64_t j = 0; - for(; j < N_; j += U3) { - constexpr int64_t EIGEN_AVX_MAX_B_LOAD = EIGEN_AVX_B_LOAD_SETS*3; + for (; j < N_; j += U3) { + constexpr int64_t EIGEN_AVX_MAX_B_LOAD = EIGEN_AVX_B_LOAD_SETS * 3; int64_t i = 0; - for(; i < M_; i += EIGEN_AVX_MAX_NUM_ROW) { - Scalar *A_t = &A_arr[idA(i,0,LDA)], *B_t = &B_arr[0*LDB + j]; - PacketBlock zmm; - urolls::template setzero<3,EIGEN_AVX_MAX_NUM_ROW>(zmm); - for(int64_t k = 0; k < K_ ; k += EIGEN_AVX_MAX_K_UNROL) { - urolls:: template microKernel( - B_t, A_t, LDB, LDA, zmm); - B_t += EIGEN_AVX_MAX_K_UNROL*LDB; - EIGEN_IF_CONSTEXPR(isARowMajor) A_t += EIGEN_AVX_MAX_K_UNROL; else A_t += EIGEN_AVX_MAX_K_UNROL*LDA; + for (; i < M_; i += EIGEN_AVX_MAX_NUM_ROW) { + Scalar *A_t = &A_arr[idA(i, 0, LDA)], *B_t = &B_arr[0 * LDB + j]; + PacketBlock zmm; + urolls::template setzero<3, EIGEN_AVX_MAX_NUM_ROW>(zmm); + for (int64_t k = 0; k < K_; k += EIGEN_AVX_MAX_K_UNROL) { + urolls::template microKernel(B_t, A_t, LDB, LDA, zmm); + B_t += EIGEN_AVX_MAX_K_UNROL * LDB; + EIGEN_IF_CONSTEXPR(isARowMajor) A_t += EIGEN_AVX_MAX_K_UNROL; + else A_t += EIGEN_AVX_MAX_K_UNROL * LDA; } EIGEN_IF_CONSTEXPR(handleKRem) { - for(int64_t k = K_; k < K ; k ++) { - urolls:: template microKernel( - B_t, A_t, LDB, LDA, zmm); + for (int64_t k = K_; k < K; k++) { + urolls::template microKernel(B_t, A_t, LDB, LDA, zmm); B_t += LDB; - EIGEN_IF_CONSTEXPR(isARowMajor) A_t++; else A_t += LDA; + EIGEN_IF_CONSTEXPR(isARowMajor) A_t++; + else A_t += LDA; } } EIGEN_IF_CONSTEXPR(isCRowMajor) { - urolls::template updateC<3,EIGEN_AVX_MAX_NUM_ROW>(&C_arr[i*LDC + j], LDC, zmm); - urolls::template storeC<3,EIGEN_AVX_MAX_NUM_ROW>(&C_arr[i*LDC+ j], LDC, zmm); + urolls::template updateC<3, EIGEN_AVX_MAX_NUM_ROW>(&C_arr[i * LDC + j], LDC, zmm); + urolls::template storeC<3, EIGEN_AVX_MAX_NUM_ROW>(&C_arr[i * LDC + j], LDC, zmm); } else { - transStoreC(zmm, &C_arr[i + j*LDC], LDC); + transStoreC(zmm, &C_arr[i + j * LDC], LDC); } } - if(M - i >= 4) { // Note: this block assumes EIGEN_AVX_MAX_NUM_ROW = 8. Should be removed otherwise - Scalar *A_t = &A_arr[idA(i,0,LDA)]; - Scalar *B_t = &B_arr[0*LDB + j]; - PacketBlock zmm; - urolls::template setzero<3,4>(zmm); - for(int64_t k = 0; k < K_ ; k += EIGEN_AVX_MAX_K_UNROL) { - urolls:: template microKernel( - B_t, A_t, LDB, LDA, zmm); - B_t += EIGEN_AVX_MAX_K_UNROL*LDB; - EIGEN_IF_CONSTEXPR(isARowMajor) A_t += EIGEN_AVX_MAX_K_UNROL; else A_t += EIGEN_AVX_MAX_K_UNROL*LDA; + if (M - i >= 4) { // Note: this block assumes EIGEN_AVX_MAX_NUM_ROW = 8. Should be removed otherwise + Scalar *A_t = &A_arr[idA(i, 0, LDA)]; + Scalar *B_t = &B_arr[0 * LDB + j]; + PacketBlock zmm; + urolls::template setzero<3, 4>(zmm); + for (int64_t k = 0; k < K_; k += EIGEN_AVX_MAX_K_UNROL) { + urolls::template microKernel(B_t, A_t, LDB, LDA, zmm); + B_t += EIGEN_AVX_MAX_K_UNROL * LDB; + EIGEN_IF_CONSTEXPR(isARowMajor) A_t += EIGEN_AVX_MAX_K_UNROL; + else A_t += EIGEN_AVX_MAX_K_UNROL * LDA; } EIGEN_IF_CONSTEXPR(handleKRem) { - for(int64_t k = K_; k < K ; k ++) { - urolls:: template microKernel( - B_t, A_t, LDB, LDA, zmm); + for (int64_t k = K_; k < K; k++) { + urolls::template microKernel( + B_t, A_t, LDB, LDA, zmm); B_t += LDB; - EIGEN_IF_CONSTEXPR(isARowMajor) A_t++; else A_t += LDA; + EIGEN_IF_CONSTEXPR(isARowMajor) A_t++; + else A_t += LDA; } } EIGEN_IF_CONSTEXPR(isCRowMajor) { - urolls::template updateC<3,4>(&C_arr[i*LDC + j], LDC, zmm); - urolls::template storeC<3,4>(&C_arr[i*LDC + j], LDC, zmm); + urolls::template updateC<3, 4>(&C_arr[i * LDC + j], LDC, zmm); + urolls::template storeC<3, 4>(&C_arr[i * LDC + j], LDC, zmm); } else { - transStoreC(zmm, &C_arr[i + j*LDC], LDC, 4); + transStoreC(zmm, &C_arr[i + j * LDC], LDC, 4); } i += 4; } - if(M - i >= 2) { - Scalar *A_t = &A_arr[idA(i,0,LDA)]; - Scalar *B_t = &B_arr[0*LDB + j]; - PacketBlock zmm; - urolls::template setzero<3,2>(zmm); - for(int64_t k = 0; k < K_ ; k += EIGEN_AVX_MAX_K_UNROL) { - urolls:: template microKernel( - B_t, A_t, LDB, LDA, zmm); - B_t += EIGEN_AVX_MAX_K_UNROL*LDB; - EIGEN_IF_CONSTEXPR(isARowMajor) A_t += EIGEN_AVX_MAX_K_UNROL; else A_t += EIGEN_AVX_MAX_K_UNROL*LDA; + if (M - i >= 2) { + Scalar *A_t = &A_arr[idA(i, 0, LDA)]; + Scalar *B_t = &B_arr[0 * LDB + j]; + PacketBlock zmm; + urolls::template setzero<3, 2>(zmm); + for (int64_t k = 0; k < K_; k += EIGEN_AVX_MAX_K_UNROL) { + urolls::template microKernel(B_t, A_t, LDB, LDA, zmm); + B_t += EIGEN_AVX_MAX_K_UNROL * LDB; + EIGEN_IF_CONSTEXPR(isARowMajor) A_t += EIGEN_AVX_MAX_K_UNROL; + else A_t += EIGEN_AVX_MAX_K_UNROL * LDA; } EIGEN_IF_CONSTEXPR(handleKRem) { - for(int64_t k = K_; k < K ; k ++) { - urolls:: template microKernel( - B_t, A_t, LDB, LDA, zmm); + for (int64_t k = K_; k < K; k++) { + urolls::template microKernel( + B_t, A_t, LDB, LDA, zmm); B_t += LDB; - EIGEN_IF_CONSTEXPR(isARowMajor) A_t++; else A_t += LDA; + EIGEN_IF_CONSTEXPR(isARowMajor) A_t++; + else A_t += LDA; } } EIGEN_IF_CONSTEXPR(isCRowMajor) { - urolls::template updateC<3,2>(&C_arr[i*LDC + j], LDC, zmm); - urolls::template storeC<3,2>(&C_arr[i*LDC + j], LDC, zmm); + urolls::template updateC<3, 2>(&C_arr[i * LDC + j], LDC, zmm); + urolls::template storeC<3, 2>(&C_arr[i * LDC + j], LDC, zmm); } else { - transStoreC(zmm, &C_arr[i + j*LDC], LDC, 2); + transStoreC(zmm, &C_arr[i + j * LDC], LDC, 2); } i += 2; } - if(M - i > 0) { - Scalar *A_t = &A_arr[idA(i,0,LDA)]; - Scalar *B_t = &B_arr[0*LDB + j]; - PacketBlock zmm; - urolls::template setzero<3,1>(zmm); + if (M - i > 0) { + Scalar *A_t = &A_arr[idA(i, 0, LDA)]; + Scalar *B_t = &B_arr[0 * LDB + j]; + PacketBlock zmm; + urolls::template setzero<3, 1>(zmm); { - for(int64_t k = 0; k < K_ ; k += EIGEN_AVX_MAX_K_UNROL) { - urolls:: template microKernel( - B_t, A_t, LDB, LDA, zmm); - B_t += EIGEN_AVX_MAX_K_UNROL*LDB; - EIGEN_IF_CONSTEXPR(isARowMajor) A_t += EIGEN_AVX_MAX_K_UNROL; else A_t += EIGEN_AVX_MAX_K_UNROL*LDA; + for (int64_t k = 0; k < K_; k += EIGEN_AVX_MAX_K_UNROL) { + urolls::template microKernel( + B_t, A_t, LDB, LDA, zmm); + B_t += EIGEN_AVX_MAX_K_UNROL * LDB; + EIGEN_IF_CONSTEXPR(isARowMajor) A_t += EIGEN_AVX_MAX_K_UNROL; + else A_t += EIGEN_AVX_MAX_K_UNROL * LDA; } EIGEN_IF_CONSTEXPR(handleKRem) { - for(int64_t k = K_; k < K ; k ++) { - urolls:: template microKernel(B_t, A_t, LDB, LDA, zmm); + for (int64_t k = K_; k < K; k++) { + urolls::template microKernel(B_t, A_t, LDB, LDA, zmm); B_t += LDB; - EIGEN_IF_CONSTEXPR(isARowMajor) A_t++; else A_t += LDA; + EIGEN_IF_CONSTEXPR(isARowMajor) A_t++; + else A_t += LDA; } } EIGEN_IF_CONSTEXPR(isCRowMajor) { - urolls::template updateC<3,1>(&C_arr[i*LDC + j], LDC, zmm); - urolls::template storeC<3,1>(&C_arr[i*LDC + j], LDC, zmm); + urolls::template updateC<3, 1>(&C_arr[i * LDC + j], LDC, zmm); + urolls::template storeC<3, 1>(&C_arr[i * LDC + j], LDC, zmm); } else { - transStoreC(zmm, &C_arr[i + j*LDC], LDC, 1); + transStoreC(zmm, &C_arr[i + j * LDC], LDC, 1); } } } } - if(N - j >= U2) { - constexpr int64_t EIGEN_AVX_MAX_B_LOAD = EIGEN_AVX_B_LOAD_SETS*2; + if (N - j >= U2) { + constexpr int64_t EIGEN_AVX_MAX_B_LOAD = EIGEN_AVX_B_LOAD_SETS * 2; int64_t i = 0; - for(; i < M_; i += EIGEN_AVX_MAX_NUM_ROW) { - Scalar *A_t = &A_arr[idA(i,0,LDA)], *B_t = &B_arr[0*LDB + j]; - EIGEN_IF_CONSTEXPR(isCRowMajor) B_t = &B_arr[0*LDB + j]; - PacketBlock zmm; - urolls::template setzero<2,EIGEN_AVX_MAX_NUM_ROW>(zmm); - for(int64_t k = 0; k < K_ ; k += EIGEN_AVX_MAX_K_UNROL) { - urolls:: template microKernel( - B_t, A_t, LDB, LDA, zmm); - B_t += EIGEN_AVX_MAX_K_UNROL*LDB; - EIGEN_IF_CONSTEXPR(isARowMajor) A_t += EIGEN_AVX_MAX_K_UNROL; else A_t += EIGEN_AVX_MAX_K_UNROL*LDA; + for (; i < M_; i += EIGEN_AVX_MAX_NUM_ROW) { + Scalar *A_t = &A_arr[idA(i, 0, LDA)], *B_t = &B_arr[0 * LDB + j]; + EIGEN_IF_CONSTEXPR(isCRowMajor) B_t = &B_arr[0 * LDB + j]; + PacketBlock zmm; + urolls::template setzero<2, EIGEN_AVX_MAX_NUM_ROW>(zmm); + for (int64_t k = 0; k < K_; k += EIGEN_AVX_MAX_K_UNROL) { + urolls::template microKernel(B_t, A_t, LDB, LDA, zmm); + B_t += EIGEN_AVX_MAX_K_UNROL * LDB; + EIGEN_IF_CONSTEXPR(isARowMajor) A_t += EIGEN_AVX_MAX_K_UNROL; + else A_t += EIGEN_AVX_MAX_K_UNROL * LDA; } EIGEN_IF_CONSTEXPR(handleKRem) { - for(int64_t k = K_; k < K ; k ++) { - urolls:: template microKernel( - B_t, A_t, LDB, LDA, zmm); + for (int64_t k = K_; k < K; k++) { + urolls::template microKernel(B_t, A_t, LDB, LDA, zmm); B_t += LDB; - EIGEN_IF_CONSTEXPR(isARowMajor) A_t++; else A_t += LDA; + EIGEN_IF_CONSTEXPR(isARowMajor) A_t++; + else A_t += LDA; } } EIGEN_IF_CONSTEXPR(isCRowMajor) { - urolls::template updateC<2,EIGEN_AVX_MAX_NUM_ROW>(&C_arr[i*LDC + j], LDC, zmm); - urolls::template storeC<2,EIGEN_AVX_MAX_NUM_ROW>(&C_arr[i*LDC + j], LDC, zmm); + urolls::template updateC<2, EIGEN_AVX_MAX_NUM_ROW>(&C_arr[i * LDC + j], LDC, zmm); + urolls::template storeC<2, EIGEN_AVX_MAX_NUM_ROW>(&C_arr[i * LDC + j], LDC, zmm); } else { - transStoreC(zmm, &C_arr[i + j*LDC], LDC); + transStoreC(zmm, &C_arr[i + j * LDC], LDC); } } - if(M - i >= 4) { // Note: this block assumes EIGEN_AVX_MAX_NUM_ROW = 8. Should be removed otherwise - Scalar *A_t = &A_arr[idA(i,0,LDA)]; - Scalar *B_t = &B_arr[0*LDB + j]; - PacketBlock zmm; - urolls::template setzero<2,4>(zmm); - for(int64_t k = 0; k < K_ ; k += EIGEN_AVX_MAX_K_UNROL) { - urolls:: template microKernel( - B_t, A_t, LDB, LDA, zmm); - B_t += EIGEN_AVX_MAX_K_UNROL*LDB; - EIGEN_IF_CONSTEXPR(isARowMajor) A_t += EIGEN_AVX_MAX_K_UNROL; else A_t += EIGEN_AVX_MAX_K_UNROL*LDA; + if (M - i >= 4) { // Note: this block assumes EIGEN_AVX_MAX_NUM_ROW = 8. Should be removed otherwise + Scalar *A_t = &A_arr[idA(i, 0, LDA)]; + Scalar *B_t = &B_arr[0 * LDB + j]; + PacketBlock zmm; + urolls::template setzero<2, 4>(zmm); + for (int64_t k = 0; k < K_; k += EIGEN_AVX_MAX_K_UNROL) { + urolls::template microKernel(B_t, A_t, LDB, LDA, zmm); + B_t += EIGEN_AVX_MAX_K_UNROL * LDB; + EIGEN_IF_CONSTEXPR(isARowMajor) A_t += EIGEN_AVX_MAX_K_UNROL; + else A_t += EIGEN_AVX_MAX_K_UNROL * LDA; } EIGEN_IF_CONSTEXPR(handleKRem) { - for(int64_t k = K_; k < K ; k ++) { - urolls:: template microKernel( - B_t, A_t, LDB, LDA, zmm); + for (int64_t k = K_; k < K; k++) { + urolls::template microKernel(B_t, A_t, LDB, + LDA, zmm); B_t += LDB; - EIGEN_IF_CONSTEXPR(isARowMajor) A_t++; else A_t += LDA; + EIGEN_IF_CONSTEXPR(isARowMajor) A_t++; + else A_t += LDA; } } EIGEN_IF_CONSTEXPR(isCRowMajor) { - urolls::template updateC<2,4>(&C_arr[i*LDC + j], LDC, zmm); - urolls::template storeC<2,4>(&C_arr[i*LDC + j], LDC, zmm); + urolls::template updateC<2, 4>(&C_arr[i * LDC + j], LDC, zmm); + urolls::template storeC<2, 4>(&C_arr[i * LDC + j], LDC, zmm); } else { - transStoreC(zmm, &C_arr[i + j*LDC], LDC, 4); + transStoreC(zmm, &C_arr[i + j * LDC], LDC, 4); } i += 4; } - if(M - i >= 2) { - Scalar *A_t = &A_arr[idA(i,0,LDA)]; - Scalar *B_t = &B_arr[0*LDB + j]; - PacketBlock zmm; - urolls::template setzero<2,2>(zmm); - for(int64_t k = 0; k < K_ ; k += EIGEN_AVX_MAX_K_UNROL) { - urolls:: template microKernel( - B_t, A_t, LDB, LDA, zmm); - B_t += EIGEN_AVX_MAX_K_UNROL*LDB; - EIGEN_IF_CONSTEXPR(isARowMajor) A_t += EIGEN_AVX_MAX_K_UNROL; else A_t += EIGEN_AVX_MAX_K_UNROL*LDA; + if (M - i >= 2) { + Scalar *A_t = &A_arr[idA(i, 0, LDA)]; + Scalar *B_t = &B_arr[0 * LDB + j]; + PacketBlock zmm; + urolls::template setzero<2, 2>(zmm); + for (int64_t k = 0; k < K_; k += EIGEN_AVX_MAX_K_UNROL) { + urolls::template microKernel(B_t, A_t, LDB, LDA, zmm); + B_t += EIGEN_AVX_MAX_K_UNROL * LDB; + EIGEN_IF_CONSTEXPR(isARowMajor) A_t += EIGEN_AVX_MAX_K_UNROL; + else A_t += EIGEN_AVX_MAX_K_UNROL * LDA; } EIGEN_IF_CONSTEXPR(handleKRem) { - for(int64_t k = K_; k < K ; k ++) { - urolls:: template microKernel( - B_t, A_t, LDB, LDA, zmm); + for (int64_t k = K_; k < K; k++) { + urolls::template microKernel(B_t, A_t, LDB, + LDA, zmm); B_t += LDB; - EIGEN_IF_CONSTEXPR(isARowMajor) A_t++; else A_t += LDA; + EIGEN_IF_CONSTEXPR(isARowMajor) A_t++; + else A_t += LDA; } } EIGEN_IF_CONSTEXPR(isCRowMajor) { - urolls::template updateC<2,2>(&C_arr[i*LDC + j], LDC, zmm); - urolls::template storeC<2,2>(&C_arr[i*LDC + j], LDC, zmm); + urolls::template updateC<2, 2>(&C_arr[i * LDC + j], LDC, zmm); + urolls::template storeC<2, 2>(&C_arr[i * LDC + j], LDC, zmm); } else { - transStoreC(zmm, &C_arr[i + j*LDC], LDC, 2); + transStoreC(zmm, &C_arr[i + j * LDC], LDC, 2); } i += 2; } - if(M - i > 0) { - Scalar *A_t = &A_arr[idA(i,0,LDA)]; - Scalar *B_t = &B_arr[0*LDB + j]; - PacketBlock zmm; - urolls::template setzero<2,1>(zmm); - for(int64_t k = 0; k < K_ ; k += EIGEN_AVX_MAX_K_UNROL) { - urolls:: template microKernel( - B_t, A_t, LDB, LDA, zmm); - B_t += EIGEN_AVX_MAX_K_UNROL*LDB; - EIGEN_IF_CONSTEXPR(isARowMajor) A_t += EIGEN_AVX_MAX_K_UNROL; else A_t += EIGEN_AVX_MAX_K_UNROL*LDA; + if (M - i > 0) { + Scalar *A_t = &A_arr[idA(i, 0, LDA)]; + Scalar *B_t = &B_arr[0 * LDB + j]; + PacketBlock zmm; + urolls::template setzero<2, 1>(zmm); + for (int64_t k = 0; k < K_; k += EIGEN_AVX_MAX_K_UNROL) { + urolls::template microKernel(B_t, A_t, LDB, + LDA, zmm); + B_t += EIGEN_AVX_MAX_K_UNROL * LDB; + EIGEN_IF_CONSTEXPR(isARowMajor) A_t += EIGEN_AVX_MAX_K_UNROL; + else A_t += EIGEN_AVX_MAX_K_UNROL * LDA; } EIGEN_IF_CONSTEXPR(handleKRem) { - for(int64_t k = K_; k < K ; k ++) { - urolls:: template microKernel(B_t, A_t, LDB, LDA, zmm); + for (int64_t k = K_; k < K; k++) { + urolls::template microKernel(B_t, A_t, LDB, LDA, zmm); B_t += LDB; - EIGEN_IF_CONSTEXPR(isARowMajor) A_t++; else A_t += LDA; + EIGEN_IF_CONSTEXPR(isARowMajor) A_t++; + else A_t += LDA; } } EIGEN_IF_CONSTEXPR(isCRowMajor) { - urolls::template updateC<2,1>(&C_arr[i*LDC + j], LDC, zmm); - urolls::template storeC<2,1>(&C_arr[i*LDC + j], LDC, zmm); + urolls::template updateC<2, 1>(&C_arr[i * LDC + j], LDC, zmm); + urolls::template storeC<2, 1>(&C_arr[i * LDC + j], LDC, zmm); } else { - transStoreC(zmm, &C_arr[i + j*LDC], LDC, 1); + transStoreC(zmm, &C_arr[i + j * LDC], LDC, 1); } } j += U2; } - if(N - j >= U1) { - constexpr int64_t EIGEN_AVX_MAX_B_LOAD = EIGEN_AVX_B_LOAD_SETS*1; + if (N - j >= U1) { + constexpr int64_t EIGEN_AVX_MAX_B_LOAD = EIGEN_AVX_B_LOAD_SETS * 1; int64_t i = 0; - for(; i < M_; i += EIGEN_AVX_MAX_NUM_ROW) { - Scalar *A_t = &A_arr[idA(i,0,LDA)], *B_t = &B_arr[0*LDB + j]; - PacketBlock zmm; - urolls::template setzero<1,EIGEN_AVX_MAX_NUM_ROW>(zmm); - for(int64_t k = 0; k < K_ ; k += EIGEN_AVX_MAX_K_UNROL) { - urolls:: template microKernel( - B_t, A_t, LDB, LDA, zmm); - B_t += EIGEN_AVX_MAX_K_UNROL*LDB; - EIGEN_IF_CONSTEXPR(isARowMajor) A_t += EIGEN_AVX_MAX_K_UNROL; else A_t += EIGEN_AVX_MAX_K_UNROL*LDA; + for (; i < M_; i += EIGEN_AVX_MAX_NUM_ROW) { + Scalar *A_t = &A_arr[idA(i, 0, LDA)], *B_t = &B_arr[0 * LDB + j]; + PacketBlock zmm; + urolls::template setzero<1, EIGEN_AVX_MAX_NUM_ROW>(zmm); + for (int64_t k = 0; k < K_; k += EIGEN_AVX_MAX_K_UNROL) { + urolls::template microKernel(B_t, A_t, LDB, LDA, zmm); + B_t += EIGEN_AVX_MAX_K_UNROL * LDB; + EIGEN_IF_CONSTEXPR(isARowMajor) A_t += EIGEN_AVX_MAX_K_UNROL; + else A_t += EIGEN_AVX_MAX_K_UNROL * LDA; } EIGEN_IF_CONSTEXPR(handleKRem) { - for(int64_t k = K_; k < K ; k ++) { - urolls:: template microKernel( - B_t, A_t, LDB, LDA, zmm); + for (int64_t k = K_; k < K; k++) { + urolls::template microKernel(B_t, A_t, LDB, LDA, zmm); B_t += LDB; - EIGEN_IF_CONSTEXPR(isARowMajor) A_t++; else A_t += LDA; + EIGEN_IF_CONSTEXPR(isARowMajor) A_t++; + else A_t += LDA; } } EIGEN_IF_CONSTEXPR(isCRowMajor) { - urolls::template updateC<1,EIGEN_AVX_MAX_NUM_ROW>(&C_arr[i*LDC + j], LDC, zmm); - urolls::template storeC<1,EIGEN_AVX_MAX_NUM_ROW>(&C_arr[i*LDC + j], LDC, zmm); + urolls::template updateC<1, EIGEN_AVX_MAX_NUM_ROW>(&C_arr[i * LDC + j], LDC, zmm); + urolls::template storeC<1, EIGEN_AVX_MAX_NUM_ROW>(&C_arr[i * LDC + j], LDC, zmm); } else { - transStoreC(zmm, &C_arr[i + j*LDC], LDC); + transStoreC(zmm, &C_arr[i + j * LDC], LDC); } } - if(M - i >= 4) { // Note: this block assumes EIGEN_AVX_MAX_NUM_ROW = 8. Should be removed otherwise - Scalar *A_t = &A_arr[idA(i,0,LDA)]; - Scalar *B_t = &B_arr[0*LDB + j]; - PacketBlock zmm; - urolls::template setzero<1,4>(zmm); - for(int64_t k = 0; k < K_ ; k += EIGEN_AVX_MAX_K_UNROL) { - urolls:: template microKernel( - B_t, A_t, LDB, LDA, zmm); - B_t += EIGEN_AVX_MAX_K_UNROL*LDB; - EIGEN_IF_CONSTEXPR(isARowMajor) A_t += EIGEN_AVX_MAX_K_UNROL; else A_t += EIGEN_AVX_MAX_K_UNROL*LDA; + if (M - i >= 4) { // Note: this block assumes EIGEN_AVX_MAX_NUM_ROW = 8. Should be removed otherwise + Scalar *A_t = &A_arr[idA(i, 0, LDA)]; + Scalar *B_t = &B_arr[0 * LDB + j]; + PacketBlock zmm; + urolls::template setzero<1, 4>(zmm); + for (int64_t k = 0; k < K_; k += EIGEN_AVX_MAX_K_UNROL) { + urolls::template microKernel(B_t, A_t, LDB, LDA, zmm); + B_t += EIGEN_AVX_MAX_K_UNROL * LDB; + EIGEN_IF_CONSTEXPR(isARowMajor) A_t += EIGEN_AVX_MAX_K_UNROL; + else A_t += EIGEN_AVX_MAX_K_UNROL * LDA; } EIGEN_IF_CONSTEXPR(handleKRem) { - for(int64_t k = K_; k < K ; k ++) { - urolls:: template microKernel( - B_t, A_t, LDB, LDA, zmm); + for (int64_t k = K_; k < K; k++) { + urolls::template microKernel(B_t, A_t, LDB, + LDA, zmm); B_t += LDB; - EIGEN_IF_CONSTEXPR(isARowMajor) A_t++; else A_t += LDA; + EIGEN_IF_CONSTEXPR(isARowMajor) A_t++; + else A_t += LDA; } } EIGEN_IF_CONSTEXPR(isCRowMajor) { - urolls::template updateC<1,4>(&C_arr[i*LDC + j], LDC, zmm); - urolls::template storeC<1,4>(&C_arr[i*LDC + j], LDC, zmm); + urolls::template updateC<1, 4>(&C_arr[i * LDC + j], LDC, zmm); + urolls::template storeC<1, 4>(&C_arr[i * LDC + j], LDC, zmm); } else { - transStoreC(zmm, &C_arr[i + j*LDC], LDC, 4); + transStoreC(zmm, &C_arr[i + j * LDC], LDC, 4); } i += 4; } - if(M - i >= 2) { - Scalar *A_t = &A_arr[idA(i,0,LDA)]; - Scalar *B_t = &B_arr[0*LDB + j]; - PacketBlock zmm; - urolls::template setzero<1,2>(zmm); - for(int64_t k = 0; k < K_ ; k += EIGEN_AVX_MAX_K_UNROL) { - urolls:: template microKernel( - B_t, A_t, LDB, LDA, zmm); - B_t += EIGEN_AVX_MAX_K_UNROL*LDB; - EIGEN_IF_CONSTEXPR(isARowMajor) A_t += EIGEN_AVX_MAX_K_UNROL; else A_t += EIGEN_AVX_MAX_K_UNROL*LDA; + if (M - i >= 2) { + Scalar *A_t = &A_arr[idA(i, 0, LDA)]; + Scalar *B_t = &B_arr[0 * LDB + j]; + PacketBlock zmm; + urolls::template setzero<1, 2>(zmm); + for (int64_t k = 0; k < K_; k += EIGEN_AVX_MAX_K_UNROL) { + urolls::template microKernel(B_t, A_t, LDB, LDA, zmm); + B_t += EIGEN_AVX_MAX_K_UNROL * LDB; + EIGEN_IF_CONSTEXPR(isARowMajor) A_t += EIGEN_AVX_MAX_K_UNROL; + else A_t += EIGEN_AVX_MAX_K_UNROL * LDA; } EIGEN_IF_CONSTEXPR(handleKRem) { - for(int64_t k = K_; k < K ; k ++) { - urolls:: template microKernel( - B_t, A_t, LDB, LDA, zmm); + for (int64_t k = K_; k < K; k++) { + urolls::template microKernel(B_t, A_t, LDB, + LDA, zmm); B_t += LDB; - EIGEN_IF_CONSTEXPR(isARowMajor) A_t++; else A_t += LDA; + EIGEN_IF_CONSTEXPR(isARowMajor) A_t++; + else A_t += LDA; } } EIGEN_IF_CONSTEXPR(isCRowMajor) { - urolls::template updateC<1,2>(&C_arr[i*LDC + j], LDC, zmm); - urolls::template storeC<1,2>(&C_arr[i*LDC + j], LDC, zmm); + urolls::template updateC<1, 2>(&C_arr[i * LDC + j], LDC, zmm); + urolls::template storeC<1, 2>(&C_arr[i * LDC + j], LDC, zmm); } else { - transStoreC(zmm, &C_arr[i + j*LDC], LDC, 2); + transStoreC(zmm, &C_arr[i + j * LDC], LDC, 2); } i += 2; } - if(M - i > 0) { - Scalar *A_t = &A_arr[idA(i,0,LDA)]; - Scalar *B_t = &B_arr[0*LDB + j]; - PacketBlock zmm; - urolls::template setzero<1,1>(zmm); + if (M - i > 0) { + Scalar *A_t = &A_arr[idA(i, 0, LDA)]; + Scalar *B_t = &B_arr[0 * LDB + j]; + PacketBlock zmm; + urolls::template setzero<1, 1>(zmm); { - for(int64_t k = 0; k < K_ ; k += EIGEN_AVX_MAX_K_UNROL) { - urolls:: template microKernel( - B_t, A_t, LDB, LDA, zmm); - B_t += EIGEN_AVX_MAX_K_UNROL*LDB; - EIGEN_IF_CONSTEXPR(isARowMajor) A_t += EIGEN_AVX_MAX_K_UNROL; else A_t += EIGEN_AVX_MAX_K_UNROL*LDA; + for (int64_t k = 0; k < K_; k += EIGEN_AVX_MAX_K_UNROL) { + urolls::template microKernel(B_t, A_t, LDB, + LDA, zmm); + B_t += EIGEN_AVX_MAX_K_UNROL * LDB; + EIGEN_IF_CONSTEXPR(isARowMajor) A_t += EIGEN_AVX_MAX_K_UNROL; + else A_t += EIGEN_AVX_MAX_K_UNROL * LDA; } EIGEN_IF_CONSTEXPR(handleKRem) { - for(int64_t k = K_; k < K ; k ++) { - urolls:: template microKernel(B_t, A_t, LDB, LDA, zmm); - B_t += LDB; - EIGEN_IF_CONSTEXPR(isARowMajor) A_t++; else A_t += LDA; + for (int64_t k = K_; k < K; k++) { + urolls::template microKernel(B_t, A_t, LDB, LDA, zmm); + B_t += LDB; + EIGEN_IF_CONSTEXPR(isARowMajor) A_t++; + else A_t += LDA; + } } - } EIGEN_IF_CONSTEXPR(isCRowMajor) { - urolls::template updateC<1,1>(&C_arr[i*LDC + j], LDC, zmm); - urolls::template storeC<1,1>(&C_arr[i*LDC + j], LDC, zmm); + urolls::template updateC<1, 1>(&C_arr[i * LDC + j], LDC, zmm); + urolls::template storeC<1, 1>(&C_arr[i * LDC + j], LDC, zmm); } else { - transStoreC(zmm, &C_arr[i + j*LDC], LDC, 1); + transStoreC(zmm, &C_arr[i + j * LDC], LDC, 1); } } } j += U1; } - if(N - j > 0) { - constexpr int64_t EIGEN_AVX_MAX_B_LOAD = EIGEN_AVX_B_LOAD_SETS*1; + if (N - j > 0) { + constexpr int64_t EIGEN_AVX_MAX_B_LOAD = EIGEN_AVX_B_LOAD_SETS * 1; int64_t i = 0; - for(; i < M_; i += EIGEN_AVX_MAX_NUM_ROW) { - Scalar *A_t = &A_arr[idA(i,0,LDA)]; - Scalar *B_t = &B_arr[0*LDB + j]; - PacketBlock zmm; - urolls::template setzero<1,EIGEN_AVX_MAX_NUM_ROW>(zmm); - for(int64_t k = 0; k < K_ ; k += EIGEN_AVX_MAX_K_UNROL) { - urolls:: template microKernel( - B_t, A_t, LDB, LDA, zmm, N - j); - B_t += EIGEN_AVX_MAX_K_UNROL*LDB; - EIGEN_IF_CONSTEXPR(isARowMajor) A_t += EIGEN_AVX_MAX_K_UNROL; else A_t += EIGEN_AVX_MAX_K_UNROL*LDA; + for (; i < M_; i += EIGEN_AVX_MAX_NUM_ROW) { + Scalar *A_t = &A_arr[idA(i, 0, LDA)]; + Scalar *B_t = &B_arr[0 * LDB + j]; + PacketBlock zmm; + urolls::template setzero<1, EIGEN_AVX_MAX_NUM_ROW>(zmm); + for (int64_t k = 0; k < K_; k += EIGEN_AVX_MAX_K_UNROL) { + urolls::template microKernel(B_t, A_t, LDB, LDA, zmm, N - j); + B_t += EIGEN_AVX_MAX_K_UNROL * LDB; + EIGEN_IF_CONSTEXPR(isARowMajor) A_t += EIGEN_AVX_MAX_K_UNROL; + else A_t += EIGEN_AVX_MAX_K_UNROL * LDA; } EIGEN_IF_CONSTEXPR(handleKRem) { - for(int64_t k = K_; k < K ; k ++) { - urolls:: template microKernel( - B_t, A_t, LDB, LDA, zmm, N - j); + for (int64_t k = K_; k < K; k++) { + urolls::template microKernel(B_t, A_t, LDB, LDA, zmm, N - j); B_t += LDB; - EIGEN_IF_CONSTEXPR(isARowMajor) A_t++; else A_t += LDA; + EIGEN_IF_CONSTEXPR(isARowMajor) A_t++; + else A_t += LDA; } } EIGEN_IF_CONSTEXPR(isCRowMajor) { - urolls::template updateC<1,EIGEN_AVX_MAX_NUM_ROW,true>(&C_arr[i*LDC + j], LDC, zmm, N - j); - urolls::template storeC<1,EIGEN_AVX_MAX_NUM_ROW,true>(&C_arr[i*LDC + j], LDC, zmm, N - j); + urolls::template updateC<1, EIGEN_AVX_MAX_NUM_ROW, true>(&C_arr[i * LDC + j], LDC, zmm, N - j); + urolls::template storeC<1, EIGEN_AVX_MAX_NUM_ROW, true>(&C_arr[i * LDC + j], LDC, zmm, N - j); } else { - transStoreC(zmm, &C_arr[i + j*LDC], LDC, 0, N-j); + transStoreC(zmm, &C_arr[i + j * LDC], LDC, 0, N - j); } } - if(M - i >= 4) { // Note: this block assumes EIGEN_AVX_MAX_NUM_ROW = 8. Should be removed otherwise - Scalar *A_t = &A_arr[idA(i,0,LDA)]; - Scalar *B_t = &B_arr[0*LDB + j]; - PacketBlock zmm; - urolls::template setzero<1,4>(zmm); - for(int64_t k = 0; k < K_ ; k += EIGEN_AVX_MAX_K_UNROL) { - urolls:: template microKernel( - B_t, A_t, LDB, LDA, zmm, N - j); - B_t += EIGEN_AVX_MAX_K_UNROL*LDB; - EIGEN_IF_CONSTEXPR(isARowMajor) A_t += EIGEN_AVX_MAX_K_UNROL; else A_t += EIGEN_AVX_MAX_K_UNROL*LDA; + if (M - i >= 4) { // Note: this block assumes EIGEN_AVX_MAX_NUM_ROW = 8. Should be removed otherwise + Scalar *A_t = &A_arr[idA(i, 0, LDA)]; + Scalar *B_t = &B_arr[0 * LDB + j]; + PacketBlock zmm; + urolls::template setzero<1, 4>(zmm); + for (int64_t k = 0; k < K_; k += EIGEN_AVX_MAX_K_UNROL) { + urolls::template microKernel(B_t, A_t, LDB, LDA, zmm, N - j); + B_t += EIGEN_AVX_MAX_K_UNROL * LDB; + EIGEN_IF_CONSTEXPR(isARowMajor) A_t += EIGEN_AVX_MAX_K_UNROL; + else A_t += EIGEN_AVX_MAX_K_UNROL * LDA; } EIGEN_IF_CONSTEXPR(handleKRem) { - for(int64_t k = K_; k < K ; k ++) { - urolls:: template microKernel( - B_t, A_t, LDB, LDA, zmm, N - j); + for (int64_t k = K_; k < K; k++) { + urolls::template microKernel( + B_t, A_t, LDB, LDA, zmm, N - j); B_t += LDB; - EIGEN_IF_CONSTEXPR(isARowMajor) A_t++; else A_t += LDA; + EIGEN_IF_CONSTEXPR(isARowMajor) A_t++; + else A_t += LDA; } } EIGEN_IF_CONSTEXPR(isCRowMajor) { - urolls::template updateC<1,4,true>(&C_arr[i*LDC + j], LDC, zmm, N - j); - urolls::template storeC<1,4,true>(&C_arr[i*LDC + j], LDC, zmm, N - j); + urolls::template updateC<1, 4, true>(&C_arr[i * LDC + j], LDC, zmm, N - j); + urolls::template storeC<1, 4, true>(&C_arr[i * LDC + j], LDC, zmm, N - j); } else { - transStoreC(zmm, &C_arr[i + j*LDC], LDC, 4, N-j); + transStoreC(zmm, &C_arr[i + j * LDC], LDC, 4, N - j); } i += 4; } - if(M - i >= 2) { - Scalar *A_t = &A_arr[idA(i,0,LDA)]; - Scalar *B_t = &B_arr[0*LDB + j]; - PacketBlock zmm; - urolls::template setzero<1,2>(zmm); - for(int64_t k = 0; k < K_ ; k += EIGEN_AVX_MAX_K_UNROL) { - urolls:: template microKernel( - B_t, A_t, LDB, LDA, zmm, N - j); - B_t += EIGEN_AVX_MAX_K_UNROL*LDB; - EIGEN_IF_CONSTEXPR(isARowMajor) A_t += EIGEN_AVX_MAX_K_UNROL; else A_t += EIGEN_AVX_MAX_K_UNROL*LDA; + if (M - i >= 2) { + Scalar *A_t = &A_arr[idA(i, 0, LDA)]; + Scalar *B_t = &B_arr[0 * LDB + j]; + PacketBlock zmm; + urolls::template setzero<1, 2>(zmm); + for (int64_t k = 0; k < K_; k += EIGEN_AVX_MAX_K_UNROL) { + urolls::template microKernel(B_t, A_t, LDB, LDA, zmm, N - j); + B_t += EIGEN_AVX_MAX_K_UNROL * LDB; + EIGEN_IF_CONSTEXPR(isARowMajor) A_t += EIGEN_AVX_MAX_K_UNROL; + else A_t += EIGEN_AVX_MAX_K_UNROL * LDA; } EIGEN_IF_CONSTEXPR(handleKRem) { - for(int64_t k = K_; k < K ; k ++) { - urolls:: template microKernel( - B_t, A_t, LDB, LDA, zmm, N - j); + for (int64_t k = K_; k < K; k++) { + urolls::template microKernel( + B_t, A_t, LDB, LDA, zmm, N - j); B_t += LDB; - EIGEN_IF_CONSTEXPR(isARowMajor) A_t++; else A_t += LDA; + EIGEN_IF_CONSTEXPR(isARowMajor) A_t++; + else A_t += LDA; } } EIGEN_IF_CONSTEXPR(isCRowMajor) { - urolls::template updateC<1,2,true>(&C_arr[i*LDC + j], LDC, zmm, N - j); - urolls::template storeC<1,2,true>(&C_arr[i*LDC + j], LDC, zmm, N - j); + urolls::template updateC<1, 2, true>(&C_arr[i * LDC + j], LDC, zmm, N - j); + urolls::template storeC<1, 2, true>(&C_arr[i * LDC + j], LDC, zmm, N - j); } else { - transStoreC(zmm, &C_arr[i + j*LDC], LDC, 2, N-j); + transStoreC(zmm, &C_arr[i + j * LDC], LDC, 2, N - j); } i += 2; } - if(M - i > 0) { - Scalar *A_t = &A_arr[idA(i,0,LDA)]; - Scalar *B_t = &B_arr[0*LDB + j]; - PacketBlock zmm; - urolls::template setzero<1,1>(zmm); - for(int64_t k = 0; k < K_ ; k += EIGEN_AVX_MAX_K_UNROL) { - urolls:: template microKernel( - B_t, A_t, LDB, LDA, zmm, N - j); - B_t += EIGEN_AVX_MAX_K_UNROL*LDB; - EIGEN_IF_CONSTEXPR(isARowMajor) A_t += EIGEN_AVX_MAX_K_UNROL; else A_t += EIGEN_AVX_MAX_K_UNROL*LDA; + if (M - i > 0) { + Scalar *A_t = &A_arr[idA(i, 0, LDA)]; + Scalar *B_t = &B_arr[0 * LDB + j]; + PacketBlock zmm; + urolls::template setzero<1, 1>(zmm); + for (int64_t k = 0; k < K_; k += EIGEN_AVX_MAX_K_UNROL) { + urolls::template microKernel( + B_t, A_t, LDB, LDA, zmm, N - j); + B_t += EIGEN_AVX_MAX_K_UNROL * LDB; + EIGEN_IF_CONSTEXPR(isARowMajor) A_t += EIGEN_AVX_MAX_K_UNROL; + else A_t += EIGEN_AVX_MAX_K_UNROL * LDA; } EIGEN_IF_CONSTEXPR(handleKRem) { - for(int64_t k = K_; k < K ; k ++) { - urolls:: template microKernel( - B_t, A_t, LDB, LDA, zmm, N - j); + for (int64_t k = K_; k < K; k++) { + urolls::template microKernel(B_t, A_t, LDB, LDA, zmm, + N - j); B_t += LDB; - EIGEN_IF_CONSTEXPR(isARowMajor) A_t++; else A_t += LDA; + EIGEN_IF_CONSTEXPR(isARowMajor) A_t++; + else A_t += LDA; } } EIGEN_IF_CONSTEXPR(isCRowMajor) { - urolls::template updateC<1,1,true>(&C_arr[i*LDC + j], LDC, zmm, N - j); - urolls::template storeC<1,1,true>(&C_arr[i*LDC + j], LDC, zmm, N - j); + urolls::template updateC<1, 1, true>(&C_arr[i * LDC + j], LDC, zmm, N - j); + urolls::template storeC<1, 1, true>(&C_arr[i * LDC + j], LDC, zmm, N - j); } else { - transStoreC(zmm, &C_arr[i + j*LDC], LDC, 1, N-j); + transStoreC(zmm, &C_arr[i + j * LDC], LDC, 1, N - j); } } } @@ -705,48 +721,46 @@ void gemmKernel(Scalar *A_arr, Scalar *B_arr, Scalar *C_arr, * isFWDSolve: is forward solve? * isUnitDiag: is the diagonal of A all ones? * The B matrix (RHS) is assumed to be row-major -*/ + */ template -static EIGEN_ALWAYS_INLINE -void triSolveKernel(Scalar *A_arr, Scalar *B_arr, int64_t K, int64_t LDA, int64_t LDB) { - - static_assert( unrollM <= EIGEN_AVX_MAX_NUM_ROW, "unrollM should be equal to EIGEN_AVX_MAX_NUM_ROW" ); +static EIGEN_ALWAYS_INLINE void triSolveKernel(Scalar *A_arr, Scalar *B_arr, int64_t K, int64_t LDA, int64_t LDB) { + static_assert(unrollM <= EIGEN_AVX_MAX_NUM_ROW, "unrollM should be equal to EIGEN_AVX_MAX_NUM_ROW"); using urolls = unrolls::trsm; constexpr int64_t U3 = urolls::PacketSize * 3; constexpr int64_t U2 = urolls::PacketSize * 2; constexpr int64_t U1 = urolls::PacketSize * 1; - PacketBlock RHSInPacket; - PacketBlock AInPacket; + PacketBlock RHSInPacket; + PacketBlock AInPacket; int64_t k = 0; - while(K - k >= U3) { - urolls:: template loadRHS(B_arr + k, LDB, RHSInPacket); - urolls:: template triSolveMicroKernel( - A_arr, LDA, RHSInPacket, AInPacket); - urolls:: template storeRHS(B_arr + k, LDB, RHSInPacket); + while (K - k >= U3) { + urolls::template loadRHS(B_arr + k, LDB, RHSInPacket); + urolls::template triSolveMicroKernel(A_arr, LDA, RHSInPacket, + AInPacket); + urolls::template storeRHS(B_arr + k, LDB, RHSInPacket); k += U3; } - if(K - k >= U2) { - urolls:: template loadRHS(B_arr + k, LDB, RHSInPacket); - urolls:: template triSolveMicroKernel( - A_arr, LDA, RHSInPacket, AInPacket); - urolls:: template storeRHS(B_arr + k, LDB, RHSInPacket); + if (K - k >= U2) { + urolls::template loadRHS(B_arr + k, LDB, RHSInPacket); + urolls::template triSolveMicroKernel(A_arr, LDA, RHSInPacket, + AInPacket); + urolls::template storeRHS(B_arr + k, LDB, RHSInPacket); k += U2; } - if(K - k >= U1) { - urolls:: template loadRHS(B_arr + k, LDB, RHSInPacket); - urolls:: template triSolveMicroKernel( - A_arr, LDA, RHSInPacket, AInPacket); - urolls:: template storeRHS(B_arr + k, LDB, RHSInPacket); + if (K - k >= U1) { + urolls::template loadRHS(B_arr + k, LDB, RHSInPacket); + urolls::template triSolveMicroKernel(A_arr, LDA, RHSInPacket, + AInPacket); + urolls::template storeRHS(B_arr + k, LDB, RHSInPacket); k += U1; } - if(K - k > 0) { + if (K - k > 0) { // Handle remaining number of RHS - urolls::template loadRHS(B_arr + k, LDB, RHSInPacket, K-k); - urolls::template triSolveMicroKernel( - A_arr, LDA, RHSInPacket, AInPacket); - urolls::template storeRHS(B_arr + k, LDB, RHSInPacket, K-k); + urolls::template loadRHS(B_arr + k, LDB, RHSInPacket, K - k); + urolls::template triSolveMicroKernel(A_arr, LDA, RHSInPacket, + AInPacket); + urolls::template storeRHS(B_arr + k, LDB, RHSInPacket, K - k); } } @@ -757,7 +771,7 @@ void triSolveKernel(Scalar *A_arr, Scalar *B_arr, int64_t K, int64_t LDA, int64_ * isFWDSolve: is forward solve? * isUnitDiag: is the diagonal of A all ones? * The B matrix (RHS) is assumed to be row-major -*/ + */ template void triSolveKernelLxK(Scalar *A_arr, Scalar *B_arr, int64_t M, int64_t K, int64_t LDA, int64_t LDB) { // Note: this assumes EIGEN_AVX_MAX_NUM_ROW = 8. Unrolls should be adjusted @@ -790,59 +804,90 @@ void triSolveKernelLxK(Scalar *A_arr, Scalar *B_arr, int64_t M, int64_t K, int64 * */ template -static EIGEN_ALWAYS_INLINE -void copyBToRowMajor(Scalar *B_arr, int64_t LDB, int64_t K, - Scalar *B_temp, int64_t LDB_, int64_t remM_ = 0) { +static EIGEN_ALWAYS_INLINE void copyBToRowMajor(Scalar *B_arr, int64_t LDB, int64_t K, Scalar *B_temp, int64_t LDB_, + int64_t remM_ = 0) { EIGEN_UNUSED_VARIABLE(remM_); using urolls = unrolls::transB; using vecHalf = typename std::conditional::value, vecHalfFloat, vecFullDouble>::type; - PacketBlock ymm; + PacketBlock ymm; constexpr int64_t U3 = urolls::PacketSize * 3; constexpr int64_t U2 = urolls::PacketSize * 2; constexpr int64_t U1 = urolls::PacketSize * 1; - int64_t K_ = K/U3*U3; + int64_t K_ = K / U3 * U3; int64_t k = 0; - for(; k < K_; k += U3) { - urolls::template transB_kernel(B_arr + k*LDB, LDB, B_temp, LDB_, ymm, remM_); + for (; k < K_; k += U3) { + urolls::template transB_kernel(B_arr + k * LDB, LDB, B_temp, LDB_, ymm, remM_); B_temp += U3; } - if(K - k >= U2) { - urolls::template transB_kernel(B_arr + k*LDB, LDB, B_temp, LDB_, ymm, remM_); - B_temp += U2; k += U2; + if (K - k >= U2) { + urolls::template transB_kernel(B_arr + k * LDB, LDB, B_temp, LDB_, ymm, remM_); + B_temp += U2; + k += U2; } - if(K - k >= U1) { - urolls::template transB_kernel(B_arr + k*LDB, LDB, B_temp, LDB_, ymm, remM_); - B_temp += U1; k += U1; + if (K - k >= U1) { + urolls::template transB_kernel(B_arr + k * LDB, LDB, B_temp, LDB_, ymm, remM_); + B_temp += U1; + k += U1; } - EIGEN_IF_CONSTEXPR( U1 > 8) { + EIGEN_IF_CONSTEXPR(U1 > 8) { // Note: without "if constexpr" this section of code will also be // parsed by the compiler so there is an additional check in {load/store}BBlock // to make sure the counter is not non-negative. - if(K - k >= 8) { - urolls::template transB_kernel<8, toTemp, remM>(B_arr + k*LDB, LDB, B_temp, LDB_, ymm, remM_); - B_temp += 8; k += 8; + if (K - k >= 8) { + urolls::template transB_kernel<8, toTemp, remM>(B_arr + k * LDB, LDB, B_temp, LDB_, ymm, remM_); + B_temp += 8; + k += 8; } } - EIGEN_IF_CONSTEXPR( U1 > 4) { + EIGEN_IF_CONSTEXPR(U1 > 4) { // Note: without "if constexpr" this section of code will also be // parsed by the compiler so there is an additional check in {load/store}BBlock // to make sure the counter is not non-negative. - if(K - k >= 4) { - urolls::template transB_kernel<4, toTemp, remM>(B_arr + k*LDB, LDB, B_temp, LDB_, ymm, remM_); - B_temp += 4; k += 4; + if (K - k >= 4) { + urolls::template transB_kernel<4, toTemp, remM>(B_arr + k * LDB, LDB, B_temp, LDB_, ymm, remM_); + B_temp += 4; + k += 4; } } - if(K - k >= 2) { - urolls::template transB_kernel<2, toTemp, remM>(B_arr + k*LDB, LDB, B_temp, LDB_, ymm, remM_); - B_temp += 2; k += 2; + if (K - k >= 2) { + urolls::template transB_kernel<2, toTemp, remM>(B_arr + k * LDB, LDB, B_temp, LDB_, ymm, remM_); + B_temp += 2; + k += 2; } - if(K - k >= 1) { - urolls::template transB_kernel<1, toTemp, remM>(B_arr + k*LDB, LDB, B_temp, LDB_, ymm, remM_); - B_temp += 1; k += 1; + if (K - k >= 1) { + urolls::template transB_kernel<1, toTemp, remM>(B_arr + k * LDB, LDB, B_temp, LDB_, ymm, remM_); + B_temp += 1; + k += 1; } } +#if (EIGEN_USE_AVX512_TRSM_L_KERNELS) && defined(EIGEN_NO_MALLOC) +/** + * Reduce blocking sizes so that the size of the temporary workspace needed is less than "limit" bytes, + * - kB must be at least psize + * - numM must be at least EIGEN_AVX_MAX_NUM_ROW + */ +template +constexpr std::pair trsmBlocking(const int64_t limit) { + constexpr int64_t psize = packet_traits::size; + int64_t kB = 15 * psize; + int64_t numM = 8 * EIGEN_AVX_MAX_NUM_ROW; + // If B is rowmajor, no temp workspace needed, so use default blocking sizes. + if (isBRowMajor) return {kB, numM}; + + // Very simple heuristic, prefer keeping kB as large as possible to fully use vector registers. + for (int64_t k = kB; k > psize; k -= psize) { + for (int64_t m = numM; m > EIGEN_AVX_MAX_NUM_ROW; m -= EIGEN_AVX_MAX_NUM_ROW) { + if ((((k + psize - 1) / psize + 4) * psize) * m * sizeof(Scalar) < limit) { + return {k, m}; + } + } + } + return {psize, EIGEN_AVX_MAX_NUM_ROW}; // Minimum blocking size required +} +#endif // (EIGEN_USE_AVX512_TRSM_L_KERNELS) && defined(EIGEN_NO_MALLOC) + /** * Main triangular solve driver * @@ -869,9 +914,11 @@ void copyBToRowMajor(Scalar *B_arr, int64_t LDB, int64_t K, * * Note: For RXX cases M,numRHS should be swapped. * -*/ -template + */ +template void triSolve(Scalar *A_arr, Scalar *B_arr, int64_t M, int64_t numRHS, int64_t LDA, int64_t LDB) { + constexpr int64_t psize = packet_traits::size; /** * The values for kB, numM were determined experimentally. * kB: Number of RHS we process at a time. @@ -885,8 +932,30 @@ void triSolve(Scalar *A_arr, Scalar *B_arr, int64_t M, int64_t numRHS, int64_t L * large enough to allow GEMM updates to have larger "K"s (see below.) No benchmarking has been done so far to * determine optimal values for numM. */ - const int64_t kB = (3*packet_traits::size)*5; // 5*U3 - const int64_t numM = 64; +#if (EIGEN_USE_AVX512_TRSM_L_KERNELS) && defined(EIGEN_NO_MALLOC) + /** + * If EIGEN_NO_MALLOC is requested, we try to reduce kB and numM so the maximum temp workspace required is less + * than EIGEN_STACK_ALLOCATION_LIMIT. Actual workspace size may be less, depending on the number of vectors to + * solve. + * - kB must be at least psize + * - numM must be at least EIGEN_AVX_MAX_NUM_ROW + * + * If B is row-major, the blocking sizes are not reduced (no temp workspace needed). + */ + constexpr std::pair blocking_ = trsmBlocking(EIGEN_STACK_ALLOCATION_LIMIT); + constexpr int64_t kB = blocking_.first; + constexpr int64_t numM = blocking_.second; + /** + * If the temp workspace size exceeds EIGEN_STACK_ALLOCATION_LIMIT even with the minimum blocking sizes, + * we throw an assertion. Use -DEIGEN_USE_AVX512_TRSM_L_KERNELS=0 if necessary + */ + static_assert(!(((((kB + psize - 1) / psize + 4) * psize) * numM * sizeof(Scalar) >= EIGEN_STACK_ALLOCATION_LIMIT) && + !isBRowMajor), + "Temp workspace required is too large."); +#else + constexpr int64_t kB = (3 * psize) * 5; // 5*U3 + constexpr int64_t numM = 8 * EIGEN_AVX_MAX_NUM_ROW; +#endif int64_t sizeBTemp = 0; Scalar *B_temp = NULL; @@ -896,42 +965,50 @@ void triSolve(Scalar *A_arr, Scalar *B_arr, int64_t M, int64_t numRHS, int64_t L * transpose it to row-major. Call the solve routine, and copy+transpose it back to the original array. * The updated row-major copy of B is reused in the GEMM updates. */ - sizeBTemp = (((std::min(kB, numRHS) + 15)/16+ 4)*16)*numM; - B_temp = (Scalar*) handmade_aligned_malloc(sizeof(Scalar)*sizeBTemp,4096); + sizeBTemp = (((std::min(kB, numRHS) + psize - 1) / psize + 4) * psize) * numM; } - for(int64_t k = 0; k < numRHS; k += kB) { + +#if !defined(EIGEN_NO_MALLOC) + EIGEN_IF_CONSTEXPR(!isBRowMajor) B_temp = (Scalar *)handmade_aligned_malloc(sizeof(Scalar) * sizeBTemp, 64); +#elif (EIGEN_USE_AVX512_TRSM_L_KERNELS) && defined(EIGEN_NO_MALLOC) + // Use alloca if malloc not allowed, requested temp workspace size should be less than EIGEN_STACK_ALLOCATION_LIMIT + ei_declare_aligned_stack_constructed_variable(Scalar, B_temp_alloca, sizeBTemp, 0); + B_temp = B_temp_alloca; +#endif + + for (int64_t k = 0; k < numRHS; k += kB) { int64_t bK = numRHS - k > kB ? kB : numRHS - k; - int64_t M_ = (M/EIGEN_AVX_MAX_NUM_ROW)*EIGEN_AVX_MAX_NUM_ROW, gemmOff = 0; + int64_t M_ = (M / EIGEN_AVX_MAX_NUM_ROW) * EIGEN_AVX_MAX_NUM_ROW, gemmOff = 0; // bK rounded up to next multiple of L=EIGEN_AVX_MAX_NUM_ROW. When B_temp is used, we solve for bkL RHS // instead of bK RHS in triSolveKernelLxK. - int64_t bkL = ((bK + (EIGEN_AVX_MAX_NUM_ROW-1))/EIGEN_AVX_MAX_NUM_ROW)*EIGEN_AVX_MAX_NUM_ROW; - const int64_t numScalarPerCache = 64/sizeof(Scalar); + int64_t bkL = ((bK + (EIGEN_AVX_MAX_NUM_ROW - 1)) / EIGEN_AVX_MAX_NUM_ROW) * EIGEN_AVX_MAX_NUM_ROW; + const int64_t numScalarPerCache = 64 / sizeof(Scalar); // Leading dimension of B_temp, will be a multiple of the cache line size. - int64_t LDT = ((bkL+(numScalarPerCache-1))/numScalarPerCache)*numScalarPerCache; + int64_t LDT = ((bkL + (numScalarPerCache - 1)) / numScalarPerCache) * numScalarPerCache; int64_t offsetBTemp = 0; - for(int64_t i = 0; i < M_; i += EIGEN_AVX_MAX_NUM_ROW) { + for (int64_t i = 0; i < M_; i += EIGEN_AVX_MAX_NUM_ROW) { EIGEN_IF_CONSTEXPR(!isBRowMajor) { - int64_t indA_i = isFWDSolve ? i : M - 1 - i; - int64_t indB_i = isFWDSolve ? i : M - (i + EIGEN_AVX_MAX_NUM_ROW); - int64_t offB_1 = isFWDSolve ? offsetBTemp : sizeBTemp - EIGEN_AVX_MAX_NUM_ROW*LDT - offsetBTemp; - int64_t offB_2 = isFWDSolve ? offsetBTemp : sizeBTemp - LDT - offsetBTemp; + int64_t indA_i = isFWDSolve ? i : M - 1 - i; + int64_t indB_i = isFWDSolve ? i : M - (i + EIGEN_AVX_MAX_NUM_ROW); + int64_t offB_1 = isFWDSolve ? offsetBTemp : sizeBTemp - EIGEN_AVX_MAX_NUM_ROW * LDT - offsetBTemp; + int64_t offB_2 = isFWDSolve ? offsetBTemp : sizeBTemp - LDT - offsetBTemp; // Copy values from B to B_temp. - copyBToRowMajor(B_arr + indB_i + k*LDB, LDB, bK, B_temp + offB_1, LDT); + copyBToRowMajor(B_arr + indB_i + k * LDB, LDB, bK, B_temp + offB_1, LDT); // Triangular solve with a small block of A and long horizontal blocks of B (or B_temp if B col-major) triSolveKernelLxK( - &A_arr[idA(indA_i, indA_i, LDA)], B_temp + offB_2, EIGEN_AVX_MAX_NUM_ROW, bkL, LDA, LDT); + &A_arr[idA(indA_i, indA_i, LDA)], B_temp + offB_2, EIGEN_AVX_MAX_NUM_ROW, bkL, LDA, LDT); // Copy values from B_temp back to B. B_temp will be reused in gemm call below. - copyBToRowMajor(B_arr + indB_i + k*LDB, LDB, bK, B_temp + offB_1, LDT); + copyBToRowMajor(B_arr + indB_i + k * LDB, LDB, bK, B_temp + offB_1, LDT); - offsetBTemp += EIGEN_AVX_MAX_NUM_ROW*LDT; + offsetBTemp += EIGEN_AVX_MAX_NUM_ROW * LDT; } else { int64_t ind = isFWDSolve ? i : M - 1 - i; triSolveKernelLxK( - &A_arr[idA(ind, ind, LDA)], B_arr + k + ind*LDB, EIGEN_AVX_MAX_NUM_ROW, bK, LDA, LDB); + &A_arr[idA(ind, ind, LDA)], B_arr + k + ind * LDB, EIGEN_AVX_MAX_NUM_ROW, bK, LDA, LDB); } - if(i+EIGEN_AVX_MAX_NUM_ROW < M_) { + if (i + EIGEN_AVX_MAX_NUM_ROW < M_) { /** * For the GEMM updates, we want "K" (K=i+8 in this case) to be large as soon as possible * to reuse the accumulators in GEMM as much as possible. So we only update 8xbK blocks of @@ -945,19 +1022,16 @@ void triSolve(Scalar *A_arr, Scalar *B_arr, int64_t M, int64_t numRHS, int64_t L * |********|__| |**| */ EIGEN_IF_CONSTEXPR(isBRowMajor) { - int64_t indA_i = isFWDSolve ? i + EIGEN_AVX_MAX_NUM_ROW : M - (i + 2*EIGEN_AVX_MAX_NUM_ROW); - int64_t indA_j = isFWDSolve ? 0 : M - (i + EIGEN_AVX_MAX_NUM_ROW); - int64_t indB_i = isFWDSolve ? 0 : M - (i + EIGEN_AVX_MAX_NUM_ROW); - int64_t indB_i2 = isFWDSolve ? i + EIGEN_AVX_MAX_NUM_ROW : M - (i + 2*EIGEN_AVX_MAX_NUM_ROW); - gemmKernel( - &A_arr[idA(indA_i,indA_j,LDA)], - B_arr + k + indB_i*LDB, - B_arr + k + indB_i2*LDB, - EIGEN_AVX_MAX_NUM_ROW, bK, i + EIGEN_AVX_MAX_NUM_ROW, - LDA, LDB, LDB); + int64_t indA_i = isFWDSolve ? i + EIGEN_AVX_MAX_NUM_ROW : M - (i + 2 * EIGEN_AVX_MAX_NUM_ROW); + int64_t indA_j = isFWDSolve ? 0 : M - (i + EIGEN_AVX_MAX_NUM_ROW); + int64_t indB_i = isFWDSolve ? 0 : M - (i + EIGEN_AVX_MAX_NUM_ROW); + int64_t indB_i2 = isFWDSolve ? i + EIGEN_AVX_MAX_NUM_ROW : M - (i + 2 * EIGEN_AVX_MAX_NUM_ROW); + gemmKernel( + &A_arr[idA(indA_i, indA_j, LDA)], B_arr + k + indB_i * LDB, B_arr + k + indB_i2 * LDB, + EIGEN_AVX_MAX_NUM_ROW, bK, i + EIGEN_AVX_MAX_NUM_ROW, LDA, LDB, LDB); } else { - if(offsetBTemp + EIGEN_AVX_MAX_NUM_ROW*LDT > sizeBTemp) { + if (offsetBTemp + EIGEN_AVX_MAX_NUM_ROW * LDT > sizeBTemp) { /** * Similar idea as mentioned above, but here we are limited by the number of updated values of B * that can be stored (row-major) in B_temp. @@ -966,164 +1040,148 @@ void triSolve(Scalar *A_arr, Scalar *B_arr, int64_t M, int64_t numRHS, int64_t L * update and partially update the remaining old values of B which depends on the new values * of B stored in B_temp. These values are then no longer needed and can be overwritten. */ - int64_t indA_i = isFWDSolve ? i + EIGEN_AVX_MAX_NUM_ROW : 0; - int64_t indA_j = isFWDSolve ? gemmOff : M - (i + EIGEN_AVX_MAX_NUM_ROW); - int64_t indB_i = isFWDSolve ? i + EIGEN_AVX_MAX_NUM_ROW : 0; - int64_t offB_1 = isFWDSolve ? 0 : sizeBTemp - offsetBTemp; - gemmKernel( - &A_arr[idA(indA_i, indA_j,LDA)], - B_temp + offB_1, - B_arr + indB_i + (k)*LDB, - M - (i + EIGEN_AVX_MAX_NUM_ROW), bK, i + EIGEN_AVX_MAX_NUM_ROW - gemmOff, - LDA, LDT, LDB); - offsetBTemp = 0; gemmOff = i + EIGEN_AVX_MAX_NUM_ROW; - } - else { + int64_t indA_i = isFWDSolve ? i + EIGEN_AVX_MAX_NUM_ROW : 0; + int64_t indA_j = isFWDSolve ? gemmOff : M - (i + EIGEN_AVX_MAX_NUM_ROW); + int64_t indB_i = isFWDSolve ? i + EIGEN_AVX_MAX_NUM_ROW : 0; + int64_t offB_1 = isFWDSolve ? 0 : sizeBTemp - offsetBTemp; + gemmKernel( + &A_arr[idA(indA_i, indA_j, LDA)], B_temp + offB_1, B_arr + indB_i + (k)*LDB, + M - (i + EIGEN_AVX_MAX_NUM_ROW), bK, i + EIGEN_AVX_MAX_NUM_ROW - gemmOff, LDA, LDT, LDB); + offsetBTemp = 0; + gemmOff = i + EIGEN_AVX_MAX_NUM_ROW; + } else { /** * If there is enough space in B_temp, we only update the next 8xbK values of B. */ - int64_t indA_i = isFWDSolve ? i + EIGEN_AVX_MAX_NUM_ROW : M - (i + 2*EIGEN_AVX_MAX_NUM_ROW); - int64_t indA_j = isFWDSolve ? gemmOff : M - (i + EIGEN_AVX_MAX_NUM_ROW); - int64_t indB_i = isFWDSolve ? i + EIGEN_AVX_MAX_NUM_ROW : M - (i + 2*EIGEN_AVX_MAX_NUM_ROW); - int64_t offB_1 = isFWDSolve ? 0 : sizeBTemp - offsetBTemp; - gemmKernel( - &A_arr[idA(indA_i,indA_j,LDA)], - B_temp + offB_1, - B_arr + indB_i + (k)*LDB, - EIGEN_AVX_MAX_NUM_ROW, bK, i + EIGEN_AVX_MAX_NUM_ROW - gemmOff, - LDA, LDT, LDB); + int64_t indA_i = isFWDSolve ? i + EIGEN_AVX_MAX_NUM_ROW : M - (i + 2 * EIGEN_AVX_MAX_NUM_ROW); + int64_t indA_j = isFWDSolve ? gemmOff : M - (i + EIGEN_AVX_MAX_NUM_ROW); + int64_t indB_i = isFWDSolve ? i + EIGEN_AVX_MAX_NUM_ROW : M - (i + 2 * EIGEN_AVX_MAX_NUM_ROW); + int64_t offB_1 = isFWDSolve ? 0 : sizeBTemp - offsetBTemp; + gemmKernel( + &A_arr[idA(indA_i, indA_j, LDA)], B_temp + offB_1, B_arr + indB_i + (k)*LDB, + EIGEN_AVX_MAX_NUM_ROW, bK, i + EIGEN_AVX_MAX_NUM_ROW - gemmOff, LDA, LDT, LDB); } } } } // Handle M remainder.. - int64_t bM = M-M_; - if (bM > 0){ - if( M_ > 0) { + int64_t bM = M - M_; + if (bM > 0) { + if (M_ > 0) { EIGEN_IF_CONSTEXPR(isBRowMajor) { - int64_t indA_i = isFWDSolve ? M_ : 0; - int64_t indA_j = isFWDSolve ? 0 : bM; - int64_t indB_i = isFWDSolve ? 0 : bM; + int64_t indA_i = isFWDSolve ? M_ : 0; + int64_t indA_j = isFWDSolve ? 0 : bM; + int64_t indB_i = isFWDSolve ? 0 : bM; int64_t indB_i2 = isFWDSolve ? M_ : 0; - gemmKernel( - &A_arr[idA(indA_i,indA_j,LDA)], - B_arr + k +indB_i*LDB, - B_arr + k + indB_i2*LDB, - bM , bK, M_, - LDA, LDB, LDB); + gemmKernel( + &A_arr[idA(indA_i, indA_j, LDA)], B_arr + k + indB_i * LDB, B_arr + k + indB_i2 * LDB, bM, + bK, M_, LDA, LDB, LDB); } else { - int64_t indA_i = isFWDSolve ? M_ : 0; + int64_t indA_i = isFWDSolve ? M_ : 0; int64_t indA_j = isFWDSolve ? gemmOff : bM; - int64_t indB_i = isFWDSolve ? M_ : 0; - int64_t offB_1 = isFWDSolve ? 0 : sizeBTemp - offsetBTemp; - gemmKernel( - &A_arr[idA(indA_i,indA_j,LDA)], - B_temp + offB_1, - B_arr + indB_i + (k)*LDB, - bM , bK, M_ - gemmOff, - LDA, LDT, LDB); + int64_t indB_i = isFWDSolve ? M_ : 0; + int64_t offB_1 = isFWDSolve ? 0 : sizeBTemp - offsetBTemp; + gemmKernel(&A_arr[idA(indA_i, indA_j, LDA)], + B_temp + offB_1, B_arr + indB_i + (k)*LDB, bM, bK, + M_ - gemmOff, LDA, LDT, LDB); } } EIGEN_IF_CONSTEXPR(!isBRowMajor) { - int64_t indA_i = isFWDSolve ? M_ : M - 1 - M_; - int64_t indB_i = isFWDSolve ? M_ : 0; - int64_t offB_1 = isFWDSolve ? 0 : (bM-1)*bkL; - copyBToRowMajor(B_arr + indB_i + k*LDB, LDB, bK, B_temp, bkL, bM); - triSolveKernelLxK( - &A_arr[idA(indA_i, indA_i, LDA)], B_temp + offB_1, bM, bkL, LDA, bkL); - copyBToRowMajor(B_arr + indB_i + k*LDB, LDB, bK, B_temp, bkL, bM); + int64_t indA_i = isFWDSolve ? M_ : M - 1 - M_; + int64_t indB_i = isFWDSolve ? M_ : 0; + int64_t offB_1 = isFWDSolve ? 0 : (bM - 1) * bkL; + copyBToRowMajor(B_arr + indB_i + k * LDB, LDB, bK, B_temp, bkL, bM); + triSolveKernelLxK(&A_arr[idA(indA_i, indA_i, LDA)], + B_temp + offB_1, bM, bkL, LDA, bkL); + copyBToRowMajor(B_arr + indB_i + k * LDB, LDB, bK, B_temp, bkL, bM); } else { int64_t ind = isFWDSolve ? M_ : M - 1 - M_; - triSolveKernelLxK( - &A_arr[idA(ind, ind, LDA)], B_arr + k + ind*LDB, bM, bK, LDA, LDB); + triSolveKernelLxK(&A_arr[idA(ind, ind, LDA)], + B_arr + k + ind * LDB, bM, bK, LDA, LDB); } } } +#if !defined(EIGEN_NO_MALLOC) EIGEN_IF_CONSTEXPR(!isBRowMajor) handmade_aligned_free(B_temp); +#endif } // Template specializations of trsmKernelL/R for float/double and inner strides of 1. -#if defined(EIGEN_USE_AVX512_TRSM_KERNELS) -template +#if (EIGEN_USE_AVX512_TRSM_KERNELS) +#if (EIGEN_USE_AVX512_TRSM_R_KERNELS) +template struct trsmKernelR; template -struct trsmKernelR{ - static void kernel(Index size, Index otherSize, const float* _tri, Index triStride, - float* _other, Index otherIncr, Index otherStride); +struct trsmKernelR { + static void kernel(Index size, Index otherSize, const float *_tri, Index triStride, float *_other, Index otherIncr, + Index otherStride); }; template -struct trsmKernelR{ - static void kernel(Index size, Index otherSize, const double* _tri, Index triStride, - double* _other, Index otherIncr, Index otherStride); +struct trsmKernelR { + static void kernel(Index size, Index otherSize, const double *_tri, Index triStride, double *_other, Index otherIncr, + Index otherStride); }; template EIGEN_DONT_INLINE void trsmKernelR::kernel( - Index size, Index otherSize, - const float* _tri, Index triStride, - float* _other, Index otherIncr, Index otherStride) -{ + Index size, Index otherSize, const float *_tri, Index triStride, float *_other, Index otherIncr, + Index otherStride) { EIGEN_UNUSED_VARIABLE(otherIncr); - triSolve( - const_cast(_tri), _other, size, otherSize, triStride, otherStride); + triSolve( + const_cast(_tri), _other, size, otherSize, triStride, otherStride); } template EIGEN_DONT_INLINE void trsmKernelR::kernel( - Index size, Index otherSize, - const double* _tri, Index triStride, - double* _other, Index otherIncr, Index otherStride) -{ + Index size, Index otherSize, const double *_tri, Index triStride, double *_other, Index otherIncr, + Index otherStride) { EIGEN_UNUSED_VARIABLE(otherIncr); - triSolve( - const_cast(_tri), _other, size, otherSize, triStride, otherStride); + triSolve( + const_cast(_tri), _other, size, otherSize, triStride, otherStride); } +#endif // (EIGEN_USE_AVX512_TRSM_R_KERNELS) -// These trsm kernels require temporary memory allocation, so disable them if malloc is not allowed. -#if defined(EIGEN_USE_AVX512_TRSM_L_KERNELS) -template +// These trsm kernels require temporary memory allocation +#if (EIGEN_USE_AVX512_TRSM_L_KERNELS) +template struct trsmKernelL; template -struct trsmKernelL{ - static void kernel(Index size, Index otherSize, const float* _tri, Index triStride, - float* _other, Index otherIncr, Index otherStride); +struct trsmKernelL { + static void kernel(Index size, Index otherSize, const float *_tri, Index triStride, float *_other, Index otherIncr, + Index otherStride); }; template -struct trsmKernelL{ - static void kernel(Index size, Index otherSize, const double* _tri, Index triStride, - double* _other, Index otherIncr, Index otherStride); +struct trsmKernelL { + static void kernel(Index size, Index otherSize, const double *_tri, Index triStride, double *_other, Index otherIncr, + Index otherStride); }; template EIGEN_DONT_INLINE void trsmKernelL::kernel( - Index size, Index otherSize, - const float* _tri, Index triStride, - float* _other, Index otherIncr, Index otherStride) -{ + Index size, Index otherSize, const float *_tri, Index triStride, float *_other, Index otherIncr, + Index otherStride) { EIGEN_UNUSED_VARIABLE(otherIncr); - triSolve( - const_cast(_tri), _other, size, otherSize, triStride, otherStride); + triSolve( + const_cast(_tri), _other, size, otherSize, triStride, otherStride); } template EIGEN_DONT_INLINE void trsmKernelL::kernel( - Index size, Index otherSize, - const double* _tri, Index triStride, - double* _other, Index otherIncr, Index otherStride) -{ + Index size, Index otherSize, const double *_tri, Index triStride, double *_other, Index otherIncr, + Index otherStride) { EIGEN_UNUSED_VARIABLE(otherIncr); - triSolve( - const_cast(_tri), _other, size, otherSize, triStride, otherStride); + triSolve( + const_cast(_tri), _other, size, otherSize, triStride, otherStride); } -#endif //EIGEN_USE_AVX512_TRSM_L_KERNELS -#endif //EIGEN_USE_AVX512_TRSM_KERNELS -} -} -#endif //EIGEN_TRSM_KERNEL_IMPL_H +#endif // EIGEN_USE_AVX512_TRSM_L_KERNELS +#endif // EIGEN_USE_AVX512_TRSM_KERNELS +} // namespace internal +} // namespace Eigen +#endif // EIGEN_TRSM_KERNEL_IMPL_H diff --git a/Eigen/src/Core/arch/AVX512/TrsmUnrolls.inc b/Eigen/src/Core/arch/AVX512/TrsmUnrolls.inc index 03640c9b7..032937cfd 100644 --- a/Eigen/src/Core/arch/AVX512/TrsmUnrolls.inc +++ b/Eigen/src/Core/arch/AVX512/TrsmUnrolls.inc @@ -11,8 +11,7 @@ #define EIGEN_UNROLLS_IMPL_H template -static EIGEN_ALWAYS_INLINE -int64_t idA(int64_t i, int64_t j, int64_t LDA) { +static EIGEN_ALWAYS_INLINE int64_t idA(int64_t i, int64_t j, int64_t LDA) { EIGEN_IF_CONSTEXPR(isARowMajor) return i * LDA + j; else return i + j * LDA; } @@ -59,34 +58,38 @@ namespace unrolls { template EIGEN_ALWAYS_INLINE auto remMask(int64_t m) { - EIGEN_IF_CONSTEXPR( N == 16) { return 0xFFFF >> (16 - m); } - else EIGEN_IF_CONSTEXPR( N == 8) { return 0xFF >> (8 - m); } - else EIGEN_IF_CONSTEXPR( N == 4) { return 0x0F >> (4 - m); } + EIGEN_IF_CONSTEXPR(N == 16) { return 0xFFFF >> (16 - m); } + else EIGEN_IF_CONSTEXPR(N == 8) { + return 0xFF >> (8 - m); + } + else EIGEN_IF_CONSTEXPR(N == 4) { + return 0x0F >> (4 - m); + } return 0; } template -EIGEN_ALWAYS_INLINE void trans8x8blocks(PacketBlock& kernel); +EIGEN_ALWAYS_INLINE void trans8x8blocks(PacketBlock &kernel); template <> -EIGEN_ALWAYS_INLINE void trans8x8blocks(PacketBlock& kernel) { - __m512 T0 = _mm512_unpacklo_ps(kernel.packet[0],kernel.packet[1]); - __m512 T1 = _mm512_unpackhi_ps(kernel.packet[0],kernel.packet[1]); - __m512 T2 = _mm512_unpacklo_ps(kernel.packet[2],kernel.packet[3]); - __m512 T3 = _mm512_unpackhi_ps(kernel.packet[2],kernel.packet[3]); - __m512 T4 = _mm512_unpacklo_ps(kernel.packet[4],kernel.packet[5]); - __m512 T5 = _mm512_unpackhi_ps(kernel.packet[4],kernel.packet[5]); - __m512 T6 = _mm512_unpacklo_ps(kernel.packet[6],kernel.packet[7]); - __m512 T7 = _mm512_unpackhi_ps(kernel.packet[6],kernel.packet[7]); +EIGEN_ALWAYS_INLINE void trans8x8blocks(PacketBlock &kernel) { + __m512 T0 = _mm512_unpacklo_ps(kernel.packet[0], kernel.packet[1]); + __m512 T1 = _mm512_unpackhi_ps(kernel.packet[0], kernel.packet[1]); + __m512 T2 = _mm512_unpacklo_ps(kernel.packet[2], kernel.packet[3]); + __m512 T3 = _mm512_unpackhi_ps(kernel.packet[2], kernel.packet[3]); + __m512 T4 = _mm512_unpacklo_ps(kernel.packet[4], kernel.packet[5]); + __m512 T5 = _mm512_unpackhi_ps(kernel.packet[4], kernel.packet[5]); + __m512 T6 = _mm512_unpacklo_ps(kernel.packet[6], kernel.packet[7]); + __m512 T7 = _mm512_unpackhi_ps(kernel.packet[6], kernel.packet[7]); - kernel.packet[0] = _mm512_castpd_ps(_mm512_unpacklo_pd(_mm512_castps_pd(T0),_mm512_castps_pd(T2))); - kernel.packet[1] = _mm512_castpd_ps(_mm512_unpackhi_pd(_mm512_castps_pd(T0),_mm512_castps_pd(T2))); - kernel.packet[2] = _mm512_castpd_ps(_mm512_unpacklo_pd(_mm512_castps_pd(T1),_mm512_castps_pd(T3))); - kernel.packet[3] = _mm512_castpd_ps(_mm512_unpackhi_pd(_mm512_castps_pd(T1),_mm512_castps_pd(T3))); - kernel.packet[4] = _mm512_castpd_ps(_mm512_unpacklo_pd(_mm512_castps_pd(T4),_mm512_castps_pd(T6))); - kernel.packet[5] = _mm512_castpd_ps(_mm512_unpackhi_pd(_mm512_castps_pd(T4),_mm512_castps_pd(T6))); - kernel.packet[6] = _mm512_castpd_ps(_mm512_unpacklo_pd(_mm512_castps_pd(T5),_mm512_castps_pd(T7))); - kernel.packet[7] = _mm512_castpd_ps(_mm512_unpackhi_pd(_mm512_castps_pd(T5),_mm512_castps_pd(T7))); + kernel.packet[0] = _mm512_castpd_ps(_mm512_unpacklo_pd(_mm512_castps_pd(T0), _mm512_castps_pd(T2))); + kernel.packet[1] = _mm512_castpd_ps(_mm512_unpackhi_pd(_mm512_castps_pd(T0), _mm512_castps_pd(T2))); + kernel.packet[2] = _mm512_castpd_ps(_mm512_unpacklo_pd(_mm512_castps_pd(T1), _mm512_castps_pd(T3))); + kernel.packet[3] = _mm512_castpd_ps(_mm512_unpackhi_pd(_mm512_castps_pd(T1), _mm512_castps_pd(T3))); + kernel.packet[4] = _mm512_castpd_ps(_mm512_unpacklo_pd(_mm512_castps_pd(T4), _mm512_castps_pd(T6))); + kernel.packet[5] = _mm512_castpd_ps(_mm512_unpackhi_pd(_mm512_castps_pd(T4), _mm512_castps_pd(T6))); + kernel.packet[6] = _mm512_castpd_ps(_mm512_unpacklo_pd(_mm512_castps_pd(T5), _mm512_castps_pd(T7))); + kernel.packet[7] = _mm512_castpd_ps(_mm512_unpackhi_pd(_mm512_castps_pd(T5), _mm512_castps_pd(T7))); T0 = _mm512_castpd_ps(_mm512_permutex_pd(_mm512_castps_pd(kernel.packet[4]), 0x4E)); T0 = _mm512_mask_blend_ps(0xF0F0, kernel.packet[0], T0); @@ -105,14 +108,18 @@ EIGEN_ALWAYS_INLINE void trans8x8blocks(PacketBlock& kernel) { T7 = _mm512_castpd_ps(_mm512_permutex_pd(_mm512_castps_pd(kernel.packet[3]), 0x4E)); T7 = _mm512_mask_blend_ps(0xF0F0, T7, kernel.packet[7]); - kernel.packet[0] = T0; kernel.packet[1] = T1; - kernel.packet[2] = T2; kernel.packet[3] = T3; - kernel.packet[4] = T4; kernel.packet[5] = T5; - kernel.packet[6] = T6; kernel.packet[7] = T7; + kernel.packet[0] = T0; + kernel.packet[1] = T1; + kernel.packet[2] = T2; + kernel.packet[3] = T3; + kernel.packet[4] = T4; + kernel.packet[5] = T5; + kernel.packet[6] = T6; + kernel.packet[7] = T7; } template <> -EIGEN_ALWAYS_INLINE void trans8x8blocks(PacketBlock& kernel) { +EIGEN_ALWAYS_INLINE void trans8x8blocks(PacketBlock &kernel) { ptranspose(kernel); } @@ -121,12 +128,11 @@ EIGEN_ALWAYS_INLINE void trans8x8blocks(PacketBlock& kernel) { */ template class trans { -public: + public: using vec = typename std::conditional::value, vecFullFloat, vecFullDouble>::type; using vecHalf = typename std::conditional::value, vecHalfFloat, vecFullDouble>::type; static constexpr int64_t PacketSize = packet_traits::size; - /*********************************** * Auxillary Functions for: * - storeC @@ -142,70 +148,67 @@ public: * (endN <= PacketSize) is required to handle the fp32 case, see comments in transStoreC * **/ - template - static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0 && endN <= PacketSize)> - aux_storeC(Scalar *C_arr, int64_t LDC, - PacketBlock &zmm, int64_t remM_ = 0) { - constexpr int64_t counterReverse = endN-counter; + template + static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0 && endN <= PacketSize)> aux_storeC( + Scalar *C_arr, int64_t LDC, PacketBlock &zmm, int64_t remM_ = 0) { + constexpr int64_t counterReverse = endN - counter; constexpr int64_t startN = counterReverse; EIGEN_IF_CONSTEXPR(startN < EIGEN_AVX_MAX_NUM_ROW) { EIGEN_IF_CONSTEXPR(remM) { pstoreu( - C_arr + LDC*startN, - padd(ploadu((const Scalar*)C_arr + LDC*startN, remMask(remM_)), - preinterpret(zmm.packet[packetIndexOffset + (unrollN/PacketSize)*startN]), - remMask(remM_)), - remMask(remM_)); + C_arr + LDC * startN, + padd(ploadu((const Scalar *)C_arr + LDC * startN, remMask(remM_)), + preinterpret(zmm.packet[packetIndexOffset + (unrollN / PacketSize) * startN]), + remMask(remM_)), + remMask(remM_)); } else { - pstoreu( - C_arr + LDC*startN, - padd(ploadu((const Scalar*)C_arr + LDC*startN), - preinterpret(zmm.packet[packetIndexOffset + (unrollN/PacketSize)*startN]))); + pstoreu(C_arr + LDC * startN, + padd(ploadu((const Scalar *)C_arr + LDC * startN), + preinterpret(zmm.packet[packetIndexOffset + (unrollN / PacketSize) * startN]))); } } - else { // This block is only needed for fp32 case + else { // This block is only needed for fp32 case // Reinterpret as __m512 for _mm512_shuffle_f32x4 vecFullFloat zmm2vecFullFloat = preinterpret( - zmm.packet[packetIndexOffset + (unrollN/PacketSize)*(startN - EIGEN_AVX_MAX_NUM_ROW)]); + zmm.packet[packetIndexOffset + (unrollN / PacketSize) * (startN - EIGEN_AVX_MAX_NUM_ROW)]); // Swap lower and upper half of avx register. - zmm.packet[packetIndexOffset + (unrollN/PacketSize)*(startN - EIGEN_AVX_MAX_NUM_ROW)] = - preinterpret(_mm512_shuffle_f32x4(zmm2vecFullFloat, zmm2vecFullFloat, 0b01001110)); + zmm.packet[packetIndexOffset + (unrollN / PacketSize) * (startN - EIGEN_AVX_MAX_NUM_ROW)] = + preinterpret(_mm512_shuffle_f32x4(zmm2vecFullFloat, zmm2vecFullFloat, 0b01001110)); EIGEN_IF_CONSTEXPR(remM) { pstoreu( - C_arr + LDC*startN, - padd(ploadu((const Scalar*)C_arr + LDC*startN, - remMask(remM_)), - preinterpret(zmm.packet[packetIndexOffset + (unrollN/PacketSize)*(startN-EIGEN_AVX_MAX_NUM_ROW)])), - remMask(remM_)); + C_arr + LDC * startN, + padd(ploadu((const Scalar *)C_arr + LDC * startN, remMask(remM_)), + preinterpret( + zmm.packet[packetIndexOffset + (unrollN / PacketSize) * (startN - EIGEN_AVX_MAX_NUM_ROW)])), + remMask(remM_)); } else { pstoreu( - C_arr + LDC*startN, - padd(ploadu((const Scalar*)C_arr + LDC*startN), - preinterpret(zmm.packet[packetIndexOffset + (unrollN/PacketSize)*(startN-EIGEN_AVX_MAX_NUM_ROW)]))); + C_arr + LDC * startN, + padd(ploadu((const Scalar *)C_arr + LDC * startN), + preinterpret( + zmm.packet[packetIndexOffset + (unrollN / PacketSize) * (startN - EIGEN_AVX_MAX_NUM_ROW)]))); } } aux_storeC(C_arr, LDC, zmm, remM_); } - template - static EIGEN_ALWAYS_INLINE std::enable_if_t 0 && endN <= PacketSize)> - aux_storeC(Scalar *C_arr, int64_t LDC, - PacketBlock &zmm, int64_t remM_ = 0) - { - EIGEN_UNUSED_VARIABLE(C_arr); - EIGEN_UNUSED_VARIABLE(LDC); - EIGEN_UNUSED_VARIABLE(zmm); - EIGEN_UNUSED_VARIABLE(remM_); - } + template + static EIGEN_ALWAYS_INLINE std::enable_if_t 0 && endN <= PacketSize)> aux_storeC( + Scalar *C_arr, int64_t LDC, PacketBlock &zmm, int64_t remM_ = 0) { + EIGEN_UNUSED_VARIABLE(C_arr); + EIGEN_UNUSED_VARIABLE(LDC); + EIGEN_UNUSED_VARIABLE(zmm); + EIGEN_UNUSED_VARIABLE(remM_); + } - template - static EIGEN_ALWAYS_INLINE - void storeC(Scalar *C_arr, int64_t LDC, - PacketBlock &zmm, int64_t remM_ = 0){ + template + static EIGEN_ALWAYS_INLINE void storeC(Scalar *C_arr, int64_t LDC, + PacketBlock &zmm, + int64_t remM_ = 0) { aux_storeC(C_arr, LDC, zmm, remM_); } @@ -234,30 +237,29 @@ public: * EIGEN_AVX_MAX_NUM_ROW. packetIndexOffset is used to select which "block" of * avx registers are being transposed. */ - template - static EIGEN_ALWAYS_INLINE - void transpose(PacketBlock &zmm) { + template + static EIGEN_ALWAYS_INLINE void transpose(PacketBlock &zmm) { // Note: this assumes EIGEN_AVX_MAX_NUM_ROW = 8. Unrolls should be adjusted // accordingly if EIGEN_AVX_MAX_NUM_ROW is smaller. - constexpr int64_t zmmStride = unrollN/PacketSize; - PacketBlock r; - r.packet[0] = zmm.packet[packetIndexOffset + zmmStride*0]; - r.packet[1] = zmm.packet[packetIndexOffset + zmmStride*1]; - r.packet[2] = zmm.packet[packetIndexOffset + zmmStride*2]; - r.packet[3] = zmm.packet[packetIndexOffset + zmmStride*3]; - r.packet[4] = zmm.packet[packetIndexOffset + zmmStride*4]; - r.packet[5] = zmm.packet[packetIndexOffset + zmmStride*5]; - r.packet[6] = zmm.packet[packetIndexOffset + zmmStride*6]; - r.packet[7] = zmm.packet[packetIndexOffset + zmmStride*7]; + constexpr int64_t zmmStride = unrollN / PacketSize; + PacketBlock r; + r.packet[0] = zmm.packet[packetIndexOffset + zmmStride * 0]; + r.packet[1] = zmm.packet[packetIndexOffset + zmmStride * 1]; + r.packet[2] = zmm.packet[packetIndexOffset + zmmStride * 2]; + r.packet[3] = zmm.packet[packetIndexOffset + zmmStride * 3]; + r.packet[4] = zmm.packet[packetIndexOffset + zmmStride * 4]; + r.packet[5] = zmm.packet[packetIndexOffset + zmmStride * 5]; + r.packet[6] = zmm.packet[packetIndexOffset + zmmStride * 6]; + r.packet[7] = zmm.packet[packetIndexOffset + zmmStride * 7]; trans8x8blocks(r); - zmm.packet[packetIndexOffset + zmmStride*0] = r.packet[0]; - zmm.packet[packetIndexOffset + zmmStride*1] = r.packet[1]; - zmm.packet[packetIndexOffset + zmmStride*2] = r.packet[2]; - zmm.packet[packetIndexOffset + zmmStride*3] = r.packet[3]; - zmm.packet[packetIndexOffset + zmmStride*4] = r.packet[4]; - zmm.packet[packetIndexOffset + zmmStride*5] = r.packet[5]; - zmm.packet[packetIndexOffset + zmmStride*6] = r.packet[6]; - zmm.packet[packetIndexOffset + zmmStride*7] = r.packet[7]; + zmm.packet[packetIndexOffset + zmmStride * 0] = r.packet[0]; + zmm.packet[packetIndexOffset + zmmStride * 1] = r.packet[1]; + zmm.packet[packetIndexOffset + zmmStride * 2] = r.packet[2]; + zmm.packet[packetIndexOffset + zmmStride * 3] = r.packet[3]; + zmm.packet[packetIndexOffset + zmmStride * 4] = r.packet[4]; + zmm.packet[packetIndexOffset + zmmStride * 5] = r.packet[5]; + zmm.packet[packetIndexOffset + zmmStride * 6] = r.packet[6]; + zmm.packet[packetIndexOffset + zmmStride * 7] = r.packet[7]; } }; @@ -277,7 +279,7 @@ public: */ template class transB { -public: + public: using vec = typename std::conditional::value, vecFullFloat, vecFullDouble>::type; using vecHalf = typename std::conditional::value, vecHalfFloat, vecFullDouble>::type; static constexpr int64_t PacketSize = packet_traits::size; @@ -297,33 +299,31 @@ public: * 1-D unroll * for(startN = 0; startN < endN; startN++) **/ - template - static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> - aux_loadB(Scalar *B_arr, int64_t LDB, - PacketBlock &ymm, int64_t remM_ = 0) { - constexpr int64_t counterReverse = endN-counter; + template + static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_loadB( + Scalar *B_arr, int64_t LDB, PacketBlock &ymm, + int64_t remM_ = 0) { + constexpr int64_t counterReverse = endN - counter; constexpr int64_t startN = counterReverse; EIGEN_IF_CONSTEXPR(remM) { - ymm.packet[packetIndexOffset + startN] = ploadu( - (const Scalar*)&B_arr[startN*LDB], remMask(remM_)); + ymm.packet[packetIndexOffset + startN] = + ploadu((const Scalar *)&B_arr[startN * LDB], remMask(remM_)); } - else - ymm.packet[packetIndexOffset + startN] = ploadu((const Scalar*)&B_arr[startN*LDB]); + else ymm.packet[packetIndexOffset + startN] = ploadu((const Scalar *)&B_arr[startN * LDB]); - aux_loadB(B_arr, LDB, ymm, remM_); + aux_loadB(B_arr, LDB, ymm, remM_); } - template - static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> - aux_loadB(Scalar *B_arr, int64_t LDB, - PacketBlock &ymm, int64_t remM_ = 0) - { - EIGEN_UNUSED_VARIABLE(B_arr); - EIGEN_UNUSED_VARIABLE(LDB); - EIGEN_UNUSED_VARIABLE(ymm); - EIGEN_UNUSED_VARIABLE(remM_); - } + template + static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_loadB( + Scalar *B_arr, int64_t LDB, PacketBlock &ymm, + int64_t remM_ = 0) { + EIGEN_UNUSED_VARIABLE(B_arr); + EIGEN_UNUSED_VARIABLE(LDB); + EIGEN_UNUSED_VARIABLE(ymm); + EIGEN_UNUSED_VARIABLE(remM_); + } /** * aux_storeB @@ -331,36 +331,31 @@ public: * 1-D unroll * for(startN = 0; startN < endN; startN++) **/ - template - static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> - aux_storeB(Scalar *B_arr, int64_t LDB, - PacketBlock &ymm, int64_t rem_ = 0) { - constexpr int64_t counterReverse = endN-counter; + template + static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_storeB( + Scalar *B_arr, int64_t LDB, PacketBlock &ymm, int64_t rem_ = 0) { + constexpr int64_t counterReverse = endN - counter; constexpr int64_t startN = counterReverse; - EIGEN_IF_CONSTEXPR( remK || remM) { - pstoreu( - &B_arr[startN*LDB], - ymm.packet[packetIndexOffset + startN], - remMask(rem_)); + EIGEN_IF_CONSTEXPR(remK || remM) { + pstoreu(&B_arr[startN * LDB], ymm.packet[packetIndexOffset + startN], + remMask(rem_)); } else { - pstoreu(&B_arr[startN*LDB], ymm.packet[packetIndexOffset + startN]); + pstoreu(&B_arr[startN * LDB], ymm.packet[packetIndexOffset + startN]); } - aux_storeB(B_arr, LDB, ymm, rem_); + aux_storeB(B_arr, LDB, ymm, rem_); } - template - static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> - aux_storeB(Scalar *B_arr, int64_t LDB, - PacketBlock &ymm, int64_t rem_ = 0) - { - EIGEN_UNUSED_VARIABLE(B_arr); - EIGEN_UNUSED_VARIABLE(LDB); - EIGEN_UNUSED_VARIABLE(ymm); - EIGEN_UNUSED_VARIABLE(rem_); - } + template + static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_storeB( + Scalar *B_arr, int64_t LDB, PacketBlock &ymm, int64_t rem_ = 0) { + EIGEN_UNUSED_VARIABLE(B_arr); + EIGEN_UNUSED_VARIABLE(LDB); + EIGEN_UNUSED_VARIABLE(ymm); + EIGEN_UNUSED_VARIABLE(rem_); + } /** * aux_loadBBlock @@ -368,32 +363,27 @@ public: * 1-D unroll * for(startN = 0; startN < endN; startN += EIGEN_AVX_MAX_NUM_ROW) **/ - template - static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> - aux_loadBBlock(Scalar *B_arr, int64_t LDB, Scalar *B_temp, int64_t LDB_, - PacketBlock &ymm, - int64_t remM_ = 0) { - constexpr int64_t counterReverse = endN-counter; + template + static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_loadBBlock( + Scalar *B_arr, int64_t LDB, Scalar *B_temp, int64_t LDB_, + PacketBlock &ymm, int64_t remM_ = 0) { + constexpr int64_t counterReverse = endN - counter; constexpr int64_t startN = counterReverse; - transB::template loadB(&B_temp[startN], LDB_, ymm); - aux_loadBBlock( - B_arr, LDB, B_temp, LDB_, ymm, remM_); + transB::template loadB(&B_temp[startN], LDB_, ymm); + aux_loadBBlock(B_arr, LDB, B_temp, LDB_, ymm, remM_); } - template - static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> - aux_loadBBlock(Scalar *B_arr, int64_t LDB, Scalar *B_temp, int64_t LDB_, - PacketBlock &ymm, - int64_t remM_ = 0) - { - EIGEN_UNUSED_VARIABLE(B_arr); - EIGEN_UNUSED_VARIABLE(LDB); - EIGEN_UNUSED_VARIABLE(B_temp); - EIGEN_UNUSED_VARIABLE(LDB_); - EIGEN_UNUSED_VARIABLE(ymm); - EIGEN_UNUSED_VARIABLE(remM_); - } - + template + static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_loadBBlock( + Scalar *B_arr, int64_t LDB, Scalar *B_temp, int64_t LDB_, + PacketBlock &ymm, int64_t remM_ = 0) { + EIGEN_UNUSED_VARIABLE(B_arr); + EIGEN_UNUSED_VARIABLE(LDB); + EIGEN_UNUSED_VARIABLE(B_temp); + EIGEN_UNUSED_VARIABLE(LDB_); + EIGEN_UNUSED_VARIABLE(ymm); + EIGEN_UNUSED_VARIABLE(remM_); + } /** * aux_storeBBlock @@ -401,88 +391,75 @@ public: * 1-D unroll * for(startN = 0; startN < endN; startN += EIGEN_AVX_MAX_NUM_ROW) **/ - template - static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> - aux_storeBBlock(Scalar *B_arr, int64_t LDB, Scalar *B_temp, int64_t LDB_, - PacketBlock &ymm, - int64_t remM_ = 0) { - constexpr int64_t counterReverse = endN-counter; + template + static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_storeBBlock( + Scalar *B_arr, int64_t LDB, Scalar *B_temp, int64_t LDB_, + PacketBlock &ymm, int64_t remM_ = 0) { + constexpr int64_t counterReverse = endN - counter; constexpr int64_t startN = counterReverse; EIGEN_IF_CONSTEXPR(toTemp) { - transB::template storeB( - &B_temp[startN], LDB_, ymm, remK_); + transB::template storeB(&B_temp[startN], LDB_, ymm, remK_); } else { - transB::template storeB( - &B_arr[0 + startN*LDB], LDB, ymm, remM_); + transB::template storeB(&B_arr[0 + startN * LDB], LDB, + ymm, remM_); } - aux_storeBBlock( - B_arr, LDB, B_temp, LDB_, ymm, remM_); + aux_storeBBlock(B_arr, LDB, B_temp, LDB_, ymm, remM_); } - template - static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> - aux_storeBBlock(Scalar *B_arr, int64_t LDB, Scalar *B_temp, int64_t LDB_, - PacketBlock &ymm, - int64_t remM_ = 0) - { - EIGEN_UNUSED_VARIABLE(B_arr); - EIGEN_UNUSED_VARIABLE(LDB); - EIGEN_UNUSED_VARIABLE(B_temp); - EIGEN_UNUSED_VARIABLE(LDB_); - EIGEN_UNUSED_VARIABLE(ymm); - EIGEN_UNUSED_VARIABLE(remM_); - } - + template + static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_storeBBlock( + Scalar *B_arr, int64_t LDB, Scalar *B_temp, int64_t LDB_, + PacketBlock &ymm, int64_t remM_ = 0) { + EIGEN_UNUSED_VARIABLE(B_arr); + EIGEN_UNUSED_VARIABLE(LDB); + EIGEN_UNUSED_VARIABLE(B_temp); + EIGEN_UNUSED_VARIABLE(LDB_); + EIGEN_UNUSED_VARIABLE(ymm); + EIGEN_UNUSED_VARIABLE(remM_); + } /******************************************************** * Wrappers for aux_XXXX to hide counter parameter ********************************************************/ - template - static EIGEN_ALWAYS_INLINE - void loadB(Scalar *B_arr, int64_t LDB, - PacketBlock &ymm, int64_t remM_ = 0) { + template + static EIGEN_ALWAYS_INLINE void loadB(Scalar *B_arr, int64_t LDB, + PacketBlock &ymm, + int64_t remM_ = 0) { aux_loadB(B_arr, LDB, ymm, remM_); } - template - static EIGEN_ALWAYS_INLINE - void storeB(Scalar *B_arr, int64_t LDB, - PacketBlock &ymm, int64_t rem_ = 0) { + template + static EIGEN_ALWAYS_INLINE void storeB(Scalar *B_arr, int64_t LDB, + PacketBlock &ymm, + int64_t rem_ = 0) { aux_storeB(B_arr, LDB, ymm, rem_); } - template - static EIGEN_ALWAYS_INLINE - void loadBBlock(Scalar *B_arr, int64_t LDB, Scalar *B_temp, int64_t LDB_, - PacketBlock &ymm, - int64_t remM_ = 0) { - EIGEN_IF_CONSTEXPR(toTemp) { - transB::template loadB(&B_arr[0],LDB, ymm, remM_); - } + template + static EIGEN_ALWAYS_INLINE void loadBBlock(Scalar *B_arr, int64_t LDB, Scalar *B_temp, int64_t LDB_, + PacketBlock &ymm, + int64_t remM_ = 0) { + EIGEN_IF_CONSTEXPR(toTemp) { transB::template loadB(&B_arr[0], LDB, ymm, remM_); } else { - aux_loadBBlock( - B_arr, LDB, B_temp, LDB_, ymm, remM_); + aux_loadBBlock(B_arr, LDB, B_temp, LDB_, ymm, remM_); } } - template - static EIGEN_ALWAYS_INLINE - void storeBBlock(Scalar *B_arr, int64_t LDB, Scalar *B_temp, int64_t LDB_, - PacketBlock &ymm, - int64_t remM_ = 0) { - aux_storeBBlock( - B_arr, LDB, B_temp, LDB_, ymm, remM_); + template + static EIGEN_ALWAYS_INLINE void storeBBlock(Scalar *B_arr, int64_t LDB, Scalar *B_temp, int64_t LDB_, + PacketBlock &ymm, + int64_t remM_ = 0) { + aux_storeBBlock(B_arr, LDB, B_temp, LDB_, ymm, remM_); } - template - static EIGEN_ALWAYS_INLINE - void transposeLxL(PacketBlock &ymm){ + template + static EIGEN_ALWAYS_INLINE void transposeLxL(PacketBlock &ymm) { // Note: this assumes EIGEN_AVX_MAX_NUM_ROW = 8. Unrolls should be adjusted // accordingly if EIGEN_AVX_MAX_NUM_ROW is smaller. - PacketBlock r; + PacketBlock r; r.packet[0] = ymm.packet[packetIndexOffset + 0]; r.packet[1] = ymm.packet[packetIndexOffset + 1]; r.packet[2] = ymm.packet[packetIndexOffset + 2]; @@ -502,10 +479,10 @@ public: ymm.packet[packetIndexOffset + 7] = r.packet[7]; } - template - static EIGEN_ALWAYS_INLINE - void transB_kernel(Scalar *B_arr, int64_t LDB, Scalar *B_temp, int64_t LDB_, - PacketBlock &ymm, int64_t remM_ = 0) { + template + static EIGEN_ALWAYS_INLINE void transB_kernel(Scalar *B_arr, int64_t LDB, Scalar *B_temp, int64_t LDB_, + PacketBlock &ymm, + int64_t remM_ = 0) { constexpr int64_t U3 = PacketSize * 3; constexpr int64_t U2 = PacketSize * 2; constexpr int64_t U1 = PacketSize * 1; @@ -518,70 +495,70 @@ public: */ EIGEN_IF_CONSTEXPR(unrollN == U3) { // load LxU3 B col major, transpose LxU3 row major - constexpr int64_t maxUBlock = std::min(3*EIGEN_AVX_MAX_NUM_ROW, U3); - transB::template loadBBlock(B_arr, LDB, B_temp, LDB_, ymm, remM_); - transB::template transposeLxL<0*EIGEN_AVX_MAX_NUM_ROW>(ymm); - transB::template transposeLxL<1*EIGEN_AVX_MAX_NUM_ROW>(ymm); - transB::template transposeLxL<2*EIGEN_AVX_MAX_NUM_ROW>(ymm); - transB::template storeBBlock(B_arr, LDB, B_temp, LDB_, ymm, remM_); + constexpr int64_t maxUBlock = std::min(3 * EIGEN_AVX_MAX_NUM_ROW, U3); + transB::template loadBBlock(B_arr, LDB, B_temp, LDB_, ymm, remM_); + transB::template transposeLxL<0 * EIGEN_AVX_MAX_NUM_ROW>(ymm); + transB::template transposeLxL<1 * EIGEN_AVX_MAX_NUM_ROW>(ymm); + transB::template transposeLxL<2 * EIGEN_AVX_MAX_NUM_ROW>(ymm); + transB::template storeBBlock(B_arr, LDB, B_temp, LDB_, ymm, remM_); - EIGEN_IF_CONSTEXPR( maxUBlock < U3) { - transB::template loadBBlock(&B_arr[maxUBlock*LDB], LDB, &B_temp[maxUBlock], LDB_, ymm, remM_); - transB::template transposeLxL<0*EIGEN_AVX_MAX_NUM_ROW>(ymm); - transB::template transposeLxL<1*EIGEN_AVX_MAX_NUM_ROW>(ymm); - transB::template transposeLxL<2*EIGEN_AVX_MAX_NUM_ROW>(ymm); - transB::template storeBBlock(&B_arr[maxUBlock*LDB], LDB, &B_temp[maxUBlock], LDB_, ymm, remM_); + EIGEN_IF_CONSTEXPR(maxUBlock < U3) { + transB::template loadBBlock(&B_arr[maxUBlock * LDB], LDB, &B_temp[maxUBlock], LDB_, + ymm, remM_); + transB::template transposeLxL<0 * EIGEN_AVX_MAX_NUM_ROW>(ymm); + transB::template transposeLxL<1 * EIGEN_AVX_MAX_NUM_ROW>(ymm); + transB::template transposeLxL<2 * EIGEN_AVX_MAX_NUM_ROW>(ymm); + transB::template storeBBlock(&B_arr[maxUBlock * LDB], LDB, &B_temp[maxUBlock], LDB_, + ymm, remM_); } } else EIGEN_IF_CONSTEXPR(unrollN == U2) { // load LxU2 B col major, transpose LxU2 row major - constexpr int64_t maxUBlock = std::min(3*EIGEN_AVX_MAX_NUM_ROW, U2); - transB::template loadBBlock(B_arr, LDB, B_temp, LDB_, ymm, remM_); - transB::template transposeLxL<0*EIGEN_AVX_MAX_NUM_ROW>(ymm); - transB::template transposeLxL<1*EIGEN_AVX_MAX_NUM_ROW>(ymm); - EIGEN_IF_CONSTEXPR(maxUBlock < U2) transB::template transposeLxL<2*EIGEN_AVX_MAX_NUM_ROW>(ymm); - transB::template storeBBlock(B_arr, LDB, B_temp, LDB_, ymm, remM_); + constexpr int64_t maxUBlock = std::min(3 * EIGEN_AVX_MAX_NUM_ROW, U2); + transB::template loadBBlock(B_arr, LDB, B_temp, LDB_, ymm, remM_); + transB::template transposeLxL<0 * EIGEN_AVX_MAX_NUM_ROW>(ymm); + transB::template transposeLxL<1 * EIGEN_AVX_MAX_NUM_ROW>(ymm); + EIGEN_IF_CONSTEXPR(maxUBlock < U2) transB::template transposeLxL<2 * EIGEN_AVX_MAX_NUM_ROW>(ymm); + transB::template storeBBlock(B_arr, LDB, B_temp, LDB_, ymm, remM_); - EIGEN_IF_CONSTEXPR( maxUBlock < U2) { - transB::template loadBBlock( - &B_arr[maxUBlock*LDB], LDB, &B_temp[maxUBlock], LDB_, ymm, remM_); + EIGEN_IF_CONSTEXPR(maxUBlock < U2) { + transB::template loadBBlock(&B_arr[maxUBlock * LDB], LDB, + &B_temp[maxUBlock], LDB_, ymm, remM_); transB::template transposeLxL<0>(ymm); - transB::template storeBBlock( - &B_arr[maxUBlock*LDB], LDB, &B_temp[maxUBlock], LDB_, ymm, remM_); + transB::template storeBBlock(&B_arr[maxUBlock * LDB], LDB, + &B_temp[maxUBlock], LDB_, ymm, remM_); } } else EIGEN_IF_CONSTEXPR(unrollN == U1) { // load LxU1 B col major, transpose LxU1 row major - transB::template loadBBlock(B_arr, LDB, B_temp, LDB_, ymm, remM_); + transB::template loadBBlock(B_arr, LDB, B_temp, LDB_, ymm, remM_); transB::template transposeLxL<0>(ymm); - EIGEN_IF_CONSTEXPR(EIGEN_AVX_MAX_NUM_ROW < U1) { - transB::template transposeLxL<1*EIGEN_AVX_MAX_NUM_ROW>(ymm); - } - transB::template storeBBlock(B_arr, LDB, B_temp, LDB_, ymm, remM_); + EIGEN_IF_CONSTEXPR(EIGEN_AVX_MAX_NUM_ROW < U1) { transB::template transposeLxL<1 * EIGEN_AVX_MAX_NUM_ROW>(ymm); } + transB::template storeBBlock(B_arr, LDB, B_temp, LDB_, ymm, remM_); } else EIGEN_IF_CONSTEXPR(unrollN == 8 && U1 > 8) { // load Lx4 B col major, transpose Lx4 row major - transB::template loadBBlock<8,toTemp,remM>(B_arr, LDB, B_temp, LDB_, ymm, remM_); + transB::template loadBBlock<8, toTemp, remM>(B_arr, LDB, B_temp, LDB_, ymm, remM_); transB::template transposeLxL<0>(ymm); - transB::template storeBBlock<8,toTemp,remM,8>(B_arr, LDB, B_temp, LDB_, ymm, remM_); + transB::template storeBBlock<8, toTemp, remM, 8>(B_arr, LDB, B_temp, LDB_, ymm, remM_); } else EIGEN_IF_CONSTEXPR(unrollN == 4 && U1 > 4) { // load Lx4 B col major, transpose Lx4 row major - transB::template loadBBlock<4,toTemp,remM>(B_arr, LDB, B_temp, LDB_, ymm, remM_); + transB::template loadBBlock<4, toTemp, remM>(B_arr, LDB, B_temp, LDB_, ymm, remM_); transB::template transposeLxL<0>(ymm); - transB::template storeBBlock<4,toTemp,remM,4>(B_arr, LDB, B_temp, LDB_, ymm, remM_); + transB::template storeBBlock<4, toTemp, remM, 4>(B_arr, LDB, B_temp, LDB_, ymm, remM_); } else EIGEN_IF_CONSTEXPR(unrollN == 2) { // load Lx2 B col major, transpose Lx2 row major - transB::template loadBBlock<2,toTemp,remM>(B_arr, LDB, B_temp, LDB_, ymm, remM_); + transB::template loadBBlock<2, toTemp, remM>(B_arr, LDB, B_temp, LDB_, ymm, remM_); transB::template transposeLxL<0>(ymm); - transB::template storeBBlock<2,toTemp,remM,2>(B_arr, LDB, B_temp, LDB_, ymm, remM_); + transB::template storeBBlock<2, toTemp, remM, 2>(B_arr, LDB, B_temp, LDB_, ymm, remM_); } else EIGEN_IF_CONSTEXPR(unrollN == 1) { // load Lx1 B col major, transpose Lx1 row major - transB::template loadBBlock<1,toTemp,remM>(B_arr, LDB, B_temp, LDB_, ymm, remM_); + transB::template loadBBlock<1, toTemp, remM>(B_arr, LDB, B_temp, LDB_, ymm, remM_); transB::template transposeLxL<0>(ymm); - transB::template storeBBlock<1,toTemp,remM,1>(B_arr, LDB, B_temp, LDB_, ymm, remM_); + transB::template storeBBlock<1, toTemp, remM, 1>(B_arr, LDB, B_temp, LDB_, ymm, remM_); } } }; @@ -600,10 +577,8 @@ public: */ template class trsm { -public: - using vec = typename std::conditional::value, - vecFullFloat, - vecFullDouble>::type; + public: + using vec = typename std::conditional::value, vecFullFloat, vecFullDouble>::type; static constexpr int64_t PacketSize = packet_traits::size; /*********************************** @@ -621,35 +596,33 @@ public: * for(startM = 0; startM < endM; startM++) * for(startK = 0; startK < endK; startK++) **/ - template - static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> - aux_loadRHS(Scalar* B_arr, int64_t LDB, PacketBlock &RHSInPacket, int64_t rem = 0) { + template + static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_loadRHS( + Scalar *B_arr, int64_t LDB, PacketBlock &RHSInPacket, int64_t rem = 0) { + constexpr int64_t counterReverse = endM * endK - counter; + constexpr int64_t startM = counterReverse / (endK); + constexpr int64_t startK = counterReverse % endK; - constexpr int64_t counterReverse = endM*endK-counter; - constexpr int64_t startM = counterReverse/(endK); - constexpr int64_t startK = counterReverse%endK; - - constexpr int64_t packetIndex = startM*endK + startK; + constexpr int64_t packetIndex = startM * endK + startK; constexpr int64_t startM_ = isFWDSolve ? startM : -startM; - const int64_t rhsIndex = (startK*PacketSize) + startM_*LDB; + const int64_t rhsIndex = (startK * PacketSize) + startM_ * LDB; EIGEN_IF_CONSTEXPR(krem) { RHSInPacket.packet[packetIndex] = ploadu(&B_arr[rhsIndex], remMask(rem)); } else { RHSInPacket.packet[packetIndex] = ploadu(&B_arr[rhsIndex]); } - aux_loadRHS(B_arr, LDB, RHSInPacket, rem); + aux_loadRHS(B_arr, LDB, RHSInPacket, rem); } - template - static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> - aux_loadRHS(Scalar* B_arr, int64_t LDB, PacketBlock &RHSInPacket, int64_t rem = 0) - { - EIGEN_UNUSED_VARIABLE(B_arr); - EIGEN_UNUSED_VARIABLE(LDB); - EIGEN_UNUSED_VARIABLE(RHSInPacket); - EIGEN_UNUSED_VARIABLE(rem); - } + template + static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_loadRHS( + Scalar *B_arr, int64_t LDB, PacketBlock &RHSInPacket, int64_t rem = 0) { + EIGEN_UNUSED_VARIABLE(B_arr); + EIGEN_UNUSED_VARIABLE(LDB); + EIGEN_UNUSED_VARIABLE(RHSInPacket); + EIGEN_UNUSED_VARIABLE(rem); + } /** * aux_storeRHS @@ -658,34 +631,33 @@ public: * for(startM = 0; startM < endM; startM++) * for(startK = 0; startK < endK; startK++) **/ - template - static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> - aux_storeRHS(Scalar* B_arr, int64_t LDB, PacketBlock &RHSInPacket, int64_t rem = 0) { - constexpr int64_t counterReverse = endM*endK-counter; - constexpr int64_t startM = counterReverse/(endK); - constexpr int64_t startK = counterReverse%endK; + template + static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_storeRHS( + Scalar *B_arr, int64_t LDB, PacketBlock &RHSInPacket, int64_t rem = 0) { + constexpr int64_t counterReverse = endM * endK - counter; + constexpr int64_t startM = counterReverse / (endK); + constexpr int64_t startK = counterReverse % endK; - constexpr int64_t packetIndex = startM*endK + startK; + constexpr int64_t packetIndex = startM * endK + startK; constexpr int64_t startM_ = isFWDSolve ? startM : -startM; - const int64_t rhsIndex = (startK*PacketSize) + startM_*LDB; + const int64_t rhsIndex = (startK * PacketSize) + startM_ * LDB; EIGEN_IF_CONSTEXPR(krem) { pstoreu(&B_arr[rhsIndex], RHSInPacket.packet[packetIndex], remMask(rem)); } else { pstoreu(&B_arr[rhsIndex], RHSInPacket.packet[packetIndex]); } - aux_storeRHS(B_arr, LDB, RHSInPacket, rem); + aux_storeRHS(B_arr, LDB, RHSInPacket, rem); } - template - static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> - aux_storeRHS(Scalar* B_arr, int64_t LDB, PacketBlock &RHSInPacket, int64_t rem = 0) - { - EIGEN_UNUSED_VARIABLE(B_arr); - EIGEN_UNUSED_VARIABLE(LDB); - EIGEN_UNUSED_VARIABLE(RHSInPacket); - EIGEN_UNUSED_VARIABLE(rem); - } + template + static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_storeRHS( + Scalar *B_arr, int64_t LDB, PacketBlock &RHSInPacket, int64_t rem = 0) { + EIGEN_UNUSED_VARIABLE(B_arr); + EIGEN_UNUSED_VARIABLE(LDB); + EIGEN_UNUSED_VARIABLE(RHSInPacket); + EIGEN_UNUSED_VARIABLE(rem); + } /** * aux_divRHSByDiag @@ -695,20 +667,20 @@ public: * 1-D unroll * for(startK = 0; startK < endK; startK++) **/ - template - static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0 && currM >= 0)> - aux_divRHSByDiag(PacketBlock &RHSInPacket, PacketBlock &AInPacket) { - constexpr int64_t counterReverse = endK-counter; + template + static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0 && currM >= 0)> aux_divRHSByDiag( + PacketBlock &RHSInPacket, PacketBlock &AInPacket) { + constexpr int64_t counterReverse = endK - counter; constexpr int64_t startK = counterReverse; - constexpr int64_t packetIndex = currM*endK + startK; + constexpr int64_t packetIndex = currM * endK + startK; RHSInPacket.packet[packetIndex] = pmul(AInPacket.packet[currM], RHSInPacket.packet[packetIndex]); - aux_divRHSByDiag(RHSInPacket, AInPacket); + aux_divRHSByDiag(RHSInPacket, AInPacket); } - template - static EIGEN_ALWAYS_INLINE std::enable_if_t 0 && currM >= 0)> - aux_divRHSByDiag(PacketBlock &RHSInPacket, PacketBlock &AInPacket) { + template + static EIGEN_ALWAYS_INLINE std::enable_if_t 0 && currM >= 0)> aux_divRHSByDiag( + PacketBlock &RHSInPacket, PacketBlock &AInPacket) { EIGEN_UNUSED_VARIABLE(RHSInPacket); EIGEN_UNUSED_VARIABLE(AInPacket); } @@ -720,52 +692,53 @@ public: * for(startM = initM; startM < endM; startM++) * for(startK = 0; startK < endK; startK++) **/ - template - static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> - aux_updateRHS(Scalar *A_arr, int64_t LDA, PacketBlock &RHSInPacket, PacketBlock &AInPacket) { + template + static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_updateRHS( + Scalar *A_arr, int64_t LDA, PacketBlock &RHSInPacket, + PacketBlock &AInPacket) { + constexpr int64_t counterReverse = (endM - initM) * endK - counter; + constexpr int64_t startM = initM + counterReverse / (endK); + constexpr int64_t startK = counterReverse % endK; - constexpr int64_t counterReverse = (endM-initM)*endK-counter; - constexpr int64_t startM = initM + counterReverse/(endK); - constexpr int64_t startK = counterReverse%endK; - - // For each row of A, first update all corresponding RHS - constexpr int64_t packetIndex = startM*endK + startK; - EIGEN_IF_CONSTEXPR(currentM > 0) { - RHSInPacket.packet[packetIndex] = - pnmadd(AInPacket.packet[startM], - RHSInPacket.packet[(currentM-1)*endK+startK], + // For each row of A, first update all corresponding RHS + constexpr int64_t packetIndex = startM * endK + startK; + EIGEN_IF_CONSTEXPR(currentM > 0) { + RHSInPacket.packet[packetIndex] = + pnmadd(AInPacket.packet[startM], RHSInPacket.packet[(currentM - 1) * endK + startK], RHSInPacket.packet[packetIndex]); - } + } - EIGEN_IF_CONSTEXPR(startK == endK - 1) { - // Once all RHS for previous row of A is updated, we broadcast the next element in the column A_{i, currentM}. - EIGEN_IF_CONSTEXPR(startM == currentM && !isUnitDiag) { - // If diagonal is not unit, we broadcast reciprocals of diagonals AinPacket.packet[currentM]. - // This will be used in divRHSByDiag - EIGEN_IF_CONSTEXPR(isFWDSolve) - AInPacket.packet[currentM] = pset1(Scalar(1)/A_arr[idA(currentM,currentM,LDA)]); - else - AInPacket.packet[currentM] = pset1(Scalar(1)/A_arr[idA(-currentM,-currentM,LDA)]); - } - else { - // Broadcast next off diagonal element of A - EIGEN_IF_CONSTEXPR(isFWDSolve) - AInPacket.packet[startM] = pset1(A_arr[idA(startM,currentM,LDA)]); - else - AInPacket.packet[startM] = pset1(A_arr[idA(-startM,-currentM,LDA)]); - } + EIGEN_IF_CONSTEXPR(startK == endK - 1) { + // Once all RHS for previous row of A is updated, we broadcast the next element in the column A_{i, currentM}. + EIGEN_IF_CONSTEXPR(startM == currentM && !isUnitDiag) { + // If diagonal is not unit, we broadcast reciprocals of diagonals AinPacket.packet[currentM]. + // This will be used in divRHSByDiag + EIGEN_IF_CONSTEXPR(isFWDSolve) + AInPacket.packet[currentM] = pset1(Scalar(1) / A_arr[idA(currentM, currentM, LDA)]); + else AInPacket.packet[currentM] = pset1(Scalar(1) / A_arr[idA(-currentM, -currentM, LDA)]); } + else { + // Broadcast next off diagonal element of A + EIGEN_IF_CONSTEXPR(isFWDSolve) + AInPacket.packet[startM] = pset1(A_arr[idA(startM, currentM, LDA)]); + else AInPacket.packet[startM] = pset1(A_arr[idA(-startM, -currentM, LDA)]); + } + } - aux_updateRHS(A_arr, LDA, RHSInPacket, AInPacket); + aux_updateRHS( + A_arr, LDA, RHSInPacket, AInPacket); } - template - static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> - aux_updateRHS(Scalar *A_arr, int64_t LDA, PacketBlock &RHSInPacket, PacketBlock &AInPacket) { - EIGEN_UNUSED_VARIABLE(A_arr); - EIGEN_UNUSED_VARIABLE(LDA); - EIGEN_UNUSED_VARIABLE(RHSInPacket); - EIGEN_UNUSED_VARIABLE(AInPacket); + template + static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_updateRHS( + Scalar *A_arr, int64_t LDA, PacketBlock &RHSInPacket, + PacketBlock &AInPacket) { + EIGEN_UNUSED_VARIABLE(A_arr); + EIGEN_UNUSED_VARIABLE(LDA); + EIGEN_UNUSED_VARIABLE(RHSInPacket); + EIGEN_UNUSED_VARIABLE(AInPacket); } /** @@ -775,10 +748,10 @@ public: * for(startM = 0; startM < endM; startM++) **/ template - static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> - aux_triSolveMicroKernel(Scalar *A_arr, int64_t LDA, PacketBlock &RHSInPacket, PacketBlock &AInPacket) { - - constexpr int64_t counterReverse = endM-counter; + static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_triSolveMicroKernel( + Scalar *A_arr, int64_t LDA, PacketBlock &RHSInPacket, + PacketBlock &AInPacket) { + constexpr int64_t counterReverse = endM - counter; constexpr int64_t startM = counterReverse; constexpr int64_t currentM = startM; @@ -789,31 +762,31 @@ public: // this is handled with enable_if to prevent out-of-bound warnings // from the compiler EIGEN_IF_CONSTEXPR(!isUnitDiag && startM > 0) - trsm::template divRHSByDiag(RHSInPacket, AInPacket); + trsm::template divRHSByDiag(RHSInPacket, AInPacket); // After division, the rhs corresponding to subsequent rows of A can be partially updated // We also broadcast the reciprocal of the next diagonal to AInPacket.packet[currentM] (if needed) // to be used in the next iteration. - trsm::template - updateRHS( - A_arr, LDA, RHSInPacket, AInPacket); + trsm::template updateRHS(A_arr, LDA, RHSInPacket, + AInPacket); // Handle division for the RHS corresponding to the final row of A. - EIGEN_IF_CONSTEXPR(!isUnitDiag && startM == endM-1) - trsm::template divRHSByDiag(RHSInPacket, AInPacket); + EIGEN_IF_CONSTEXPR(!isUnitDiag && startM == endM - 1) + trsm::template divRHSByDiag(RHSInPacket, AInPacket); - aux_triSolveMicroKernel(A_arr, LDA, RHSInPacket, AInPacket); + aux_triSolveMicroKernel(A_arr, LDA, RHSInPacket, + AInPacket); } template - static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> - aux_triSolveMicroKernel(Scalar *A_arr, int64_t LDA, PacketBlock &RHSInPacket, PacketBlock &AInPacket) - { - EIGEN_UNUSED_VARIABLE(A_arr); - EIGEN_UNUSED_VARIABLE(LDA); - EIGEN_UNUSED_VARIABLE(RHSInPacket); - EIGEN_UNUSED_VARIABLE(AInPacket); - } + static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_triSolveMicroKernel( + Scalar *A_arr, int64_t LDA, PacketBlock &RHSInPacket, + PacketBlock &AInPacket) { + EIGEN_UNUSED_VARIABLE(A_arr); + EIGEN_UNUSED_VARIABLE(LDA); + EIGEN_UNUSED_VARIABLE(RHSInPacket); + EIGEN_UNUSED_VARIABLE(AInPacket); + } /******************************************************** * Wrappers for aux_XXXX to hide counter parameter @@ -823,40 +796,42 @@ public: * Load endMxendK block of B to RHSInPacket * Masked loads are used for cases where endK is not a multiple of PacketSize */ - template - static EIGEN_ALWAYS_INLINE - void loadRHS(Scalar* B_arr, int64_t LDB, PacketBlock &RHSInPacket, int64_t rem = 0) { - aux_loadRHS(B_arr, LDB, RHSInPacket, rem); + template + static EIGEN_ALWAYS_INLINE void loadRHS(Scalar *B_arr, int64_t LDB, + PacketBlock &RHSInPacket, int64_t rem = 0) { + aux_loadRHS(B_arr, LDB, RHSInPacket, rem); } /** * Load endMxendK block of B to RHSInPacket * Masked loads are used for cases where endK is not a multiple of PacketSize */ - template - static EIGEN_ALWAYS_INLINE - void storeRHS(Scalar* B_arr, int64_t LDB, PacketBlock &RHSInPacket, int64_t rem = 0) { - aux_storeRHS(B_arr, LDB, RHSInPacket, rem); + template + static EIGEN_ALWAYS_INLINE void storeRHS(Scalar *B_arr, int64_t LDB, + PacketBlock &RHSInPacket, int64_t rem = 0) { + aux_storeRHS(B_arr, LDB, RHSInPacket, rem); } /** * Only used if Triangular matrix has non-unit diagonal values */ - template - static EIGEN_ALWAYS_INLINE - void divRHSByDiag(PacketBlock &RHSInPacket, PacketBlock &AInPacket) { + template + static EIGEN_ALWAYS_INLINE void divRHSByDiag(PacketBlock &RHSInPacket, + PacketBlock &AInPacket) { aux_divRHSByDiag(RHSInPacket, AInPacket); } /** * Update right-hand sides (stored in avx registers) * Traversing along the column A_{i,currentM}, where currentM <= i <= endM, and broadcasting each value to AInPacket. - **/ - template - static EIGEN_ALWAYS_INLINE - void updateRHS(Scalar *A_arr, int64_t LDA, PacketBlock &RHSInPacket, PacketBlock &AInPacket) { - aux_updateRHS( - A_arr, LDA, RHSInPacket, AInPacket); + **/ + template + static EIGEN_ALWAYS_INLINE void updateRHS(Scalar *A_arr, int64_t LDA, + PacketBlock &RHSInPacket, + PacketBlock &AInPacket) { + aux_updateRHS( + A_arr, LDA, RHSInPacket, AInPacket); } /** @@ -866,11 +841,11 @@ public: * isUnitDiag: true => triangular matrix has unit diagonal. */ template - static EIGEN_ALWAYS_INLINE - void triSolveMicroKernel(Scalar *A_arr, int64_t LDA, PacketBlock &RHSInPacket, PacketBlock &AInPacket) { - static_assert( numK >= 1 && numK <= 3, "numK out of range" ); - aux_triSolveMicroKernel( - A_arr, LDA, RHSInPacket, AInPacket); + static EIGEN_ALWAYS_INLINE void triSolveMicroKernel(Scalar *A_arr, int64_t LDA, + PacketBlock &RHSInPacket, + PacketBlock &AInPacket) { + static_assert(numK >= 1 && numK <= 3, "numK out of range"); + aux_triSolveMicroKernel(A_arr, LDA, RHSInPacket, AInPacket); } }; @@ -881,7 +856,7 @@ public: */ template class gemm { -public: + public: using vec = typename std::conditional::value, vecFullFloat, vecFullDouble>::type; static constexpr int64_t PacketSize = packet_traits::size; @@ -901,23 +876,22 @@ public: * for(startM = 0; startM < endM; startM++) * for(startN = 0; startN < endN; startN++) **/ - template - static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> - aux_setzero(PacketBlock &zmm) { - constexpr int64_t counterReverse = endM*endN-counter; - constexpr int64_t startM = counterReverse/(endN); - constexpr int64_t startN = counterReverse%endN; + template + static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_setzero( + PacketBlock &zmm) { + constexpr int64_t counterReverse = endM * endN - counter; + constexpr int64_t startM = counterReverse / (endN); + constexpr int64_t startN = counterReverse % endN; - zmm.packet[startN*endM + startM] = pzero(zmm.packet[startN*endM + startM]); - aux_setzero(zmm); + zmm.packet[startN * endM + startM] = pzero(zmm.packet[startN * endM + startM]); + aux_setzero(zmm); } - template - static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> - aux_setzero(PacketBlock &zmm) - { - EIGEN_UNUSED_VARIABLE(zmm); - } + template + static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_setzero( + PacketBlock &zmm) { + EIGEN_UNUSED_VARIABLE(zmm); + } /** * aux_updateC @@ -926,34 +900,31 @@ public: * for(startM = 0; startM < endM; startM++) * for(startN = 0; startN < endN; startN++) **/ - template - static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> - aux_updateC(Scalar *C_arr, int64_t LDC, PacketBlock &zmm, int64_t rem_ = 0) { + template + static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_updateC( + Scalar *C_arr, int64_t LDC, PacketBlock &zmm, int64_t rem_ = 0) { EIGEN_UNUSED_VARIABLE(rem_); - constexpr int64_t counterReverse = endM*endN-counter; - constexpr int64_t startM = counterReverse/(endN); - constexpr int64_t startN = counterReverse%endN; + constexpr int64_t counterReverse = endM * endN - counter; + constexpr int64_t startM = counterReverse / (endN); + constexpr int64_t startN = counterReverse % endN; EIGEN_IF_CONSTEXPR(rem) - zmm.packet[startN*endM + startM] = - padd(ploadu(&C_arr[(startN) * LDC + startM*PacketSize], remMask(rem_)), - zmm.packet[startN*endM + startM], - remMask(rem_)); - else - zmm.packet[startN*endM + startM] = - padd(ploadu(&C_arr[(startN) * LDC + startM*PacketSize]), zmm.packet[startN*endM + startM]); - aux_updateC(C_arr, LDC, zmm, rem_); + zmm.packet[startN * endM + startM] = + padd(ploadu(&C_arr[(startN)*LDC + startM * PacketSize], remMask(rem_)), + zmm.packet[startN * endM + startM], remMask(rem_)); + else zmm.packet[startN * endM + startM] = + padd(ploadu(&C_arr[(startN)*LDC + startM * PacketSize]), zmm.packet[startN * endM + startM]); + aux_updateC(C_arr, LDC, zmm, rem_); } - template - static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> - aux_updateC(Scalar *C_arr, int64_t LDC, PacketBlock &zmm, int64_t rem_ = 0) - { - EIGEN_UNUSED_VARIABLE(C_arr); - EIGEN_UNUSED_VARIABLE(LDC); - EIGEN_UNUSED_VARIABLE(zmm); - EIGEN_UNUSED_VARIABLE(rem_); - } + template + static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_updateC( + Scalar *C_arr, int64_t LDC, PacketBlock &zmm, int64_t rem_ = 0) { + EIGEN_UNUSED_VARIABLE(C_arr); + EIGEN_UNUSED_VARIABLE(LDC); + EIGEN_UNUSED_VARIABLE(zmm); + EIGEN_UNUSED_VARIABLE(rem_); + } /** * aux_storeC @@ -962,30 +933,29 @@ public: * for(startM = 0; startM < endM; startM++) * for(startN = 0; startN < endN; startN++) **/ - template - static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> - aux_storeC(Scalar *C_arr, int64_t LDC, PacketBlock &zmm, int64_t rem_ = 0) { + template + static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_storeC( + Scalar *C_arr, int64_t LDC, PacketBlock &zmm, int64_t rem_ = 0) { EIGEN_UNUSED_VARIABLE(rem_); - constexpr int64_t counterReverse = endM*endN-counter; - constexpr int64_t startM = counterReverse/(endN); - constexpr int64_t startN = counterReverse%endN; + constexpr int64_t counterReverse = endM * endN - counter; + constexpr int64_t startM = counterReverse / (endN); + constexpr int64_t startN = counterReverse % endN; EIGEN_IF_CONSTEXPR(rem) - pstoreu(&C_arr[(startN) * LDC + startM*PacketSize], zmm.packet[startN*endM + startM], remMask(rem_)); - else - pstoreu(&C_arr[(startN) * LDC + startM*PacketSize], zmm.packet[startN*endM + startM]); - aux_storeC(C_arr, LDC, zmm, rem_); + pstoreu(&C_arr[(startN)*LDC + startM * PacketSize], zmm.packet[startN * endM + startM], + remMask(rem_)); + else pstoreu(&C_arr[(startN)*LDC + startM * PacketSize], zmm.packet[startN * endM + startM]); + aux_storeC(C_arr, LDC, zmm, rem_); } - template - static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> - aux_storeC(Scalar *C_arr, int64_t LDC, PacketBlock &zmm, int64_t rem_ = 0) - { - EIGEN_UNUSED_VARIABLE(C_arr); - EIGEN_UNUSED_VARIABLE(LDC); - EIGEN_UNUSED_VARIABLE(zmm); - EIGEN_UNUSED_VARIABLE(rem_); - } + template + static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_storeC( + Scalar *C_arr, int64_t LDC, PacketBlock &zmm, int64_t rem_ = 0) { + EIGEN_UNUSED_VARIABLE(C_arr); + EIGEN_UNUSED_VARIABLE(LDC); + EIGEN_UNUSED_VARIABLE(zmm); + EIGEN_UNUSED_VARIABLE(rem_); + } /** * aux_startLoadB @@ -993,28 +963,25 @@ public: * 1-D unroll * for(startL = 0; startL < endL; startL++) **/ - template - static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> - aux_startLoadB(Scalar *B_t, int64_t LDB, PacketBlock &zmm, int64_t rem_ = 0) { + template + static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_startLoadB( + Scalar *B_t, int64_t LDB, PacketBlock &zmm, int64_t rem_ = 0) { EIGEN_UNUSED_VARIABLE(rem_); - constexpr int64_t counterReverse = endL-counter; + constexpr int64_t counterReverse = endL - counter; constexpr int64_t startL = counterReverse; EIGEN_IF_CONSTEXPR(rem) - zmm.packet[unrollM*unrollN+startL] = - ploadu(&B_t[(startL/unrollM)*LDB + (startL%unrollM)*PacketSize], remMask(rem_)); - else - zmm.packet[unrollM*unrollN+startL] = ploadu(&B_t[(startL/unrollM)*LDB + (startL%unrollM)*PacketSize]); + zmm.packet[unrollM * unrollN + startL] = + ploadu(&B_t[(startL / unrollM) * LDB + (startL % unrollM) * PacketSize], remMask(rem_)); + else zmm.packet[unrollM * unrollN + startL] = + ploadu(&B_t[(startL / unrollM) * LDB + (startL % unrollM) * PacketSize]); - aux_startLoadB(B_t, LDB, zmm, rem_); + aux_startLoadB(B_t, LDB, zmm, rem_); } - template - static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> - aux_startLoadB( - Scalar *B_t, int64_t LDB, - PacketBlock &zmm, int64_t rem_ = 0) - { + template + static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_startLoadB( + Scalar *B_t, int64_t LDB, PacketBlock &zmm, int64_t rem_ = 0) { EIGEN_UNUSED_VARIABLE(B_t); EIGEN_UNUSED_VARIABLE(LDB); EIGEN_UNUSED_VARIABLE(zmm); @@ -1027,21 +994,20 @@ public: * 1-D unroll * for(startB = 0; startB < endB; startB++) **/ - template - static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> - aux_startBCastA(Scalar *A_t, int64_t LDA, PacketBlock &zmm) { - constexpr int64_t counterReverse = endB-counter; + template + static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_startBCastA( + Scalar *A_t, int64_t LDA, PacketBlock &zmm) { + constexpr int64_t counterReverse = endB - counter; constexpr int64_t startB = counterReverse; - zmm.packet[unrollM*unrollN+numLoad+startB] = pload1(&A_t[idA(startB, 0,LDA)]); + zmm.packet[unrollM * unrollN + numLoad + startB] = pload1(&A_t[idA(startB, 0, LDA)]); - aux_startBCastA(A_t, LDA, zmm); + aux_startBCastA(A_t, LDA, zmm); } - template - static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> - aux_startBCastA(Scalar *A_t, int64_t LDA, PacketBlock &zmm) - { + template + static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_startBCastA( + Scalar *A_t, int64_t LDA, PacketBlock &zmm) { EIGEN_UNUSED_VARIABLE(A_t); EIGEN_UNUSED_VARIABLE(LDA); EIGEN_UNUSED_VARIABLE(zmm); @@ -1054,33 +1020,32 @@ public: * 1-D unroll * for(startM = 0; startM < endM; startM++) **/ - template - static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> - aux_loadB(Scalar *B_t, int64_t LDB, PacketBlock &zmm, int64_t rem_ = 0) { + template + static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_loadB( + Scalar *B_t, int64_t LDB, PacketBlock &zmm, int64_t rem_ = 0) { EIGEN_UNUSED_VARIABLE(rem_); - if ((numLoad/endM + currK < unrollK)) { - constexpr int64_t counterReverse = endM-counter; + if ((numLoad / endM + currK < unrollK)) { + constexpr int64_t counterReverse = endM - counter; constexpr int64_t startM = counterReverse; EIGEN_IF_CONSTEXPR(rem) { - zmm.packet[endM*unrollN+(startM+currK*endM)%numLoad] = - ploadu(&B_t[(numLoad/endM + currK)*LDB + startM*PacketSize], remMask(rem_)); + zmm.packet[endM * unrollN + (startM + currK * endM) % numLoad] = + ploadu(&B_t[(numLoad / endM + currK) * LDB + startM * PacketSize], remMask(rem_)); } else { - zmm.packet[endM*unrollN+(startM+currK*endM)%numLoad] = - ploadu(&B_t[(numLoad/endM + currK)*LDB + startM*PacketSize]); + zmm.packet[endM * unrollN + (startM + currK * endM) % numLoad] = + ploadu(&B_t[(numLoad / endM + currK) * LDB + startM * PacketSize]); } - aux_loadB(B_t, LDB, zmm, rem_); + aux_loadB(B_t, LDB, zmm, rem_); } } - template - static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> - aux_loadB( - Scalar *B_t, int64_t LDB, - PacketBlock &zmm, int64_t rem_ = 0) - { + template + static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_loadB( + Scalar *B_t, int64_t LDB, PacketBlock &zmm, int64_t rem_ = 0) { EIGEN_UNUSED_VARIABLE(B_t); EIGEN_UNUSED_VARIABLE(LDB); EIGEN_UNUSED_VARIABLE(zmm); @@ -1095,58 +1060,53 @@ public: * for(startN = 0; startN < endN; startN++) * for(startK = 0; startK < endK; startK++) **/ - template - static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> - aux_microKernel( - Scalar *B_t, Scalar* A_t, int64_t LDB, int64_t LDA, - PacketBlock &zmm, int64_t rem_ = 0) { + template + static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_microKernel( + Scalar *B_t, Scalar *A_t, int64_t LDB, int64_t LDA, PacketBlock &zmm, + int64_t rem_ = 0) { EIGEN_UNUSED_VARIABLE(rem_); - constexpr int64_t counterReverse = endM*endN*endK-counter; - constexpr int startK = counterReverse/(endM*endN); - constexpr int startN = (counterReverse/(endM))%endN; - constexpr int startM = counterReverse%endM; + constexpr int64_t counterReverse = endM * endN * endK - counter; + constexpr int startK = counterReverse / (endM * endN); + constexpr int startN = (counterReverse / (endM)) % endN; + constexpr int startM = counterReverse % endM; EIGEN_IF_CONSTEXPR(startK == 0 && startM == 0 && startN == 0) { - gemm:: template - startLoadB(B_t, LDB, zmm, rem_); - gemm:: template - startBCastA(A_t, LDA, zmm); + gemm::template startLoadB(B_t, LDB, zmm, rem_); + gemm::template startBCastA(A_t, LDA, zmm); } { // Interleave FMA and Bcast EIGEN_IF_CONSTEXPR(isAdd) { - zmm.packet[startN*endM + startM] = - pmadd(zmm.packet[endM*endN+numLoad+(startN+startK*endN)%numBCast], - zmm.packet[endM*endN+(startM+startK*endM)%numLoad], zmm.packet[startN*endM + startM]); + zmm.packet[startN * endM + startM] = + pmadd(zmm.packet[endM * endN + numLoad + (startN + startK * endN) % numBCast], + zmm.packet[endM * endN + (startM + startK * endM) % numLoad], zmm.packet[startN * endM + startM]); } else { - zmm.packet[startN*endM + startM] = - pnmadd(zmm.packet[endM*endN+numLoad+(startN+startK*endN)%numBCast], - zmm.packet[endM*endN+(startM+startK*endM)%numLoad], zmm.packet[startN*endM + startM]); + zmm.packet[startN * endM + startM] = + pnmadd(zmm.packet[endM * endN + numLoad + (startN + startK * endN) % numBCast], + zmm.packet[endM * endN + (startM + startK * endM) % numLoad], zmm.packet[startN * endM + startM]); } // Bcast - EIGEN_IF_CONSTEXPR(startM == endM - 1 && (numBCast + startN + startK*endN < endK*endN)) { - zmm.packet[endM*endN+numLoad+(startN+startK*endN)%numBCast] = - pload1(&A_t[idA((numBCast + startN + startK*endN)%endN, - (numBCast + startN + startK*endN)/endN, LDA)]); + EIGEN_IF_CONSTEXPR(startM == endM - 1 && (numBCast + startN + startK * endN < endK * endN)) { + zmm.packet[endM * endN + numLoad + (startN + startK * endN) % numBCast] = pload1(&A_t[idA( + (numBCast + startN + startK * endN) % endN, (numBCast + startN + startK * endN) / endN, LDA)]); } } // We have updated all accumlators, time to load next set of B's - EIGEN_IF_CONSTEXPR( (startN == endN - 1) && (startM == endM - 1) ) { + EIGEN_IF_CONSTEXPR((startN == endN - 1) && (startM == endM - 1)) { gemm::template loadB(B_t, LDB, zmm, rem_); } - aux_microKernel(B_t, A_t, LDB, LDA, zmm, rem_); - + aux_microKernel(B_t, A_t, LDB, LDA, zmm, rem_); } - template - static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> - aux_microKernel( - Scalar *B_t, Scalar* A_t, int64_t LDB, int64_t LDA, - PacketBlock &zmm, int64_t rem_ = 0) - { + template + static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_microKernel( + Scalar *B_t, Scalar *A_t, int64_t LDB, int64_t LDA, PacketBlock &zmm, + int64_t rem_ = 0) { EIGEN_UNUSED_VARIABLE(B_t); EIGEN_UNUSED_VARIABLE(A_t); EIGEN_UNUSED_VARIABLE(LDB); @@ -1159,55 +1119,57 @@ public: * Wrappers for aux_XXXX to hide counter parameter ********************************************************/ - template - static EIGEN_ALWAYS_INLINE - void setzero(PacketBlock &zmm){ - aux_setzero(zmm); + template + static EIGEN_ALWAYS_INLINE void setzero(PacketBlock &zmm) { + aux_setzero(zmm); } /** * Ideally the compiler folds these into vaddp{s,d} with an embedded memory load. */ - template - static EIGEN_ALWAYS_INLINE - void updateC(Scalar *C_arr, int64_t LDC, PacketBlock &zmm, int64_t rem_ = 0){ + template + static EIGEN_ALWAYS_INLINE void updateC(Scalar *C_arr, int64_t LDC, + PacketBlock &zmm, + int64_t rem_ = 0) { EIGEN_UNUSED_VARIABLE(rem_); - aux_updateC(C_arr, LDC, zmm, rem_); + aux_updateC(C_arr, LDC, zmm, rem_); } - template - static EIGEN_ALWAYS_INLINE - void storeC(Scalar *C_arr, int64_t LDC, PacketBlock &zmm, int64_t rem_ = 0){ + template + static EIGEN_ALWAYS_INLINE void storeC(Scalar *C_arr, int64_t LDC, + PacketBlock &zmm, + int64_t rem_ = 0) { EIGEN_UNUSED_VARIABLE(rem_); - aux_storeC(C_arr, LDC, zmm, rem_); + aux_storeC(C_arr, LDC, zmm, rem_); } /** * Use numLoad registers for loading B at start of microKernel - */ - template - static EIGEN_ALWAYS_INLINE - void startLoadB(Scalar *B_t, int64_t LDB, PacketBlock &zmm, int64_t rem_ = 0){ + */ + template + static EIGEN_ALWAYS_INLINE void startLoadB(Scalar *B_t, int64_t LDB, + PacketBlock &zmm, + int64_t rem_ = 0) { EIGEN_UNUSED_VARIABLE(rem_); aux_startLoadB(B_t, LDB, zmm, rem_); } /** * Use numBCast registers for broadcasting A at start of microKernel - */ - template - static EIGEN_ALWAYS_INLINE - void startBCastA(Scalar *A_t, int64_t LDA, PacketBlock &zmm){ + */ + template + static EIGEN_ALWAYS_INLINE void startBCastA(Scalar *A_t, int64_t LDA, + PacketBlock &zmm) { aux_startBCastA(A_t, LDA, zmm); } /** * Loads next set of B into vector registers between each K unroll. - */ - template - static EIGEN_ALWAYS_INLINE - void loadB( - Scalar *B_t, int64_t LDB, PacketBlock &zmm, int64_t rem_ = 0){ + */ + template + static EIGEN_ALWAYS_INLINE void loadB(Scalar *B_t, int64_t LDB, + PacketBlock &zmm, + int64_t rem_ = 0) { EIGEN_UNUSED_VARIABLE(rem_); aux_loadB(B_t, LDB, zmm, rem_); } @@ -1235,18 +1197,16 @@ public: * From testing, there are no register spills with clang. There are register spills with GNU, which * causes a performance hit. */ - template - static EIGEN_ALWAYS_INLINE - void microKernel( - Scalar *B_t, Scalar* A_t, int64_t LDB, int64_t LDA, - PacketBlock &zmm, int64_t rem_ = 0){ + template + static EIGEN_ALWAYS_INLINE void microKernel(Scalar *B_t, Scalar *A_t, int64_t LDB, int64_t LDA, + PacketBlock &zmm, + int64_t rem_ = 0) { EIGEN_UNUSED_VARIABLE(rem_); - aux_microKernel( - B_t, A_t, LDB, LDA, zmm, rem_); + aux_microKernel(B_t, A_t, LDB, LDA, zmm, + rem_); } - }; -} // namespace unrolls +} // namespace unrolls - -#endif //EIGEN_UNROLLS_IMPL_H +#endif // EIGEN_UNROLLS_IMPL_H diff --git a/Eigen/src/Core/products/TriangularSolverMatrix.h b/Eigen/src/Core/products/TriangularSolverMatrix.h index 8b9be8b60..707c284e8 100644 --- a/Eigen/src/Core/products/TriangularSolverMatrix.h +++ b/Eigen/src/Core/products/TriangularSolverMatrix.h @@ -171,7 +171,7 @@ EIGEN_DONT_INLINE void triangular_solve_matrix::value || std::is_same::value)) ) { @@ -246,7 +246,7 @@ EIGEN_DONT_INLINE void triangular_solve_matrix::value || std::is_same::value)) ) { @@ -318,7 +318,7 @@ EIGEN_DONT_INLINE void triangular_solve_matrix::value || std::is_same::value)) ) {