mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-04-29 15:24:51 +08:00
Add AVX512 optimizations for matrix multiply
This commit is contained in:
parent
00b75375e7
commit
25db0b4a82
@ -191,6 +191,7 @@ using std::ptrdiff_t;
|
||||
#include "src/Core/arch/AVX/MathFunctions.h"
|
||||
#include "src/Core/arch/AVX512/MathFunctions.h"
|
||||
#include "src/Core/arch/AVX512/TrsmKernel.h"
|
||||
#include "src/Core/arch/AVX512/GemmKernel.h"
|
||||
#elif defined EIGEN_VECTORIZE_AVX
|
||||
// Use AVX for floats and doubles, SSE for integers
|
||||
#include "src/Core/arch/SSE/PacketMath.h"
|
||||
|
973
Eigen/src/Core/arch/AVX512/GemmKernel.h
Normal file
973
Eigen/src/Core/arch/AVX512/GemmKernel.h
Normal file
@ -0,0 +1,973 @@
|
||||
#ifndef GEMM_KERNEL_H
|
||||
#define GEMM_KERNEL_H
|
||||
|
||||
#include <x86intrin.h>
|
||||
#include <immintrin.h>
|
||||
#include <type_traits>
|
||||
|
||||
#define SECOND_FETCH (32)
|
||||
|
||||
#if (EIGEN_COMP_GNUC_STRICT != 0) && !defined(EIGEN_ARCH_AVX512_GEMM_KERNEL_USE_LESS_A_REGS)
|
||||
// Use less registers to load A elements to workaround compiler spills. Loose a
|
||||
// bit of performance (less than ~2%).
|
||||
#define EIGEN_ARCH_AVX512_GEMM_KERNEL_USE_LESS_A_REGS
|
||||
#endif
|
||||
|
||||
namespace Eigen {
|
||||
namespace internal {
|
||||
|
||||
static inline constexpr int div_up(int a, int b) {
|
||||
return (a + b - 1) / b;
|
||||
}
|
||||
|
||||
template <typename Scalar>
|
||||
class gemm_class
|
||||
{
|
||||
using vec = typename std::conditional<std::is_same<Scalar, float>::value,
|
||||
Packet16f, Packet8d>::type;
|
||||
using vec_ymm = typename std::conditional<std::is_same<Scalar, float>::value,
|
||||
Packet8f, Packet4d>::type;
|
||||
using vec_xmm = typename std::conditional<std::is_same<Scalar, float>::value,
|
||||
Packet4f, Packet2d>::type;
|
||||
|
||||
static constexpr bool is_f32 = sizeof(Scalar) == sizeof(float);
|
||||
static constexpr bool is_f64 = sizeof(Scalar) == sizeof(double);
|
||||
|
||||
#ifndef EIGEN_ARCH_AVX512_GEMM_KERNEL_USE_LESS_A_REGS
|
||||
static constexpr int a_regs[] = {0, 1, 2, 3, 4, 5};
|
||||
#else
|
||||
static constexpr int a_regs[] = {0, 1, 2, 0, 1, 2};
|
||||
#endif
|
||||
#ifndef EIGEN_ARCH_AVX512_GEMM_KERNEL_USE_LESS_B_REGS
|
||||
static constexpr int b_regs[] = {6, 7};
|
||||
#else
|
||||
static constexpr int b_regs[] = {6, 6};
|
||||
#endif
|
||||
static constexpr int c_regs[] = {
|
||||
8 , 16, 24,
|
||||
9 , 17, 25,
|
||||
10, 18, 26,
|
||||
11, 19, 27,
|
||||
12, 20, 28,
|
||||
13, 21, 29,
|
||||
14, 22, 30,
|
||||
15, 23, 31,
|
||||
};
|
||||
|
||||
static constexpr int a_shift = 128;
|
||||
static constexpr int b_shift = 128;
|
||||
|
||||
static constexpr int nelems_in_cache_line = is_f32 ? 16 : 8;
|
||||
static constexpr int a_prefetch_size = nelems_in_cache_line * 2;
|
||||
static constexpr int b_prefetch_size = nelems_in_cache_line * 8;
|
||||
|
||||
vec zmm[32];
|
||||
|
||||
// gemm arguments.
|
||||
int64_t m;
|
||||
const int64_t n, k, ldc;
|
||||
const Scalar *alpha;
|
||||
|
||||
const Scalar *a, *b;
|
||||
Scalar *c;
|
||||
|
||||
const bool is_alpha1;
|
||||
const bool is_beta0;
|
||||
|
||||
const int64_t a_stride, b_stride;
|
||||
const int64_t a_off, b_off;
|
||||
|
||||
public:
|
||||
EIGEN_ALWAYS_INLINE void prefetch_a(const Scalar *a_addr)
|
||||
{
|
||||
_mm_prefetch((char *) (a_prefetch_size + a_addr - a_shift), _MM_HINT_T0);
|
||||
}
|
||||
|
||||
EIGEN_ALWAYS_INLINE void prefetch_b(const Scalar *b_addr)
|
||||
{
|
||||
_mm_prefetch((char *) (b_prefetch_size + b_addr - b_shift), _MM_HINT_T0);
|
||||
}
|
||||
|
||||
EIGEN_ALWAYS_INLINE void prefetch_x(const Scalar *x_addr)
|
||||
{
|
||||
_mm_prefetch((char *) (x_addr - a_shift), _MM_HINT_T2);
|
||||
}
|
||||
|
||||
EIGEN_ALWAYS_INLINE void prefetch_c(const Scalar *c_addr)
|
||||
{
|
||||
#if defined(__PRFCHW__) && __PRFCHW__ == 1
|
||||
_m_prefetchw((void *) c_addr);
|
||||
#else
|
||||
_mm_prefetch((char *) c_addr, _MM_HINT_T0);
|
||||
#endif
|
||||
}
|
||||
|
||||
template <int nelems>
|
||||
EIGEN_ALWAYS_INLINE void a_load(vec &a_reg, const Scalar *a_addr)
|
||||
{
|
||||
switch (nelems * sizeof(*a_addr) * 8) {
|
||||
default:
|
||||
case 512 * 3: a_reg = ploadu<vec>(a_addr); break;
|
||||
case 512 * 2: a_reg = ploadu<vec>(a_addr); break;
|
||||
case 512 * 1: a_reg = ploadu<vec>(a_addr); break;
|
||||
case 256 * 1: a_reg = preinterpret<vec>(_mm512_broadcast_f64x4(ploadu<Packet4d>(reinterpret_cast<const double *>(a_addr)))); break;
|
||||
case 128 * 1: a_reg = preinterpret<vec>(_mm512_broadcast_f32x4(ploadu<Packet4f>(reinterpret_cast<const float *>(a_addr)))); break;
|
||||
case 64 * 1: a_reg = preinterpret<vec>(pload1<Packet8d>(reinterpret_cast<const double *>(a_addr))); break;
|
||||
case 32 * 1: a_reg = pload1<vec>(a_addr); break;
|
||||
}
|
||||
}
|
||||
|
||||
EIGEN_ALWAYS_INLINE void b_load(vec &b_reg, const Scalar *b_addr)
|
||||
{
|
||||
b_reg = pload1<vec>(b_addr);
|
||||
}
|
||||
|
||||
template <int nelems>
|
||||
EIGEN_ALWAYS_INLINE void c_store(Scalar *mem, vec &src)
|
||||
{
|
||||
switch (nelems * sizeof(*mem) * 8) {
|
||||
default:
|
||||
case 512 * 3: pstoreu(mem, src); break;
|
||||
case 512 * 2: pstoreu(mem, src); break;
|
||||
case 512 * 1: pstoreu(mem, src); break;
|
||||
case 256 * 1: pstoreu(mem, preinterpret<vec_ymm>(src)); break;
|
||||
case 128 * 1: pstoreu(mem, preinterpret<vec_xmm>(src)); break;
|
||||
case 64 * 1: pstorel(mem, preinterpret<vec_xmm>(src)); break;
|
||||
case 32 * 1: pstores(mem, preinterpret<vec_xmm>(src)); break;
|
||||
}
|
||||
}
|
||||
|
||||
template <int nelems>
|
||||
EIGEN_ALWAYS_INLINE void vaddm(vec &dst, const Scalar *mem, vec &src)
|
||||
{
|
||||
switch (nelems * sizeof(*mem) * 8) {
|
||||
default:
|
||||
case 512 * 3: dst = padd(src, ploadu<vec>(mem)); break;
|
||||
case 512 * 2: dst = padd(src, ploadu<vec>(mem)); break;
|
||||
case 512 * 1: dst = padd(src, ploadu<vec>(mem)); break;
|
||||
case 256 * 1: dst = preinterpret<vec>(padd(preinterpret<vec_ymm>(src), ploadu<vec_ymm>(mem))); break;
|
||||
case 128 * 1: dst = preinterpret<vec>(padd(preinterpret<vec_xmm>(src), ploadu<vec_xmm>(mem))); break;
|
||||
case 64 * 1: dst = preinterpret<vec>(padd(preinterpret<vec_xmm>(src), ploadl<vec_xmm>(mem))); break;
|
||||
case 32 * 1: dst = preinterpret<vec>(padds(preinterpret<vec_xmm>(src), ploads<vec_xmm>(mem))); break;
|
||||
}
|
||||
}
|
||||
|
||||
EIGEN_STRONG_INLINE void vfmadd(vec &dst, const vec &src1, const vec &src2) {
|
||||
dst = pmadd(src1, src2, dst);
|
||||
|
||||
#if (EIGEN_COMP_GNUC != 0) || (EIGEN_COMP_CLANG != 0)
|
||||
// Workaround register spills for gcc and clang
|
||||
__asm__ ("#" : [dst] "+v" (dst) : [src1] "%v" (src1), [src2] "v" (src2));
|
||||
#endif
|
||||
}
|
||||
|
||||
template <int nelems>
|
||||
EIGEN_ALWAYS_INLINE void vfmaddm(vec &dst, const Scalar *mem, vec &src, vec &scale)
|
||||
{
|
||||
switch (nelems * sizeof(*mem) * 8) {
|
||||
default:
|
||||
case 512 * 3: dst = pmadd(scale, src, ploadu<vec>(mem)); break;
|
||||
case 512 * 2: dst = pmadd(scale, src, ploadu<vec>(mem)); break;
|
||||
case 512 * 1: dst = pmadd(scale, src, ploadu<vec>(mem)); break;
|
||||
case 256 * 1: dst = preinterpret<vec>(pmadd(preinterpret<vec_ymm>(scale), preinterpret<vec_ymm>(src), ploadu<vec_ymm>(mem))); break;
|
||||
case 128 * 1: dst = preinterpret<vec>(pmadd(preinterpret<vec_xmm>(scale), preinterpret<vec_xmm>(src), ploadu<vec_xmm>(mem))); break;
|
||||
case 64 * 1: dst = preinterpret<vec>(pmadd(preinterpret<vec_xmm>(scale), preinterpret<vec_xmm>(src), ploadl<vec_xmm>(mem))); break;
|
||||
case 32 * 1: dst = preinterpret<vec>(pmadds(preinterpret<vec_xmm>(scale), preinterpret<vec_xmm>(src), ploads<vec_xmm>(mem))); break;
|
||||
}
|
||||
}
|
||||
|
||||
gemm_class(int64_t m_, int64_t n_, int64_t k_, int64_t ldc_, const Scalar *alpha_,
|
||||
const Scalar *a_, const Scalar *b_, Scalar *c_,
|
||||
bool is_alpha1_, bool is_beta0_,
|
||||
int64_t a_stride_, int64_t b_stride_,
|
||||
int64_t a_off_, int64_t b_off_)
|
||||
: m(m_)
|
||||
, n(n_)
|
||||
, k(k_)
|
||||
, ldc(ldc_)
|
||||
, alpha(alpha_)
|
||||
, a(a_)
|
||||
, b(b_)
|
||||
, c(c_)
|
||||
, is_alpha1(is_alpha1_)
|
||||
, is_beta0(is_beta0_)
|
||||
, a_stride(a_stride_)
|
||||
, b_stride(b_stride_)
|
||||
, a_off(a_off_)
|
||||
, b_off(b_off_)
|
||||
{
|
||||
|
||||
// Zero out all accumulation registers.
|
||||
zmm[8 ] = pzero(zmm[8 ]);
|
||||
zmm[9 ] = pzero(zmm[9 ]);
|
||||
zmm[10] = pzero(zmm[10]);
|
||||
zmm[11] = pzero(zmm[11]);
|
||||
zmm[12] = pzero(zmm[12]);
|
||||
zmm[13] = pzero(zmm[13]);
|
||||
zmm[14] = pzero(zmm[14]);
|
||||
zmm[15] = pzero(zmm[15]);
|
||||
zmm[16] = pzero(zmm[16]);
|
||||
zmm[17] = pzero(zmm[17]);
|
||||
zmm[18] = pzero(zmm[18]);
|
||||
zmm[19] = pzero(zmm[19]);
|
||||
zmm[20] = pzero(zmm[20]);
|
||||
zmm[21] = pzero(zmm[21]);
|
||||
zmm[22] = pzero(zmm[22]);
|
||||
zmm[23] = pzero(zmm[23]);
|
||||
zmm[24] = pzero(zmm[24]);
|
||||
zmm[25] = pzero(zmm[25]);
|
||||
zmm[26] = pzero(zmm[26]);
|
||||
zmm[27] = pzero(zmm[27]);
|
||||
zmm[28] = pzero(zmm[28]);
|
||||
zmm[29] = pzero(zmm[29]);
|
||||
zmm[30] = pzero(zmm[30]);
|
||||
zmm[31] = pzero(zmm[31]);
|
||||
}
|
||||
|
||||
template <int j, int endX, int i, int endY, int nelems>
|
||||
EIGEN_ALWAYS_INLINE std::enable_if_t<(j > endX) || (i > endY)>
|
||||
a_loads(const Scalar *ao)
|
||||
{
|
||||
EIGEN_UNUSED_VARIABLE(ao);
|
||||
}
|
||||
|
||||
template <int j, int endX, int i, int endY, int nelems>
|
||||
EIGEN_ALWAYS_INLINE std::enable_if_t<(j <= endX) && (i <= endY)>
|
||||
a_loads(const Scalar *ao)
|
||||
{
|
||||
if (j < endX) {
|
||||
if (i < endY) {
|
||||
auto &a_reg = zmm[a_regs[i + (j % 2) * 3]];
|
||||
const Scalar *a_addr = ao + nelems * j + nelems_in_cache_line * i - a_shift;
|
||||
a_load<nelems>(a_reg, a_addr);
|
||||
|
||||
a_loads<j, endX, i + 1, endY, nelems>(ao);
|
||||
} else {
|
||||
a_loads<j + 1, endX, 0, endY, nelems>(ao);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <int un, int max_b_unroll, int i, int um_vecs, int a_unroll, int b_unroll>
|
||||
EIGEN_ALWAYS_INLINE std::enable_if_t<(un > max_b_unroll) || (i > um_vecs)>
|
||||
prefetch_cs(const Scalar *co1, const Scalar *co2)
|
||||
{
|
||||
EIGEN_UNUSED_VARIABLE(co1);
|
||||
EIGEN_UNUSED_VARIABLE(co2);
|
||||
}
|
||||
|
||||
/* C prefetch loop structure.
|
||||
* for (int un = 0; un < 8; un++) {
|
||||
* if (b_unroll >= un + 1) {
|
||||
* if (un == 4) co2 = co1 + 4 * ldc;
|
||||
*
|
||||
* for (int i = 0; i < um_vecs; i++) {
|
||||
* Scalar *co = (un + 1 <= 4) ? co1 : co2;
|
||||
* auto co_off = (un % 4) * ldc + a_unroll - 1 + i * nelems_in_cache_line * sizeof *co;
|
||||
* prefetch_c(co + co_off);
|
||||
* }
|
||||
* }
|
||||
* }
|
||||
*/
|
||||
|
||||
template <int un, int max_b_unroll, int i, int um_vecs, int a_unroll, int b_unroll>
|
||||
EIGEN_ALWAYS_INLINE std::enable_if_t<(un <= max_b_unroll) && (i <= um_vecs)>
|
||||
prefetch_cs(Scalar *&co1, Scalar *&co2)
|
||||
{
|
||||
if (un < max_b_unroll) {
|
||||
|
||||
if (b_unroll >= un + 1) {
|
||||
if (un == 4 && i == 0) co2 = co1 + 4 * ldc;
|
||||
|
||||
if (i < um_vecs) {
|
||||
Scalar *co = (un + 1 <= 4) ? co1 : co2;
|
||||
auto co_off = (un % 4) * ldc + a_unroll - 1 + i * nelems_in_cache_line * sizeof *co;
|
||||
prefetch_c(co + co_off);
|
||||
|
||||
prefetch_cs<un, max_b_unroll, i + 1, um_vecs, a_unroll, b_unroll>(co1, co2);
|
||||
} else {
|
||||
prefetch_cs<un + 1, max_b_unroll, 0, um_vecs, a_unroll, b_unroll>(co1, co2);
|
||||
}
|
||||
|
||||
} else {
|
||||
prefetch_cs<un + 1, max_b_unroll, 0, um_vecs, a_unroll, b_unroll>(co1, co2);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// load_c
|
||||
template <int i, int um_vecs, int idx, int nelems>
|
||||
EIGEN_ALWAYS_INLINE std::enable_if_t<(i > um_vecs)>
|
||||
scale_load_c(const Scalar *cox, vec &alpha_reg)
|
||||
{
|
||||
EIGEN_UNUSED_VARIABLE(cox);
|
||||
EIGEN_UNUSED_VARIABLE(alpha_reg);
|
||||
}
|
||||
|
||||
template <int i, int um_vecs, int idx, int nelems>
|
||||
EIGEN_ALWAYS_INLINE std::enable_if_t<(i <= um_vecs)>
|
||||
scale_load_c(const Scalar *cox, vec &alpha_reg)
|
||||
{
|
||||
|
||||
if (i < um_vecs) {
|
||||
auto &c_reg = zmm[c_regs[i + idx * 3]];
|
||||
auto c_mem = cox + i * nelems_in_cache_line;
|
||||
|
||||
if (!is_beta0 && is_alpha1)
|
||||
vaddm<nelems>(c_reg, c_mem, c_reg);
|
||||
else if (!is_beta0 && !is_alpha1)
|
||||
vfmaddm<nelems>(c_reg, c_mem, c_reg, alpha_reg);
|
||||
else if (is_beta0 && !is_alpha1)
|
||||
c_reg = pmul(alpha_reg, c_reg);
|
||||
|
||||
scale_load_c<i + 1, um_vecs, idx, nelems>(cox, alpha_reg);
|
||||
}
|
||||
}
|
||||
|
||||
// store_c
|
||||
template <int i, int um_vecs, int idx, int nelems>
|
||||
EIGEN_ALWAYS_INLINE std::enable_if_t<(i > um_vecs)>
|
||||
write_c(Scalar *cox)
|
||||
{
|
||||
EIGEN_UNUSED_VARIABLE(cox);
|
||||
}
|
||||
|
||||
template <int i, int um_vecs, int idx, int nelems>
|
||||
EIGEN_ALWAYS_INLINE std::enable_if_t<(i <= um_vecs)>
|
||||
write_c(Scalar *cox)
|
||||
{
|
||||
if (i < um_vecs) {
|
||||
auto &c_reg = zmm[c_regs[i + idx * 3]];
|
||||
auto c_mem = cox + i * nelems_in_cache_line;
|
||||
|
||||
c_store<nelems>(c_mem, c_reg);
|
||||
c_reg = pzero(c_reg);
|
||||
|
||||
write_c<i + 1, um_vecs, idx, nelems>(cox);
|
||||
}
|
||||
}
|
||||
|
||||
// update c matrix
|
||||
template <int pow, int max_b_unroll, int count, int a_unroll, int b_unroll, int idx>
|
||||
EIGEN_ALWAYS_INLINE std::enable_if_t<(pow > (max_b_unroll << 1)) || (count > (pow + 1) / 2 + 1)>
|
||||
c_update(Scalar *&co1, Scalar *&co2)
|
||||
{
|
||||
EIGEN_UNUSED_VARIABLE(co1);
|
||||
EIGEN_UNUSED_VARIABLE(co2);
|
||||
}
|
||||
|
||||
/* C update loop structure.
|
||||
* co2 = co1 + ldc;
|
||||
*
|
||||
* auto &alpha_reg = zmm[0];
|
||||
* if (!is_alpha1) alpha_reg = pload1<vec>(alpha);
|
||||
*
|
||||
* int idx = 0;
|
||||
* for (pow = 1; pow <= 8; pow <<= 1) {
|
||||
*
|
||||
* if (b_unroll >= pow) {
|
||||
* for (count = 1; count < (pow + 1) / 2 + 1; count++) {
|
||||
* if (pow >= 4) co2 += ldc;
|
||||
*
|
||||
* const Scalar *cox = (idx == 0) ? co1 : co2;
|
||||
*
|
||||
* const int um_vecs = div_up(a_unroll, nelems_in_cache_line);
|
||||
* scale_load_c<0, um_vecs, idx, a_unroll>(cox, alpha_reg);
|
||||
* write_c<0, um_vecs, idx, a_unroll>(cox);
|
||||
*
|
||||
* idx++;
|
||||
* }
|
||||
* }
|
||||
* }
|
||||
*
|
||||
* if (b_unroll == 1)
|
||||
* co1 += ldc;
|
||||
* else
|
||||
* co1 = co2 + ldc;
|
||||
*/
|
||||
|
||||
template <int pow, int max_b_unroll, int count, int a_unroll, int b_unroll, int idx>
|
||||
EIGEN_ALWAYS_INLINE std::enable_if_t<(pow <= (max_b_unroll << 1)) && (count <= (pow + 1) / 2 + 1)>
|
||||
c_update(Scalar *&co1, Scalar *&co2)
|
||||
{
|
||||
const bool first_call = idx == 0;
|
||||
auto &alpha_reg = zmm[0];
|
||||
|
||||
if (first_call) {
|
||||
co2 = co1 + ldc;
|
||||
if (!is_alpha1) alpha_reg = pload1<vec>(alpha);
|
||||
}
|
||||
|
||||
if (pow < (max_b_unroll << 1) && pow <= b_unroll) {
|
||||
if (count < (pow + 1) / 2 + 1) {
|
||||
if (pow >= 4) co2 += ldc;
|
||||
|
||||
Scalar *cox = idx == 0 ? co1 : co2;
|
||||
|
||||
const int um_vecs = div_up(a_unroll, nelems_in_cache_line);
|
||||
scale_load_c<0, um_vecs, idx, a_unroll>(cox, alpha_reg);
|
||||
write_c<0, um_vecs, idx, a_unroll>(cox);
|
||||
|
||||
// Go to the next count and next idx.
|
||||
c_update<pow, max_b_unroll, count + 1, a_unroll, b_unroll, idx + 1>(co1, co2);
|
||||
} else {
|
||||
// Go to the next pow and reset count.
|
||||
c_update<pow << 1, max_b_unroll, 1, a_unroll, b_unroll, idx>(co1, co2);
|
||||
}
|
||||
} else {
|
||||
if (b_unroll == 1)
|
||||
co1 += ldc;
|
||||
else
|
||||
co1 = co2 + ldc;
|
||||
}
|
||||
}
|
||||
|
||||
// compute
|
||||
template <int um, int um_vecs, int idx, int uk, bool fetch_x, bool ktail>
|
||||
EIGEN_ALWAYS_INLINE std::enable_if_t<(um > um_vecs)>
|
||||
compute(const Scalar *ao, const Scalar *bo, int &fetchA_idx, int &fetchB_idx, vec &b_reg)
|
||||
{
|
||||
EIGEN_UNUSED_VARIABLE(ao);
|
||||
EIGEN_UNUSED_VARIABLE(bo);
|
||||
EIGEN_UNUSED_VARIABLE(fetchA_idx);
|
||||
EIGEN_UNUSED_VARIABLE(fetchB_idx);
|
||||
EIGEN_UNUSED_VARIABLE(b_reg);
|
||||
}
|
||||
|
||||
template <int um, int um_vecs, int idx, int uk, bool fetch_x, bool ktail>
|
||||
EIGEN_ALWAYS_INLINE std::enable_if_t<(um <= um_vecs)>
|
||||
compute(const Scalar *ao, const Scalar *bo, int &fetchA_idx, int &fetchB_idx, vec &b_reg)
|
||||
{
|
||||
if (um < um_vecs) {
|
||||
auto &c_reg = zmm[c_regs[um + idx * 3]];
|
||||
auto &a_reg = zmm[a_regs[um + (uk % 2) * 3]];
|
||||
|
||||
vfmadd(c_reg, a_reg, b_reg);
|
||||
|
||||
if (!fetch_x && um == 0 && (((idx == 0 || idx == 6) && (uk % 2 == 0 || is_f64 || ktail)) || (idx == 3 && (uk % 2 == 1 || is_f64 || ktail)))) {
|
||||
prefetch_a(ao + nelems_in_cache_line * fetchA_idx);
|
||||
fetchA_idx++;
|
||||
}
|
||||
|
||||
if (um == 0 && idx == 1 && (uk % 2 == 0 || is_f64 || ktail)) {
|
||||
prefetch_b(bo + nelems_in_cache_line * fetchB_idx);
|
||||
fetchB_idx++;
|
||||
}
|
||||
|
||||
compute<um + 1, um_vecs, idx, uk, fetch_x, ktail>(ao, bo, fetchA_idx, fetchB_idx, b_reg);
|
||||
}
|
||||
}
|
||||
|
||||
// load_a
|
||||
template <int um, int um_vecs, int uk, int nelems, bool ktail>
|
||||
EIGEN_ALWAYS_INLINE std::enable_if_t<(um > um_vecs)>
|
||||
load_a(const Scalar *ao)
|
||||
{
|
||||
EIGEN_UNUSED_VARIABLE(ao);
|
||||
}
|
||||
|
||||
template <int um, int um_vecs, int uk, int nelems, bool ktail>
|
||||
EIGEN_ALWAYS_INLINE std::enable_if_t<(um <= um_vecs)>
|
||||
load_a(const Scalar *ao)
|
||||
{
|
||||
if (um < um_vecs) {
|
||||
auto &a_reg = zmm[a_regs[um + (uk % 2) * 3]];
|
||||
#ifndef EIGEN_ARCH_AVX512_GEMM_KERNEL_USE_LESS_A_REGS
|
||||
const Scalar *a_addr = ao + nelems * (1 + !ktail + uk) + nelems_in_cache_line * um - a_shift;
|
||||
#else
|
||||
const Scalar *a_addr = ao + nelems * (1 + uk) + nelems_in_cache_line * um - a_shift;
|
||||
#endif
|
||||
a_load<nelems>(a_reg, a_addr);
|
||||
|
||||
load_a<um + 1, um_vecs, uk, nelems, ktail>(ao);
|
||||
}
|
||||
}
|
||||
template<int uk, int pow, int count, int um_vecs, int b_unroll, bool ktail, bool fetch_x, bool c_fetch>
|
||||
EIGEN_ALWAYS_INLINE std::enable_if_t<(count > (pow + 1) / 2)>
|
||||
innerkernel_1pow(const Scalar *&aa, const Scalar * const &ao, const Scalar * const &bo, Scalar *&co2, int &fetchA_idx, int &fetchB_idx)
|
||||
{
|
||||
EIGEN_UNUSED_VARIABLE(aa);
|
||||
EIGEN_UNUSED_VARIABLE(ao);
|
||||
EIGEN_UNUSED_VARIABLE(bo);
|
||||
EIGEN_UNUSED_VARIABLE(co2);
|
||||
EIGEN_UNUSED_VARIABLE(fetchA_idx);
|
||||
EIGEN_UNUSED_VARIABLE(fetchB_idx);
|
||||
}
|
||||
|
||||
template<int uk, int pow, int count, int um_vecs, int b_unroll, bool ktail, bool fetch_x, bool c_fetch>
|
||||
EIGEN_ALWAYS_INLINE std::enable_if_t<(count <= (pow + 1) / 2)>
|
||||
innerkernel_1pow(const Scalar *&aa, const Scalar * const &ao, const Scalar * const &bo, Scalar *&co2, int &fetchA_idx, int &fetchB_idx)
|
||||
{
|
||||
const int idx = (pow / 2) + count;
|
||||
|
||||
if (count < (pow + 1) / 2) {
|
||||
auto &b_reg = zmm[b_regs[idx % 2]];
|
||||
|
||||
if (fetch_x && uk == 3 && idx == 0) prefetch_x(aa);
|
||||
if (fetch_x && uk == 3 && idx == 4) aa += 8;
|
||||
|
||||
if (b_unroll >= pow) {
|
||||
|
||||
compute<0, um_vecs, idx, uk, fetch_x, ktail>(ao, bo, fetchA_idx, fetchB_idx, b_reg);
|
||||
|
||||
#ifndef EIGEN_ARCH_AVX512_GEMM_KERNEL_USE_LESS_B_REGS
|
||||
const Scalar *b_addr = bo + b_unroll * uk + idx + 1 + (b_unroll > 1) - b_shift;
|
||||
#else
|
||||
const Scalar *b_addr = bo + b_unroll * uk + idx + 1 - b_shift;
|
||||
#endif
|
||||
b_load(b_reg, b_addr);
|
||||
}
|
||||
|
||||
// Go to the next count.
|
||||
innerkernel_1pow<uk, pow, count + 1, um_vecs, b_unroll, ktail, fetch_x, c_fetch>(aa, ao, bo, co2, fetchA_idx, fetchB_idx);
|
||||
|
||||
} else {
|
||||
// Maybe prefetch C data after count-loop.
|
||||
if (pow == 2 && c_fetch) {
|
||||
if (uk % 3 == 0 && uk > 0) {
|
||||
co2 += ldc;
|
||||
} else {
|
||||
prefetch_c(co2 + (uk % 3) * nelems_in_cache_line);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template<int uk, int max_b_unroll, int a_unroll, int b_unroll, bool ktail, bool fetch_x, bool c_fetch>
|
||||
EIGEN_ALWAYS_INLINE void innerkernel_1uk(const Scalar *&aa, const Scalar * const &ao, const Scalar * const &bo, Scalar *&co2, int &fetchA_idx, int &fetchB_idx)
|
||||
{
|
||||
const int um_vecs = div_up(a_unroll, nelems_in_cache_line);
|
||||
|
||||
if (max_b_unroll >= 1) innerkernel_1pow<uk, 1, 0, um_vecs, b_unroll, ktail, fetch_x, c_fetch>(aa, ao, bo, co2, fetchA_idx, fetchB_idx);
|
||||
if (max_b_unroll >= 2) innerkernel_1pow<uk, 2, 0, um_vecs, b_unroll, ktail, fetch_x, c_fetch>(aa, ao, bo, co2, fetchA_idx, fetchB_idx);
|
||||
if (max_b_unroll >= 4) innerkernel_1pow<uk, 4, 0, um_vecs, b_unroll, ktail, fetch_x, c_fetch>(aa, ao, bo, co2, fetchA_idx, fetchB_idx);
|
||||
if (max_b_unroll >= 8) innerkernel_1pow<uk, 8, 0, um_vecs, b_unroll, ktail, fetch_x, c_fetch>(aa, ao, bo, co2, fetchA_idx, fetchB_idx);
|
||||
|
||||
// Load A after pow-loop.
|
||||
load_a<0, um_vecs, uk, a_unroll, ktail>(ao);
|
||||
}
|
||||
|
||||
/* Inner kernel loop structure.
|
||||
* for (int uk = 0; uk < kfactor; uk++) {
|
||||
* int idx = 0;
|
||||
*
|
||||
* for (pow = 1; pow < max_b_unroll << 1; pow <<= 1) {
|
||||
* for (int count = 0; count < (pow + 1) / 2; count++) {
|
||||
* auto &b_reg = zmm[b_regs[idx % 2]];
|
||||
*
|
||||
* if (fetch_x && uk == 3 && idx == 0) prefetch_x(aa);
|
||||
* if (fetch_x && uk == 3 && idx == 4) aa += 8;
|
||||
*
|
||||
* if (b_unroll >= pow) {
|
||||
* compute<0, um_vecs, idx, uk, fetchx, ktail>(ao, bo, fetchA_idx, fetchB_idx, b_reg);
|
||||
*
|
||||
* const Scalar *b_addr = bo + b_unroll * uk + idx + 1 + (b_unroll > 1) - b_shift ;
|
||||
* b_load(b_reg, b_addr);
|
||||
* }
|
||||
* idx++;
|
||||
* }
|
||||
*
|
||||
* Maybe prefetch C data.
|
||||
* if (pow == 2 && c_fetch) {
|
||||
* if (uk % 3 == 0 && uk > 0) {
|
||||
* co2 += ldc;
|
||||
* } else {
|
||||
* prefetch_c(co2 + (uk % 3) * nelems_in_cache_line);
|
||||
* }
|
||||
* }
|
||||
* }
|
||||
*
|
||||
* Load A.
|
||||
* load_a<0, um_vecs, uk, ktail, a_unroll>(ao);
|
||||
* }
|
||||
*
|
||||
* Advance A/B pointers after uk-loop.
|
||||
* ao += a_unroll * kfactor;
|
||||
* bo += b_unroll * kfactor;
|
||||
*/
|
||||
|
||||
template <int a_unroll, int b_unroll, int k_factor, int max_b_unroll, int max_k_factor, bool c_fetch>
|
||||
EIGEN_ALWAYS_INLINE void innerkernel(const Scalar *&aa, const Scalar *&ao, const Scalar *&bo, Scalar *&co2)
|
||||
{
|
||||
int fetchA_idx = 0;
|
||||
int fetchB_idx = 0;
|
||||
|
||||
const bool fetch_x = k_factor == max_k_factor;
|
||||
const bool ktail = k_factor == 1;
|
||||
|
||||
static_assert(k_factor <= 4 && k_factor > 0,
|
||||
"innerkernel maximum k_factor supported is 4");
|
||||
|
||||
if (k_factor > 0) innerkernel_1uk<0, max_b_unroll, a_unroll, b_unroll, ktail, fetch_x, c_fetch>(aa, ao, bo, co2, fetchA_idx, fetchB_idx);
|
||||
if (k_factor > 1) innerkernel_1uk<1, max_b_unroll, a_unroll, b_unroll, ktail, fetch_x, c_fetch>(aa, ao, bo, co2, fetchA_idx, fetchB_idx);
|
||||
if (k_factor > 2) innerkernel_1uk<2, max_b_unroll, a_unroll, b_unroll, ktail, fetch_x, c_fetch>(aa, ao, bo, co2, fetchA_idx, fetchB_idx);
|
||||
if (k_factor > 3) innerkernel_1uk<3, max_b_unroll, a_unroll, b_unroll, ktail, fetch_x, c_fetch>(aa, ao, bo, co2, fetchA_idx, fetchB_idx);
|
||||
|
||||
// Advance A/B pointers after uk-loop.
|
||||
ao += a_unroll * k_factor;
|
||||
bo += b_unroll * k_factor;
|
||||
}
|
||||
|
||||
|
||||
template <int a_unroll, int b_unroll, int max_b_unroll>
|
||||
EIGEN_ALWAYS_INLINE void kloop(const Scalar *&aa, const Scalar *&ao, const Scalar *&bo, Scalar *&co1, Scalar *&co2)
|
||||
{
|
||||
const int um_vecs = div_up(a_unroll, nelems_in_cache_line);
|
||||
#ifndef EIGEN_ARCH_AVX512_GEMM_KERNEL_USE_LESS_A_REGS
|
||||
a_loads<0, 2, 0, um_vecs, a_unroll>(ao);
|
||||
#else
|
||||
a_loads<0, 1, 0, um_vecs, a_unroll>(ao);
|
||||
#endif
|
||||
|
||||
b_load(zmm[b_regs[0]], bo - b_shift + 0);
|
||||
#ifndef EIGEN_ARCH_AVX512_GEMM_KERNEL_USE_LESS_B_REGS
|
||||
b_load(zmm[b_regs[1]], bo - b_shift + 1);
|
||||
#endif
|
||||
|
||||
#ifndef SECOND_FETCH
|
||||
prefetch_cs<0, max_b_unroll, 0, um_vecs, a_unroll, b_unroll>(co1, co2);
|
||||
#endif // SECOND_FETCH
|
||||
|
||||
// Unrolling k-loop by a factor of 4.
|
||||
const int max_k_factor = 4;
|
||||
int64_t loop_count = k / max_k_factor;
|
||||
|
||||
if (loop_count > 0) {
|
||||
#ifdef SECOND_FETCH
|
||||
loop_count -= SECOND_FETCH;
|
||||
#endif
|
||||
while (loop_count > 0) {
|
||||
innerkernel<a_unroll, b_unroll, max_k_factor, max_b_unroll, max_k_factor, 0>(aa, ao, bo, co2);
|
||||
loop_count--;
|
||||
}
|
||||
#ifdef SECOND_FETCH
|
||||
co2 = co1 + nelems_in_cache_line - 1;
|
||||
|
||||
loop_count += b_unroll;
|
||||
while (loop_count > 0) {
|
||||
innerkernel<a_unroll, b_unroll, max_k_factor, max_b_unroll, max_k_factor, 1>(aa, ao, bo, co2);
|
||||
loop_count--;
|
||||
}
|
||||
|
||||
loop_count += SECOND_FETCH - b_unroll;
|
||||
while (loop_count > 0) {
|
||||
innerkernel<a_unroll, b_unroll, max_k_factor, max_b_unroll, max_k_factor, 0>(aa, ao, bo, co2);
|
||||
loop_count--;
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
// k-loop remainder handling.
|
||||
loop_count = k % max_k_factor;
|
||||
while (loop_count > 0) {
|
||||
innerkernel<a_unroll, b_unroll, 1, max_b_unroll, max_k_factor, 0>(aa, ao, bo, co2);
|
||||
loop_count--;
|
||||
}
|
||||
|
||||
// Update C matrix.
|
||||
c_update<1, max_b_unroll, 1, a_unroll, b_unroll, 0>(co1, co2);
|
||||
}
|
||||
|
||||
template <int a_unroll, int b_unroll, int max_b_unroll>
|
||||
EIGEN_ALWAYS_INLINE void nloop(const Scalar *&aa, const Scalar *&ao, const Scalar *&bo, Scalar *&co1, Scalar *&co2)
|
||||
{
|
||||
// Set A matrix pointer.
|
||||
ao = a + a_off * a_unroll;
|
||||
|
||||
// Set B matrix pointer if needed.
|
||||
bo += b_unroll * b_off;
|
||||
|
||||
kloop<a_unroll, b_unroll, max_b_unroll>(aa, ao, bo, co1, co2);
|
||||
|
||||
// Advance B matrix pointer if needed.
|
||||
bo += b_unroll * (b_stride - k - b_off);
|
||||
|
||||
// Advance prefetch A pointer.
|
||||
aa += 16;
|
||||
}
|
||||
|
||||
template <int a_unroll, int max_a_unroll, int max_b_unroll>
|
||||
EIGEN_ALWAYS_INLINE void mloop(const Scalar *&ao, const Scalar *&bo, Scalar *&co1, Scalar *&co2)
|
||||
{
|
||||
// Set prefetch A pointers.
|
||||
const Scalar *aa = a + a_unroll * a_stride;
|
||||
|
||||
// Set C matrix pointers.
|
||||
co1 = c;
|
||||
if (a_unroll >= max_a_unroll) co2 = c + 2 * ldc;
|
||||
c += a_unroll;
|
||||
|
||||
// Set B matrix pointer.
|
||||
bo = b;
|
||||
|
||||
// Main n-loop.
|
||||
for (int64_t i = n / max_b_unroll; i > 0; i--)
|
||||
nloop<a_unroll, max_b_unroll, max_b_unroll>(aa, ao, bo, co1, co2);
|
||||
|
||||
// n-remainders.
|
||||
if (n & 4 && max_b_unroll > 4) nloop<a_unroll, 4, max_b_unroll>(aa, ao, bo, co1, co2);
|
||||
#if 0
|
||||
if (n & 2 && max_b_unroll > 2) nloop<a_unroll, 2, max_b_unroll>(aa, ao, bo, co1, co2);
|
||||
if (n & 1 && max_b_unroll > 1) nloop<a_unroll, 1, max_b_unroll>(aa, ao, bo, co1, co2);
|
||||
#else
|
||||
// Copy kernels don't support tails of n = 2 for single/double precision.
|
||||
// Loop over ones.
|
||||
int n_rem = 2 * ((n & 2) != 0) + 1 * ((n & 1) != 0);
|
||||
while (n_rem > 0) {nloop<a_unroll, 1, max_b_unroll>(aa, ao, bo, co1, co2); n_rem--;}
|
||||
#endif
|
||||
|
||||
// Advance A matrix pointer.
|
||||
a = ao + a_unroll * (a_stride - k - a_off);
|
||||
}
|
||||
|
||||
// Compute kernel unrolling C matrix by max_a_unroll x max_b_unroll.
|
||||
template <int max_a_unroll, int max_b_unroll>
|
||||
EIGEN_ALWAYS_INLINE void compute_kern()
|
||||
{
|
||||
a -= -a_shift;
|
||||
b -= -b_shift;
|
||||
|
||||
const Scalar *ao = nullptr;
|
||||
const Scalar *bo = nullptr;
|
||||
Scalar *co1 = nullptr;
|
||||
Scalar *co2 = nullptr;
|
||||
|
||||
// Main m-loop.
|
||||
for (; m >= max_a_unroll; m -= max_a_unroll)
|
||||
mloop<max_a_unroll, max_a_unroll, max_b_unroll>(ao, bo, co1, co2);
|
||||
|
||||
// m-remainders.
|
||||
if (m & 32 && max_a_unroll > 32) mloop<32, max_a_unroll, max_b_unroll>(ao, bo, co1, co2);
|
||||
if (m & 16 && max_a_unroll > 16) mloop<16, max_a_unroll, max_b_unroll>(ao, bo, co1, co2);
|
||||
if (m & 8 && max_a_unroll > 8) mloop< 8, max_a_unroll, max_b_unroll>(ao, bo, co1, co2);
|
||||
if (m & 4 && max_a_unroll > 4) mloop< 4, max_a_unroll, max_b_unroll>(ao, bo, co1, co2);
|
||||
if (m & 2 && max_a_unroll > 2 && is_f64) mloop< 2, max_a_unroll, max_b_unroll>(ao, bo, co1, co2);
|
||||
if (m & 1 && max_a_unroll > 1 && is_f64) mloop< 1, max_a_unroll, max_b_unroll>(ao, bo, co1, co2);
|
||||
|
||||
// Copy kernels don't support tails of m = 2 for single precision.
|
||||
// Loop over ones.
|
||||
if (is_f32) {
|
||||
int m_rem = 2 * ((m & 2) != 0) + 1 * ((m & 1) != 0);
|
||||
while (m_rem > 0) {mloop< 1, max_a_unroll, max_b_unroll>(ao, bo, co1, co2); m_rem--;}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// Compute kernel with max unroll support of:
|
||||
// Single precision:
|
||||
// max_a_unroll: 48, 32, 16, 8, 4, 2, 1
|
||||
// max_b_unroll: 8, 4, 2, 1
|
||||
// Double precision:
|
||||
// max_a_unroll: 24, 16, 8, 4, 2, 1
|
||||
// max_b_unroll: 8, 4, 2, 1
|
||||
template <typename Scalar, int max_a_unroll, int max_b_unroll, bool is_alpha1, bool is_beta0>
|
||||
EIGEN_DONT_INLINE void gemm_kern_avx512(int64_t *p_m, int64_t *p_n, int64_t *p_k,
|
||||
Scalar *alpha, const Scalar *a, const Scalar *b, Scalar *c,
|
||||
int64_t ldc, int64_t a_stride = -1, int64_t b_stride = -1,
|
||||
int64_t a_off = 0, int64_t b_off = 0)
|
||||
{
|
||||
if (a_stride == -1) a_stride = *p_k;
|
||||
if (b_stride == -1) b_stride = *p_k;
|
||||
|
||||
gemm_class<Scalar> g(*p_m, *p_n, *p_k, ldc, alpha, a, b, c,
|
||||
is_alpha1, is_beta0, a_stride, b_stride, a_off, b_off);
|
||||
g.template compute_kern<max_a_unroll, max_b_unroll>();
|
||||
}
|
||||
|
||||
template <typename a_t, typename b_t, typename c_t>
|
||||
bool gemm_kernel(int64_t m, int64_t n, int64_t k, c_t alpha,
|
||||
const a_t *a, const b_t *b, c_t *c, int64_t ldc,
|
||||
int64_t a_stride = -1, int64_t b_stride = -1,
|
||||
int64_t a_off = 0, int64_t b_off = 0)
|
||||
{
|
||||
EIGEN_UNUSED_VARIABLE(m);
|
||||
EIGEN_UNUSED_VARIABLE(n);
|
||||
EIGEN_UNUSED_VARIABLE(k);
|
||||
EIGEN_UNUSED_VARIABLE(alpha);
|
||||
EIGEN_UNUSED_VARIABLE(a);
|
||||
EIGEN_UNUSED_VARIABLE(b);
|
||||
EIGEN_UNUSED_VARIABLE(c);
|
||||
EIGEN_UNUSED_VARIABLE(ldc);
|
||||
EIGEN_UNUSED_VARIABLE(a_stride);
|
||||
EIGEN_UNUSED_VARIABLE(b_stride);
|
||||
EIGEN_UNUSED_VARIABLE(a_off);
|
||||
EIGEN_UNUSED_VARIABLE(b_off);
|
||||
return false;
|
||||
}
|
||||
|
||||
template <>
|
||||
bool gemm_kernel(int64_t m, int64_t n, int64_t k, float alpha,
|
||||
const float *a, const float *b, float *c, int64_t ldc,
|
||||
int64_t a_stride, int64_t b_stride,
|
||||
int64_t a_off, int64_t b_off)
|
||||
{
|
||||
if (alpha == 1.f)
|
||||
gemm_kern_avx512<float, 48, 8, true, false>(&m, &n, &k, &alpha, a, b, c,
|
||||
ldc, a_stride, b_stride, a_off, b_off);
|
||||
else
|
||||
gemm_kern_avx512<float, 48, 8, false, false>(&m, &n, &k, &alpha, a, b, c,
|
||||
ldc, a_stride, b_stride, a_off, b_off);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
template <>
|
||||
bool gemm_kernel(int64_t m, int64_t n, int64_t k, double alpha,
|
||||
const double *a, const double *b, double *c, int64_t ldc,
|
||||
int64_t a_stride, int64_t b_stride,
|
||||
int64_t a_off, int64_t b_off)
|
||||
{
|
||||
if (alpha == 1.)
|
||||
gemm_kern_avx512<double, 24, 8, true, false>(&m, &n, &k, &alpha, a, b, c,
|
||||
ldc, a_stride, b_stride, a_off, b_off);
|
||||
else
|
||||
gemm_kern_avx512<double, 24, 8, false, false>(&m, &n, &k, &alpha, a, b, c,
|
||||
ldc, a_stride, b_stride, a_off, b_off);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
template<typename Scalar, typename Index, typename DataMapper, int nr, int StorageOrder, bool Conjugate, bool PanelMode>
|
||||
struct gemm_pack_rhs;
|
||||
|
||||
template<typename Scalar, typename Index, typename DataMapper, bool Conjugate, bool PanelMode>
|
||||
struct gemm_pack_rhs<Scalar, Index, DataMapper, 8, ColMajor, Conjugate, PanelMode>
|
||||
{
|
||||
typedef typename packet_traits<Scalar>::type Packet;
|
||||
typedef typename DataMapper::LinearMapper LinearMapper;
|
||||
enum { PacketSize = packet_traits<Scalar>::size };
|
||||
EIGEN_DONT_INLINE void operator()(Scalar* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride=0, Index offset=0);
|
||||
};
|
||||
|
||||
template<typename Scalar, typename Index, typename DataMapper, bool Conjugate, bool PanelMode>
|
||||
EIGEN_DONT_INLINE void gemm_pack_rhs<Scalar, Index, DataMapper, 8, ColMajor, Conjugate, PanelMode>
|
||||
::operator()(Scalar* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride, Index offset)
|
||||
{
|
||||
constexpr int nr = 8;
|
||||
EIGEN_ASM_COMMENT("EIGEN PRODUCT PACK RHS COLMAJOR");
|
||||
EIGEN_UNUSED_VARIABLE(stride);
|
||||
EIGEN_UNUSED_VARIABLE(offset);
|
||||
eigen_assert(((!PanelMode) && stride==0 && offset==0) || (PanelMode && stride>=depth && offset<=stride));
|
||||
conj_if<NumTraits<Scalar>::IsComplex && Conjugate> cj;
|
||||
Index packet_cols8 = nr>=8 ? (cols/8) * 8 : 0;
|
||||
Index packet_cols4 = nr>=4 ? (cols/4) * 4 : 0;
|
||||
Index count = 0;
|
||||
const Index peeled_k = (depth/PacketSize)*PacketSize;
|
||||
if(nr>=8)
|
||||
{
|
||||
for(Index j2=0; j2<packet_cols8; j2+=8)
|
||||
{
|
||||
// skip what we have before
|
||||
if(PanelMode) count += 8 * offset;
|
||||
const LinearMapper dm0 = rhs.getLinearMapper(0, j2+0);
|
||||
const LinearMapper dm1 = rhs.getLinearMapper(0, j2+1);
|
||||
const LinearMapper dm2 = rhs.getLinearMapper(0, j2+2);
|
||||
const LinearMapper dm3 = rhs.getLinearMapper(0, j2+3);
|
||||
const LinearMapper dm4 = rhs.getLinearMapper(0, j2+4);
|
||||
const LinearMapper dm5 = rhs.getLinearMapper(0, j2+5);
|
||||
const LinearMapper dm6 = rhs.getLinearMapper(0, j2+6);
|
||||
const LinearMapper dm7 = rhs.getLinearMapper(0, j2+7);
|
||||
Index k=0;
|
||||
if((PacketSize%8)==0) // TODO enable vectorized transposition for PacketSize==4
|
||||
{
|
||||
for(; k<peeled_k; k+=PacketSize) {
|
||||
PacketBlock<Packet,(PacketSize%8)==0?8:PacketSize> kernel;
|
||||
|
||||
kernel.packet[0] = dm0.template loadPacket<Packet>(k);
|
||||
kernel.packet[1] = dm1.template loadPacket<Packet>(k);
|
||||
kernel.packet[2] = dm2.template loadPacket<Packet>(k);
|
||||
kernel.packet[3] = dm3.template loadPacket<Packet>(k);
|
||||
kernel.packet[4] = dm4.template loadPacket<Packet>(k);
|
||||
kernel.packet[5] = dm5.template loadPacket<Packet>(k);
|
||||
kernel.packet[6] = dm6.template loadPacket<Packet>(k);
|
||||
kernel.packet[7] = dm7.template loadPacket<Packet>(k);
|
||||
|
||||
ptranspose(kernel);
|
||||
|
||||
pstoreu(blockB+count+0*PacketSize, cj.pconj(kernel.packet[0]));
|
||||
pstoreu(blockB+count+1*PacketSize, cj.pconj(kernel.packet[1%PacketSize]));
|
||||
pstoreu(blockB+count+2*PacketSize, cj.pconj(kernel.packet[2%PacketSize]));
|
||||
pstoreu(blockB+count+3*PacketSize, cj.pconj(kernel.packet[3%PacketSize]));
|
||||
pstoreu(blockB+count+4*PacketSize, cj.pconj(kernel.packet[4%PacketSize]));
|
||||
pstoreu(blockB+count+5*PacketSize, cj.pconj(kernel.packet[5%PacketSize]));
|
||||
pstoreu(blockB+count+6*PacketSize, cj.pconj(kernel.packet[6%PacketSize]));
|
||||
pstoreu(blockB+count+7*PacketSize, cj.pconj(kernel.packet[7%PacketSize]));
|
||||
count+=8*PacketSize;
|
||||
}
|
||||
}
|
||||
for(; k<depth; k++)
|
||||
{
|
||||
blockB[count+0] = cj(dm0(k));
|
||||
blockB[count+1] = cj(dm1(k));
|
||||
blockB[count+2] = cj(dm2(k));
|
||||
blockB[count+3] = cj(dm3(k));
|
||||
blockB[count+4] = cj(dm4(k));
|
||||
blockB[count+5] = cj(dm5(k));
|
||||
blockB[count+6] = cj(dm6(k));
|
||||
blockB[count+7] = cj(dm7(k));
|
||||
count += 8;
|
||||
}
|
||||
// skip what we have after
|
||||
if(PanelMode) count += 8 * (stride-offset-depth);
|
||||
}
|
||||
}
|
||||
|
||||
if(nr>=4)
|
||||
{
|
||||
for(Index j2=packet_cols8; j2<packet_cols4; j2+=4)
|
||||
{
|
||||
// skip what we have before
|
||||
if(PanelMode) count += 4 * offset;
|
||||
const LinearMapper dm0 = rhs.getLinearMapper(0, j2 + 0);
|
||||
const LinearMapper dm1 = rhs.getLinearMapper(0, j2 + 1);
|
||||
const LinearMapper dm2 = rhs.getLinearMapper(0, j2 + 2);
|
||||
const LinearMapper dm3 = rhs.getLinearMapper(0, j2 + 3);
|
||||
|
||||
Index k=0;
|
||||
if((PacketSize%4)==0) // TODO enable vectorized transposition for PacketSize==2 ??
|
||||
{
|
||||
for(; k<peeled_k; k+=PacketSize) {
|
||||
PacketBlock<Packet,(PacketSize%4)==0?4:PacketSize> kernel;
|
||||
kernel.packet[0 ] = dm0.template loadPacket<Packet>(k);
|
||||
kernel.packet[1%PacketSize] = dm1.template loadPacket<Packet>(k);
|
||||
kernel.packet[2%PacketSize] = dm2.template loadPacket<Packet>(k);
|
||||
kernel.packet[3%PacketSize] = dm3.template loadPacket<Packet>(k);
|
||||
ptranspose(kernel);
|
||||
pstoreu(blockB+count+0*PacketSize, cj.pconj(kernel.packet[0]));
|
||||
pstoreu(blockB+count+1*PacketSize, cj.pconj(kernel.packet[1%PacketSize]));
|
||||
pstoreu(blockB+count+2*PacketSize, cj.pconj(kernel.packet[2%PacketSize]));
|
||||
pstoreu(blockB+count+3*PacketSize, cj.pconj(kernel.packet[3%PacketSize]));
|
||||
count+=4*PacketSize;
|
||||
}
|
||||
}
|
||||
for(; k<depth; k++)
|
||||
{
|
||||
blockB[count+0] = cj(dm0(k));
|
||||
blockB[count+1] = cj(dm1(k));
|
||||
blockB[count+2] = cj(dm2(k));
|
||||
blockB[count+3] = cj(dm3(k));
|
||||
count += 4;
|
||||
}
|
||||
// skip what we have after
|
||||
if(PanelMode) count += 4 * (stride-offset-depth);
|
||||
}
|
||||
}
|
||||
|
||||
// copy the remaining columns one at a time (nr==1)
|
||||
for(Index j2=packet_cols4; j2<cols; ++j2)
|
||||
{
|
||||
if(PanelMode) count += offset;
|
||||
const LinearMapper dm0 = rhs.getLinearMapper(0, j2);
|
||||
for(Index k=0; k<depth; k++)
|
||||
{
|
||||
blockB[count] = cj(dm0(k));
|
||||
count += 1;
|
||||
}
|
||||
if(PanelMode) count += (stride-offset-depth);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace Eigen
|
||||
} // namespace internal
|
||||
|
||||
#endif // GEMM_KERNEL_H
|
@ -1432,6 +1432,7 @@ EIGEN_DEVICE_FUNC inline void ptranspose(PacketBlock<Packet16f, 16>& kernel) {
|
||||
EIGEN_INSERT_8f_INTO_16f(OUTPUT[INDEX], INPUT[2 * INDEX], \
|
||||
INPUT[2 * INDEX + STRIDE]);
|
||||
|
||||
template<bool for_trsm = false>
|
||||
EIGEN_DEVICE_FUNC inline void ptranspose(PacketBlock<Packet16f, 8>& kernel) {
|
||||
__m512 T0 = _mm512_unpacklo_ps(kernel.packet[0],kernel.packet[1]);
|
||||
__m512 T1 = _mm512_unpackhi_ps(kernel.packet[0],kernel.packet[1]);
|
||||
@ -1450,28 +1451,49 @@ EIGEN_DEVICE_FUNC inline void ptranspose(PacketBlock<Packet16f, 8>& kernel) {
|
||||
kernel.packet[5] = _mm512_castpd_ps(_mm512_unpackhi_pd(_mm512_castps_pd(T4),_mm512_castps_pd(T6)));
|
||||
kernel.packet[6] = _mm512_castpd_ps(_mm512_unpacklo_pd(_mm512_castps_pd(T5),_mm512_castps_pd(T7)));
|
||||
kernel.packet[7] = _mm512_castpd_ps(_mm512_unpackhi_pd(_mm512_castps_pd(T5),_mm512_castps_pd(T7)));
|
||||
|
||||
T0 = _mm512_castpd_ps(_mm512_permutex_pd(_mm512_castps_pd(kernel.packet[4]), 0x4E));
|
||||
T0 = _mm512_mask_blend_ps(0xF0F0, kernel.packet[0], T0);
|
||||
T4 = _mm512_castpd_ps(_mm512_permutex_pd(_mm512_castps_pd(kernel.packet[0]), 0x4E));
|
||||
T4 = _mm512_mask_blend_ps(0xF0F0, T4, kernel.packet[4]);
|
||||
T1 = _mm512_castpd_ps(_mm512_permutex_pd(_mm512_castps_pd(kernel.packet[5]), 0x4E));
|
||||
T1 = _mm512_mask_blend_ps(0xF0F0, kernel.packet[1], T1);
|
||||
T5 = _mm512_castpd_ps(_mm512_permutex_pd(_mm512_castps_pd(kernel.packet[1]), 0x4E));
|
||||
T5 = _mm512_mask_blend_ps(0xF0F0, T5, kernel.packet[5]);
|
||||
T2 = _mm512_castpd_ps(_mm512_permutex_pd(_mm512_castps_pd(kernel.packet[6]), 0x4E));
|
||||
T2 = _mm512_mask_blend_ps(0xF0F0, kernel.packet[2], T2);
|
||||
T6 = _mm512_castpd_ps(_mm512_permutex_pd(_mm512_castps_pd(kernel.packet[2]), 0x4E));
|
||||
T6 = _mm512_mask_blend_ps(0xF0F0, T6, kernel.packet[6]);
|
||||
T3 = _mm512_castpd_ps(_mm512_permutex_pd(_mm512_castps_pd(kernel.packet[7]), 0x4E));
|
||||
T3 = _mm512_mask_blend_ps(0xF0F0, kernel.packet[3], T3);
|
||||
T7 = _mm512_castpd_ps(_mm512_permutex_pd(_mm512_castps_pd(kernel.packet[3]), 0x4E));
|
||||
T7 = _mm512_mask_blend_ps(0xF0F0, T7, kernel.packet[7]);
|
||||
|
||||
kernel.packet[0] = T0; kernel.packet[1] = T1;
|
||||
kernel.packet[2] = T2; kernel.packet[3] = T3;
|
||||
kernel.packet[4] = T4; kernel.packet[5] = T5;
|
||||
kernel.packet[6] = T6; kernel.packet[7] = T7;
|
||||
// Transpose for gemm is slightly different than trsm.
|
||||
if (!for_trsm) {
|
||||
T0 = _mm512_shuffle_f32x4(kernel.packet[0], kernel.packet[4], 0x44);
|
||||
T1 = _mm512_shuffle_f32x4(kernel.packet[0], kernel.packet[4], 0xee);
|
||||
T2 = _mm512_shuffle_f32x4(kernel.packet[1], kernel.packet[5], 0x44);
|
||||
T3 = _mm512_shuffle_f32x4(kernel.packet[1], kernel.packet[5], 0xee);
|
||||
T4 = _mm512_shuffle_f32x4(kernel.packet[2], kernel.packet[6], 0x44);
|
||||
T5 = _mm512_shuffle_f32x4(kernel.packet[2], kernel.packet[6], 0xee);
|
||||
T6 = _mm512_shuffle_f32x4(kernel.packet[3], kernel.packet[7], 0x44);
|
||||
T7 = _mm512_shuffle_f32x4(kernel.packet[3], kernel.packet[7], 0xee);
|
||||
|
||||
kernel.packet[0] = _mm512_shuffle_f32x4(T0, T2, 0x88);
|
||||
kernel.packet[2] = _mm512_shuffle_f32x4(T0, T2, 0xdd);
|
||||
kernel.packet[1] = _mm512_shuffle_f32x4(T4, T6, 0x88);
|
||||
kernel.packet[3] = _mm512_shuffle_f32x4(T4, T6, 0xdd);
|
||||
kernel.packet[4] = _mm512_shuffle_f32x4(T1, T3, 0x88);
|
||||
kernel.packet[6] = _mm512_shuffle_f32x4(T1, T3, 0xdd);
|
||||
kernel.packet[5] = _mm512_shuffle_f32x4(T5, T7, 0x88);
|
||||
kernel.packet[7] = _mm512_shuffle_f32x4(T5, T7, 0xdd);
|
||||
} else {
|
||||
T0 = _mm512_castpd_ps(_mm512_permutex_pd(_mm512_castps_pd(kernel.packet[4]), 0x4E));
|
||||
T0 = _mm512_mask_blend_ps(0xF0F0, kernel.packet[0], T0);
|
||||
T4 = _mm512_castpd_ps(_mm512_permutex_pd(_mm512_castps_pd(kernel.packet[0]), 0x4E));
|
||||
T4 = _mm512_mask_blend_ps(0xF0F0, T4, kernel.packet[4]);
|
||||
T1 = _mm512_castpd_ps(_mm512_permutex_pd(_mm512_castps_pd(kernel.packet[5]), 0x4E));
|
||||
T1 = _mm512_mask_blend_ps(0xF0F0, kernel.packet[1], T1);
|
||||
T5 = _mm512_castpd_ps(_mm512_permutex_pd(_mm512_castps_pd(kernel.packet[1]), 0x4E));
|
||||
T5 = _mm512_mask_blend_ps(0xF0F0, T5, kernel.packet[5]);
|
||||
T2 = _mm512_castpd_ps(_mm512_permutex_pd(_mm512_castps_pd(kernel.packet[6]), 0x4E));
|
||||
T2 = _mm512_mask_blend_ps(0xF0F0, kernel.packet[2], T2);
|
||||
T6 = _mm512_castpd_ps(_mm512_permutex_pd(_mm512_castps_pd(kernel.packet[2]), 0x4E));
|
||||
T6 = _mm512_mask_blend_ps(0xF0F0, T6, kernel.packet[6]);
|
||||
T3 = _mm512_castpd_ps(_mm512_permutex_pd(_mm512_castps_pd(kernel.packet[7]), 0x4E));
|
||||
T3 = _mm512_mask_blend_ps(0xF0F0, kernel.packet[3], T3);
|
||||
T7 = _mm512_castpd_ps(_mm512_permutex_pd(_mm512_castps_pd(kernel.packet[3]), 0x4E));
|
||||
T7 = _mm512_mask_blend_ps(0xF0F0, T7, kernel.packet[7]);
|
||||
|
||||
kernel.packet[0] = T0; kernel.packet[1] = T1;
|
||||
kernel.packet[2] = T2; kernel.packet[3] = T3;
|
||||
kernel.packet[4] = T4; kernel.packet[5] = T5;
|
||||
kernel.packet[6] = T6; kernel.packet[7] = T7;
|
||||
}
|
||||
}
|
||||
|
||||
EIGEN_DEVICE_FUNC inline void ptranspose(PacketBlock<Packet16f, 4>& kernel) {
|
||||
@ -1549,7 +1571,9 @@ EIGEN_DEVICE_FUNC inline void ptranspose(PacketBlock<Packet8d, 4>& kernel) {
|
||||
PACK_OUTPUT_D(kernel.packet, tmp.packet, 3, 1);
|
||||
}
|
||||
|
||||
template<bool for_trsm = false>
|
||||
EIGEN_DEVICE_FUNC inline void ptranspose(PacketBlock<Packet8d, 8>& kernel) {
|
||||
// Transpose for trsm is the same as for gemm.
|
||||
__m512d T0 = _mm512_unpacklo_pd(kernel.packet[0],kernel.packet[1]);
|
||||
__m512d T1 = _mm512_unpackhi_pd(kernel.packet[0],kernel.packet[1]);
|
||||
__m512d T2 = _mm512_unpacklo_pd(kernel.packet[2],kernel.packet[3]);
|
||||
|
@ -198,7 +198,7 @@ public:
|
||||
r.packet[5] = zmm.packet[packetIndexOffset + zmmStride*5];
|
||||
r.packet[6] = zmm.packet[packetIndexOffset + zmmStride*6];
|
||||
r.packet[7] = zmm.packet[packetIndexOffset + zmmStride*7];
|
||||
ptranspose(r);
|
||||
ptranspose<true>(r);
|
||||
zmm.packet[packetIndexOffset + zmmStride*0] = r.packet[0];
|
||||
zmm.packet[packetIndexOffset + zmmStride*1] = r.packet[1];
|
||||
zmm.packet[packetIndexOffset + zmmStride*2] = r.packet[2];
|
||||
|
@ -44,6 +44,34 @@ template<> EIGEN_STRONG_INLINE Packet8f preinterpret<Packet8f, Packet16f>(const
|
||||
return _mm512_castps512_ps256(a);
|
||||
}
|
||||
|
||||
template<> EIGEN_STRONG_INLINE Packet4f preinterpret<Packet4f, Packet16f>(const Packet16f& a) {
|
||||
return _mm512_castps512_ps128(a);
|
||||
}
|
||||
|
||||
template<> EIGEN_STRONG_INLINE Packet4d preinterpret<Packet4d, Packet8d>(const Packet8d& a) {
|
||||
return _mm512_castpd512_pd256(a);
|
||||
}
|
||||
|
||||
template<> EIGEN_STRONG_INLINE Packet2d preinterpret<Packet2d, Packet8d>(const Packet8d& a) {
|
||||
return _mm512_castpd512_pd128(a);
|
||||
}
|
||||
|
||||
template<> EIGEN_STRONG_INLINE Packet16f preinterpret<Packet16f, Packet8f>(const Packet8f& a) {
|
||||
return _mm512_castps256_ps512(a);
|
||||
}
|
||||
|
||||
template<> EIGEN_STRONG_INLINE Packet16f preinterpret<Packet16f, Packet4f>(const Packet4f& a) {
|
||||
return _mm512_castps128_ps512(a);
|
||||
}
|
||||
|
||||
template<> EIGEN_STRONG_INLINE Packet8d preinterpret<Packet8d, Packet4d>(const Packet4d& a) {
|
||||
return _mm512_castpd256_pd512(a);
|
||||
}
|
||||
|
||||
template<> EIGEN_STRONG_INLINE Packet8d preinterpret<Packet8d, Packet2d>(const Packet2d& a) {
|
||||
return _mm512_castpd128_pd512(a);
|
||||
}
|
||||
|
||||
template<> EIGEN_STRONG_INLINE Packet16f preinterpret<Packet16f, Packet16f>(const Packet16f& a) {
|
||||
return a;
|
||||
}
|
||||
|
@ -8,9 +8,9 @@ namespace internal {
|
||||
// Clang seems to excessively spill registers in the GEBP kernel on 32-bit arm.
|
||||
// Here we specialize gebp_traits to eliminate these register spills.
|
||||
// See #2138.
|
||||
template<>
|
||||
struct gebp_traits <float,float,false,false,Architecture::NEON,GEBPPacketFull>
|
||||
: gebp_traits<float,float,false,false,Architecture::Generic,GEBPPacketFull>
|
||||
template<bool UnitResIncr>
|
||||
struct gebp_traits <float,float,UnitResIncr,false,false,Architecture::NEON,GEBPPacketFull>
|
||||
: gebp_traits<float,float,UnitResIncr,false,false,Architecture::Generic,GEBPPacketFull>
|
||||
{
|
||||
EIGEN_STRONG_INLINE void acc(const AccPacket& c, const ResPacket& alpha, ResPacket& r) const
|
||||
{
|
||||
@ -43,9 +43,9 @@ struct gebp_traits <float,float,false,false,Architecture::NEON,GEBPPacketFull>
|
||||
|
||||
#if EIGEN_ARCH_ARM64
|
||||
|
||||
template<>
|
||||
struct gebp_traits <float,float,false,false,Architecture::NEON,GEBPPacketFull>
|
||||
: gebp_traits<float,float,false,false,Architecture::Generic,GEBPPacketFull>
|
||||
template<bool UnitResIncr>
|
||||
struct gebp_traits <float,float,UnitResIncr,false,false,Architecture::NEON,GEBPPacketFull>
|
||||
: gebp_traits<float,float,UnitResIncr,false,false,Architecture::Generic,GEBPPacketFull>
|
||||
{
|
||||
typedef float RhsPacket;
|
||||
typedef float32x4_t RhsPacketx4;
|
||||
@ -108,9 +108,9 @@ struct gebp_traits <float,float,false,false,Architecture::NEON,GEBPPacketFull>
|
||||
};
|
||||
|
||||
|
||||
template<>
|
||||
struct gebp_traits <double,double,false,false,Architecture::NEON>
|
||||
: gebp_traits<double,double,false,false,Architecture::Generic>
|
||||
template<bool UnitResIncr>
|
||||
struct gebp_traits <double,double,UnitResIncr,false,false,Architecture::NEON>
|
||||
: gebp_traits<double,double,UnitResIncr,false,false,Architecture::Generic>
|
||||
{
|
||||
typedef double RhsPacket;
|
||||
|
||||
|
@ -285,6 +285,10 @@ template<> EIGEN_STRONG_INLINE Packet4i padd<Packet4i>(const Packet4i& a, const
|
||||
|
||||
template<> EIGEN_STRONG_INLINE Packet16b padd<Packet16b>(const Packet16b& a, const Packet16b& b) { return _mm_or_si128(a,b); }
|
||||
|
||||
template<typename Packet> EIGEN_STRONG_INLINE Packet padds(const Packet& a, const Packet& b);
|
||||
template<> EIGEN_STRONG_INLINE Packet4f padds<Packet4f>(const Packet4f& a, const Packet4f& b) { return _mm_add_ss(a,b); }
|
||||
template<> EIGEN_STRONG_INLINE Packet2d padds<Packet2d>(const Packet2d& a, const Packet2d& b) { return _mm_add_sd(a,b); }
|
||||
|
||||
template<> EIGEN_STRONG_INLINE Packet4f psub<Packet4f>(const Packet4f& a, const Packet4f& b) { return _mm_sub_ps(a,b); }
|
||||
template<> EIGEN_STRONG_INLINE Packet2d psub<Packet2d>(const Packet2d& a, const Packet2d& b) { return _mm_sub_pd(a,b); }
|
||||
template<> EIGEN_STRONG_INLINE Packet4i psub<Packet4i>(const Packet4i& a, const Packet4i& b) { return _mm_sub_epi32(a,b); }
|
||||
@ -370,6 +374,10 @@ template<> EIGEN_STRONG_INLINE Packet4f pnmadd(const Packet4f& a, const Packet4f
|
||||
template<> EIGEN_STRONG_INLINE Packet2d pnmadd(const Packet2d& a, const Packet2d& b, const Packet2d& c) { return _mm_fnmadd_pd(a,b,c); }
|
||||
template<> EIGEN_STRONG_INLINE Packet4f pnmsub(const Packet4f& a, const Packet4f& b, const Packet4f& c) { return _mm_fnmsub_ps(a,b,c); }
|
||||
template<> EIGEN_STRONG_INLINE Packet2d pnmsub(const Packet2d& a, const Packet2d& b, const Packet2d& c) { return _mm_fnmsub_pd(a,b,c); }
|
||||
|
||||
template<typename Packet> EIGEN_STRONG_INLINE Packet pmadds(const Packet& a, const Packet& b, const Packet& c);
|
||||
template<> EIGEN_STRONG_INLINE Packet4f pmadds<Packet4f>(const Packet4f& a, const Packet4f& b, const Packet4f& c) { return _mm_fmadd_ss(a,b,c); }
|
||||
template<> EIGEN_STRONG_INLINE Packet2d pmadds<Packet2d>(const Packet2d& a, const Packet2d& b, const Packet2d& c) { return _mm_fmadd_sd(a,b,c); }
|
||||
#endif
|
||||
|
||||
#ifdef EIGEN_VECTORIZE_SSE4_1
|
||||
@ -746,6 +754,15 @@ template<> EIGEN_STRONG_INLINE Packet16b ploadu<Packet16b>(const bool* from)
|
||||
return _mm_loadu_si128(reinterpret_cast<const __m128i*>(from));
|
||||
}
|
||||
|
||||
// Load lower part of packet zero extending.
|
||||
template<typename Packet> EIGEN_STRONG_INLINE Packet ploadl(const typename unpacket_traits<Packet>::type* from);
|
||||
template<> EIGEN_STRONG_INLINE Packet4f ploadl<Packet4f>(const float* from) { EIGEN_DEBUG_UNALIGNED_LOAD return _mm_castpd_ps(_mm_load_sd(reinterpret_cast<const double*>(from))); }
|
||||
template<> EIGEN_STRONG_INLINE Packet2d ploadl<Packet2d>(const double* from) { EIGEN_DEBUG_UNALIGNED_LOAD return _mm_load_sd(from); }
|
||||
|
||||
// Load scalar
|
||||
template<typename Packet> EIGEN_STRONG_INLINE Packet ploads(const typename unpacket_traits<Packet>::type* from);
|
||||
template<> EIGEN_STRONG_INLINE Packet4f ploads<Packet4f>(const float* from) { EIGEN_DEBUG_UNALIGNED_LOAD return _mm_load_ss(from); }
|
||||
template<> EIGEN_STRONG_INLINE Packet2d ploads<Packet2d>(const double* from) { EIGEN_DEBUG_UNALIGNED_LOAD return _mm_load_sd(from); }
|
||||
|
||||
template<> EIGEN_STRONG_INLINE Packet4f ploaddup<Packet4f>(const float* from)
|
||||
{
|
||||
@ -787,6 +804,14 @@ template<> EIGEN_STRONG_INLINE void pstoreu<float>(float* to, const Packet4f&
|
||||
template<> EIGEN_STRONG_INLINE void pstoreu<int>(int* to, const Packet4i& from) { EIGEN_DEBUG_UNALIGNED_STORE _mm_storeu_si128(reinterpret_cast<__m128i*>(to), from); }
|
||||
template<> EIGEN_STRONG_INLINE void pstoreu<bool>(bool* to, const Packet16b& from) { EIGEN_DEBUG_ALIGNED_STORE _mm_storeu_si128(reinterpret_cast<__m128i*>(to), from); }
|
||||
|
||||
template<typename Scalar, typename Packet> EIGEN_STRONG_INLINE void pstorel(Scalar* to, const Packet& from);
|
||||
template<> EIGEN_STRONG_INLINE void pstorel(float* to, const Packet4f& from) { EIGEN_DEBUG_UNALIGNED_STORE _mm_storel_pi(reinterpret_cast<__m64*>(to), from); }
|
||||
template<> EIGEN_STRONG_INLINE void pstorel(double* to, const Packet2d& from) { EIGEN_DEBUG_UNALIGNED_STORE _mm_storel_pd(to, from); }
|
||||
|
||||
template<typename Scalar, typename Packet> EIGEN_STRONG_INLINE void pstores(Scalar* to, const Packet& from);
|
||||
template<> EIGEN_STRONG_INLINE void pstores(float* to, const Packet4f& from) { EIGEN_DEBUG_UNALIGNED_STORE _mm_store_ss(to, from); }
|
||||
template<> EIGEN_STRONG_INLINE void pstores(double* to, const Packet2d& from) { EIGEN_DEBUG_UNALIGNED_STORE _mm_store_sd(to, from); }
|
||||
|
||||
template<> EIGEN_DEVICE_FUNC inline Packet4f pgather<float, Packet4f>(const float* from, Index stride)
|
||||
{
|
||||
return _mm_set_ps(from[3*stride], from[2*stride], from[1*stride], from[0*stride]);
|
||||
|
@ -71,6 +71,14 @@ template<> EIGEN_STRONG_INLINE Packet2d pcast<Packet4f, Packet2d>(const Packet4f
|
||||
return _mm_cvtps_pd(a);
|
||||
}
|
||||
|
||||
template<> EIGEN_STRONG_INLINE Packet2d preinterpret<Packet2d, Packet4f>(const Packet4f& a) {
|
||||
return _mm_castps_pd(a);
|
||||
}
|
||||
|
||||
template<> EIGEN_STRONG_INLINE Packet4f preinterpret<Packet4f, Packet2d>(const Packet2d& a) {
|
||||
return _mm_castpd_ps(a);
|
||||
}
|
||||
|
||||
template<> EIGEN_STRONG_INLINE Packet4i preinterpret<Packet4i,Packet4f>(const Packet4f& a) {
|
||||
return _mm_castps_si128(a);
|
||||
}
|
||||
|
@ -2,6 +2,7 @@
|
||||
// for linear algebra.
|
||||
//
|
||||
// Copyright (C) 2008-2009 Gael Guennebaud <gael.guennebaud@inria.fr>
|
||||
// Modifications Copyright (C) 2022 Intel Corporation
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the Mozilla
|
||||
// Public License v. 2.0. If a copy of the MPL was not distributed
|
||||
@ -23,7 +24,7 @@ enum GEBPPacketSizeType {
|
||||
GEBPPacketQuarter
|
||||
};
|
||||
|
||||
template<typename LhsScalar_, typename RhsScalar_, bool ConjLhs_=false, bool ConjRhs_=false, int Arch=Architecture::Target, int PacketSize_=GEBPPacketFull>
|
||||
template<typename LhsScalar_, typename RhsScalar_, bool UnitResIncr=false, bool ConjLhs_=false, bool ConjRhs_=false, int Arch=Architecture::Target, int PacketSize_=GEBPPacketFull>
|
||||
class gebp_traits;
|
||||
|
||||
|
||||
@ -125,7 +126,7 @@ inline void manage_caching_sizes(Action action, std::ptrdiff_t* l1, std::ptrdiff
|
||||
template<typename LhsScalar, typename RhsScalar, int KcFactor, typename Index>
|
||||
void evaluateProductBlockingSizesHeuristic(Index& k, Index& m, Index& n, Index num_threads = 1)
|
||||
{
|
||||
typedef gebp_traits<LhsScalar,RhsScalar> Traits;
|
||||
typedef gebp_traits<LhsScalar,RhsScalar, true> Traits;
|
||||
|
||||
// Explanations:
|
||||
// Let's recall that the product algorithms form mc x kc vertical panels A' on the lhs and
|
||||
@ -416,7 +417,7 @@ struct packet_conditional<GEBPPacketHalf, T1, T2, T3> { typedef T2 type; };
|
||||
* cplx*real : unpack rhs to constant packets, ...
|
||||
* real*cplx : load lhs as (a0,a0,a1,a1), and mul as usual
|
||||
*/
|
||||
template<typename LhsScalar_, typename RhsScalar_, bool ConjLhs_, bool ConjRhs_, int Arch, int PacketSize_>
|
||||
template<typename LhsScalar_, typename RhsScalar_, bool UnitResIncr_, bool ConjLhs_, bool ConjRhs_, int Arch, int PacketSize_>
|
||||
class gebp_traits
|
||||
{
|
||||
public:
|
||||
@ -429,6 +430,7 @@ public:
|
||||
PACKET_DECL_COND_POSTFIX(_, Res, PacketSize_);
|
||||
|
||||
enum {
|
||||
UnitResIncr = UnitResIncr_,
|
||||
ConjLhs = ConjLhs_,
|
||||
ConjRhs = ConjRhs_,
|
||||
Vectorizable = unpacket_traits<LhsPacket_>::vectorizable && unpacket_traits<RhsPacket_>::vectorizable,
|
||||
@ -437,9 +439,17 @@ public:
|
||||
ResPacketSize = Vectorizable ? unpacket_traits<ResPacket_>::size : 1,
|
||||
|
||||
NumberOfRegisters = EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS,
|
||||
IsReal = std::is_same<LhsScalar, RhsScalar>::value
|
||||
&& (std::is_same<LhsScalar, float>::value
|
||||
|| std::is_same<LhsScalar, double>::value),
|
||||
|
||||
// register block size along the N direction must be 1 or 4
|
||||
#if defined(EIGEN_VECTORIZE_AVX512)
|
||||
// AVX512 support nr = 8 for unit inner strides for result matrix.
|
||||
nr = IsReal && Vectorizable && UnitResIncr ? 8 : 4,
|
||||
#else
|
||||
nr = 4,
|
||||
#endif
|
||||
|
||||
// register block size along the M direction (currently, this one cannot be modified)
|
||||
default_mr = (plain_enum_min(16, NumberOfRegisters)/2/nr)*LhsPacketSize,
|
||||
@ -545,8 +555,9 @@ public:
|
||||
|
||||
};
|
||||
|
||||
template<typename RealScalar, bool ConjLhs_, int Arch, int PacketSize_>
|
||||
class gebp_traits<std::complex<RealScalar>, RealScalar, ConjLhs_, false, Arch, PacketSize_>
|
||||
|
||||
template<typename RealScalar, bool UnitResIncr_, bool ConjLhs_, int Arch, int PacketSize_>
|
||||
class gebp_traits<std::complex<RealScalar>, RealScalar, UnitResIncr_, ConjLhs_, false, Arch, PacketSize_>
|
||||
{
|
||||
public:
|
||||
typedef std::complex<RealScalar> LhsScalar;
|
||||
@ -756,8 +767,8 @@ template<typename Packet> struct unpacket_traits<DoublePacket<Packet> > {
|
||||
// return res;
|
||||
// }
|
||||
|
||||
template<typename RealScalar, bool ConjLhs_, bool ConjRhs_, int Arch, int PacketSize_>
|
||||
class gebp_traits<std::complex<RealScalar>, std::complex<RealScalar>, ConjLhs_, ConjRhs_, Arch, PacketSize_ >
|
||||
template<typename RealScalar, bool UnitResIncr_, bool ConjLhs_, bool ConjRhs_, int Arch, int PacketSize_>
|
||||
class gebp_traits<std::complex<RealScalar>, std::complex<RealScalar>, UnitResIncr_, ConjLhs_, ConjRhs_, Arch, PacketSize_ >
|
||||
{
|
||||
public:
|
||||
typedef std::complex<RealScalar> Scalar;
|
||||
@ -922,8 +933,8 @@ protected:
|
||||
conj_helper<LhsScalar,RhsScalar,ConjLhs,ConjRhs> cj;
|
||||
};
|
||||
|
||||
template<typename RealScalar, bool ConjRhs_, int Arch, int PacketSize_>
|
||||
class gebp_traits<RealScalar, std::complex<RealScalar>, false, ConjRhs_, Arch, PacketSize_ >
|
||||
template<typename RealScalar, bool UnitResIncr, bool ConjRhs_, int Arch, int PacketSize_>
|
||||
class gebp_traits<RealScalar, std::complex<RealScalar>, UnitResIncr, false, ConjRhs_, Arch, PacketSize_ >
|
||||
{
|
||||
public:
|
||||
typedef std::complex<RealScalar> Scalar;
|
||||
@ -1058,9 +1069,9 @@ protected:
|
||||
template<typename LhsScalar, typename RhsScalar, typename Index, typename DataMapper, int mr, int nr, bool ConjugateLhs, bool ConjugateRhs>
|
||||
struct gebp_kernel
|
||||
{
|
||||
typedef gebp_traits<LhsScalar,RhsScalar,ConjugateLhs,ConjugateRhs,Architecture::Target> Traits;
|
||||
typedef gebp_traits<LhsScalar,RhsScalar,ConjugateLhs,ConjugateRhs,Architecture::Target,GEBPPacketHalf> HalfTraits;
|
||||
typedef gebp_traits<LhsScalar,RhsScalar,ConjugateLhs,ConjugateRhs,Architecture::Target,GEBPPacketQuarter> QuarterTraits;
|
||||
typedef gebp_traits<LhsScalar,RhsScalar,DataMapper::incr == 1,ConjugateLhs,ConjugateRhs,Architecture::Target> Traits;
|
||||
typedef gebp_traits<LhsScalar,RhsScalar,DataMapper::incr == 1,ConjugateLhs,ConjugateRhs,Architecture::Target,GEBPPacketHalf> HalfTraits;
|
||||
typedef gebp_traits<LhsScalar,RhsScalar,DataMapper::incr == 1,ConjugateLhs,ConjugateRhs,Architecture::Target,GEBPPacketQuarter> QuarterTraits;
|
||||
|
||||
typedef typename Traits::ResScalar ResScalar;
|
||||
typedef typename Traits::LhsPacket LhsPacket;
|
||||
@ -1071,7 +1082,7 @@ struct gebp_kernel
|
||||
|
||||
typedef typename RhsPanelHelper<RhsPacket, RhsPacketx4, 15>::type RhsPanel15;
|
||||
|
||||
typedef gebp_traits<RhsScalar,LhsScalar,ConjugateRhs,ConjugateLhs,Architecture::Target> SwappedTraits;
|
||||
typedef gebp_traits<RhsScalar,LhsScalar,DataMapper::incr == 1,ConjugateRhs,ConjugateLhs,Architecture::Target> SwappedTraits;
|
||||
|
||||
typedef typename SwappedTraits::ResScalar SResScalar;
|
||||
typedef typename SwappedTraits::LhsPacket SLhsPacket;
|
||||
@ -1109,11 +1120,11 @@ struct gebp_kernel
|
||||
};
|
||||
|
||||
template<typename LhsScalar, typename RhsScalar, typename Index, typename DataMapper, int mr, int nr, bool ConjugateLhs, bool ConjugateRhs,
|
||||
int SwappedLhsProgress = gebp_traits<RhsScalar,LhsScalar,ConjugateRhs,ConjugateLhs,Architecture::Target>::LhsProgress>
|
||||
int SwappedLhsProgress = gebp_traits<RhsScalar,LhsScalar,DataMapper::incr == 1,ConjugateRhs,ConjugateLhs,Architecture::Target>::LhsProgress>
|
||||
struct last_row_process_16_packets
|
||||
{
|
||||
typedef gebp_traits<LhsScalar,RhsScalar,ConjugateLhs,ConjugateRhs,Architecture::Target> Traits;
|
||||
typedef gebp_traits<RhsScalar,LhsScalar,ConjugateRhs,ConjugateLhs,Architecture::Target> SwappedTraits;
|
||||
typedef gebp_traits<LhsScalar,RhsScalar,DataMapper::incr == 1,ConjugateLhs,ConjugateRhs,Architecture::Target> Traits;
|
||||
typedef gebp_traits<RhsScalar,LhsScalar,DataMapper::incr == 1,ConjugateRhs,ConjugateLhs,Architecture::Target> SwappedTraits;
|
||||
|
||||
typedef typename Traits::ResScalar ResScalar;
|
||||
typedef typename SwappedTraits::LhsPacket SLhsPacket;
|
||||
@ -1141,8 +1152,8 @@ struct last_row_process_16_packets
|
||||
|
||||
template<typename LhsScalar, typename RhsScalar, typename Index, typename DataMapper, int mr, int nr, bool ConjugateLhs, bool ConjugateRhs>
|
||||
struct last_row_process_16_packets<LhsScalar, RhsScalar, Index, DataMapper, mr, nr, ConjugateLhs, ConjugateRhs, 16> {
|
||||
typedef gebp_traits<LhsScalar,RhsScalar,ConjugateLhs,ConjugateRhs,Architecture::Target> Traits;
|
||||
typedef gebp_traits<RhsScalar,LhsScalar,ConjugateRhs,ConjugateLhs,Architecture::Target> SwappedTraits;
|
||||
typedef gebp_traits<LhsScalar,RhsScalar,DataMapper::incr == 1,ConjugateLhs,ConjugateRhs,Architecture::Target> Traits;
|
||||
typedef gebp_traits<RhsScalar,LhsScalar,DataMapper::incr == 1,ConjugateRhs,ConjugateLhs,Architecture::Target> SwappedTraits;
|
||||
|
||||
typedef typename Traits::ResScalar ResScalar;
|
||||
typedef typename SwappedTraits::LhsPacket SLhsPacket;
|
||||
@ -1408,6 +1419,15 @@ void gebp_kernel<LhsScalar,RhsScalar,Index,DataMapper,mr,nr,ConjugateLhs,Conjuga
|
||||
Index rows, Index depth, Index cols, ResScalar alpha,
|
||||
Index strideA, Index strideB, Index offsetA, Index offsetB)
|
||||
{
|
||||
#if defined(EIGEN_VECTORIZE_AVX512)
|
||||
if (nr == 8) {
|
||||
bool done = gemm_kernel(
|
||||
rows, cols, depth, alpha, blockA, blockB,
|
||||
(ResScalar *)res.data(), res.stride(),
|
||||
strideA, strideB, offsetA, offsetB);
|
||||
if (done) return;
|
||||
}
|
||||
#endif
|
||||
Traits traits;
|
||||
SwappedTraits straits;
|
||||
|
||||
@ -2397,51 +2417,67 @@ EIGEN_DONT_INLINE void gemm_pack_rhs<Scalar, Index, DataMapper, nr, ColMajor, Co
|
||||
Index packet_cols4 = nr>=4 ? (cols/4) * 4 : 0;
|
||||
Index count = 0;
|
||||
const Index peeled_k = (depth/PacketSize)*PacketSize;
|
||||
// if(nr>=8)
|
||||
// {
|
||||
// for(Index j2=0; j2<packet_cols8; j2+=8)
|
||||
// {
|
||||
// // skip what we have before
|
||||
// if(PanelMode) count += 8 * offset;
|
||||
// const Scalar* b0 = &rhs[(j2+0)*rhsStride];
|
||||
// const Scalar* b1 = &rhs[(j2+1)*rhsStride];
|
||||
// const Scalar* b2 = &rhs[(j2+2)*rhsStride];
|
||||
// const Scalar* b3 = &rhs[(j2+3)*rhsStride];
|
||||
// const Scalar* b4 = &rhs[(j2+4)*rhsStride];
|
||||
// const Scalar* b5 = &rhs[(j2+5)*rhsStride];
|
||||
// const Scalar* b6 = &rhs[(j2+6)*rhsStride];
|
||||
// const Scalar* b7 = &rhs[(j2+7)*rhsStride];
|
||||
// Index k=0;
|
||||
// if(PacketSize==8) // TODO enable vectorized transposition for PacketSize==4
|
||||
// {
|
||||
// for(; k<peeled_k; k+=PacketSize) {
|
||||
// PacketBlock<Packet> kernel;
|
||||
// for (int p = 0; p < PacketSize; ++p) {
|
||||
// kernel.packet[p] = ploadu<Packet>(&rhs[(j2+p)*rhsStride+k]);
|
||||
// }
|
||||
// ptranspose(kernel);
|
||||
// for (int p = 0; p < PacketSize; ++p) {
|
||||
// pstoreu(blockB+count, cj.pconj(kernel.packet[p]));
|
||||
// count+=PacketSize;
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
// for(; k<depth; k++)
|
||||
// {
|
||||
// blockB[count+0] = cj(b0[k]);
|
||||
// blockB[count+1] = cj(b1[k]);
|
||||
// blockB[count+2] = cj(b2[k]);
|
||||
// blockB[count+3] = cj(b3[k]);
|
||||
// blockB[count+4] = cj(b4[k]);
|
||||
// blockB[count+5] = cj(b5[k]);
|
||||
// blockB[count+6] = cj(b6[k]);
|
||||
// blockB[count+7] = cj(b7[k]);
|
||||
// count += 8;
|
||||
// }
|
||||
// // skip what we have after
|
||||
// if(PanelMode) count += 8 * (stride-offset-depth);
|
||||
// }
|
||||
// }
|
||||
if(nr>=8)
|
||||
{
|
||||
for(Index j2=0; j2<packet_cols8; j2+=8)
|
||||
{
|
||||
// skip what we have before
|
||||
if(PanelMode) count += 8 * offset;
|
||||
const LinearMapper dm0 = rhs.getLinearMapper(0, j2+0);
|
||||
const LinearMapper dm1 = rhs.getLinearMapper(0, j2+1);
|
||||
const LinearMapper dm2 = rhs.getLinearMapper(0, j2+2);
|
||||
const LinearMapper dm3 = rhs.getLinearMapper(0, j2+3);
|
||||
const LinearMapper dm4 = rhs.getLinearMapper(0, j2+4);
|
||||
const LinearMapper dm5 = rhs.getLinearMapper(0, j2+5);
|
||||
const LinearMapper dm6 = rhs.getLinearMapper(0, j2+6);
|
||||
const LinearMapper dm7 = rhs.getLinearMapper(0, j2+7);
|
||||
Index k=0;
|
||||
#if 0
|
||||
// TODO Need to enable vectorized transposition.
|
||||
if((PacketSize%8)==0) // TODO enable vectorized transposition for PacketSize==4
|
||||
{
|
||||
for(; k<peeled_k; k+=PacketSize) {
|
||||
PacketBlock<Packet,(PacketSize%8)==0?8:PacketSize> kernel;
|
||||
|
||||
kernel.packet[0] = dm0.template loadPacket<Packet>(k);
|
||||
kernel.packet[1] = dm1.template loadPacket<Packet>(k);
|
||||
kernel.packet[2] = dm2.template loadPacket<Packet>(k);
|
||||
kernel.packet[3] = dm3.template loadPacket<Packet>(k);
|
||||
kernel.packet[4] = dm4.template loadPacket<Packet>(k);
|
||||
kernel.packet[5] = dm5.template loadPacket<Packet>(k);
|
||||
kernel.packet[6] = dm6.template loadPacket<Packet>(k);
|
||||
kernel.packet[7] = dm7.template loadPacket<Packet>(k);
|
||||
|
||||
ptranspose(kernel);
|
||||
|
||||
pstoreu(blockB+count+0*PacketSize, cj.pconj(kernel.packet[0]));
|
||||
pstoreu(blockB+count+1*PacketSize, cj.pconj(kernel.packet[1%PacketSize]));
|
||||
pstoreu(blockB+count+2*PacketSize, cj.pconj(kernel.packet[2%PacketSize]));
|
||||
pstoreu(blockB+count+3*PacketSize, cj.pconj(kernel.packet[3%PacketSize]));
|
||||
pstoreu(blockB+count+4*PacketSize, cj.pconj(kernel.packet[4%PacketSize]));
|
||||
pstoreu(blockB+count+5*PacketSize, cj.pconj(kernel.packet[5%PacketSize]));
|
||||
pstoreu(blockB+count+6*PacketSize, cj.pconj(kernel.packet[6%PacketSize]));
|
||||
pstoreu(blockB+count+7*PacketSize, cj.pconj(kernel.packet[7%PacketSize]));
|
||||
count+=8*PacketSize;
|
||||
}
|
||||
}
|
||||
#endif
|
||||
for(; k<depth; k++)
|
||||
{
|
||||
blockB[count+0] = cj(dm0(k));
|
||||
blockB[count+1] = cj(dm1(k));
|
||||
blockB[count+2] = cj(dm2(k));
|
||||
blockB[count+3] = cj(dm3(k));
|
||||
blockB[count+4] = cj(dm4(k));
|
||||
blockB[count+5] = cj(dm5(k));
|
||||
blockB[count+6] = cj(dm6(k));
|
||||
blockB[count+7] = cj(dm7(k));
|
||||
count += 8;
|
||||
}
|
||||
// skip what we have after
|
||||
if(PanelMode) count += 8 * (stride-offset-depth);
|
||||
}
|
||||
}
|
||||
|
||||
if(nr>=4)
|
||||
{
|
||||
@ -2522,39 +2558,50 @@ struct gemm_pack_rhs<Scalar, Index, DataMapper, nr, RowMajor, Conjugate, PanelMo
|
||||
Index packet_cols4 = nr>=4 ? (cols/4) * 4 : 0;
|
||||
Index count = 0;
|
||||
|
||||
// if(nr>=8)
|
||||
// {
|
||||
// for(Index j2=0; j2<packet_cols8; j2+=8)
|
||||
// {
|
||||
// // skip what we have before
|
||||
// if(PanelMode) count += 8 * offset;
|
||||
// for(Index k=0; k<depth; k++)
|
||||
// {
|
||||
// if (PacketSize==8) {
|
||||
// Packet A = ploadu<Packet>(&rhs[k*rhsStride + j2]);
|
||||
// pstoreu(blockB+count, cj.pconj(A));
|
||||
// } else if (PacketSize==4) {
|
||||
// Packet A = ploadu<Packet>(&rhs[k*rhsStride + j2]);
|
||||
// Packet B = ploadu<Packet>(&rhs[k*rhsStride + j2 + PacketSize]);
|
||||
// pstoreu(blockB+count, cj.pconj(A));
|
||||
// pstoreu(blockB+count+PacketSize, cj.pconj(B));
|
||||
// } else {
|
||||
// const Scalar* b0 = &rhs[k*rhsStride + j2];
|
||||
// blockB[count+0] = cj(b0[0]);
|
||||
// blockB[count+1] = cj(b0[1]);
|
||||
// blockB[count+2] = cj(b0[2]);
|
||||
// blockB[count+3] = cj(b0[3]);
|
||||
// blockB[count+4] = cj(b0[4]);
|
||||
// blockB[count+5] = cj(b0[5]);
|
||||
// blockB[count+6] = cj(b0[6]);
|
||||
// blockB[count+7] = cj(b0[7]);
|
||||
// }
|
||||
// count += 8;
|
||||
// }
|
||||
// // skip what we have after
|
||||
// if(PanelMode) count += 8 * (stride-offset-depth);
|
||||
// }
|
||||
// }
|
||||
if(nr>=8)
|
||||
{
|
||||
for(Index j2=0; j2<packet_cols8; j2+=8)
|
||||
{
|
||||
// skip what we have before
|
||||
if(PanelMode) count += 8 * offset;
|
||||
for(Index k=0; k<depth; k++)
|
||||
{
|
||||
if (PacketSize==8) {
|
||||
// Packet A = ploadu<Packet>(&rhs.data()[k*rhs.stride() + j2]);
|
||||
Packet A = rhs.template loadPacket<Packet>(k, j2);
|
||||
pstoreu(blockB+count, cj.pconj(A));
|
||||
} else if (HasHalf && HalfPacketSize==8) {
|
||||
HalfPacket A = rhs.template loadPacket<HalfPacket>(k, j2);
|
||||
pstoreu(blockB+count, cj.pconj(A));
|
||||
} else if (HasQuarter && QuarterPacketSize==8) {
|
||||
QuarterPacket A = rhs.template loadPacket<QuarterPacket>(k, j2);
|
||||
pstoreu(blockB+count, cj.pconj(A));
|
||||
} else if (PacketSize==4) {
|
||||
// Packet A = ploadu<Packet>(&rhs.data()[k*rhs.stride() + j2]);
|
||||
// Packet B = ploadu<Packet>(&rhs.data()[k*rhs.stride() + j2 + PacketSize]);
|
||||
Packet A = rhs.template loadPacket<Packet>(k, j2);
|
||||
Packet B = rhs.template loadPacket<Packet>(k, j2 + PacketSize);
|
||||
pstoreu(blockB+count, cj.pconj(A));
|
||||
pstoreu(blockB+count+PacketSize, cj.pconj(B));
|
||||
} else {
|
||||
// const Scalar* b0 = &rhs.data()[k*rhs.stride() + j2];
|
||||
const LinearMapper dm0 = rhs.getLinearMapper(k, j2);
|
||||
blockB[count+0] = cj(dm0(0));
|
||||
blockB[count+1] = cj(dm0(1));
|
||||
blockB[count+2] = cj(dm0(2));
|
||||
blockB[count+3] = cj(dm0(3));
|
||||
blockB[count+4] = cj(dm0(4));
|
||||
blockB[count+5] = cj(dm0(5));
|
||||
blockB[count+6] = cj(dm0(6));
|
||||
blockB[count+7] = cj(dm0(7));
|
||||
}
|
||||
count += 8;
|
||||
}
|
||||
// skip what we have after
|
||||
if(PanelMode) count += 8 * (stride-offset-depth);
|
||||
}
|
||||
}
|
||||
|
||||
if(nr>=4)
|
||||
{
|
||||
for(Index j2=packet_cols8; j2<packet_cols4; j2+=4)
|
||||
|
@ -26,7 +26,7 @@ template<
|
||||
int ResInnerStride>
|
||||
struct general_matrix_matrix_product<Index,LhsScalar,LhsStorageOrder,ConjugateLhs,RhsScalar,RhsStorageOrder,ConjugateRhs,RowMajor,ResInnerStride>
|
||||
{
|
||||
typedef gebp_traits<RhsScalar,LhsScalar> Traits;
|
||||
typedef gebp_traits<RhsScalar,LhsScalar,ResInnerStride == 1> Traits;
|
||||
|
||||
typedef typename ScalarBinaryOpTraits<LhsScalar, RhsScalar>::ReturnType ResScalar;
|
||||
static EIGEN_STRONG_INLINE void run(
|
||||
@ -57,7 +57,7 @@ template<
|
||||
struct general_matrix_matrix_product<Index,LhsScalar,LhsStorageOrder,ConjugateLhs,RhsScalar,RhsStorageOrder,ConjugateRhs,ColMajor,ResInnerStride>
|
||||
{
|
||||
|
||||
typedef gebp_traits<LhsScalar,RhsScalar> Traits;
|
||||
typedef gebp_traits<LhsScalar,RhsScalar, ResInnerStride == 1> Traits;
|
||||
|
||||
typedef typename ScalarBinaryOpTraits<LhsScalar, RhsScalar>::ReturnType ResScalar;
|
||||
static void run(Index rows, Index cols, Index depth,
|
||||
@ -287,7 +287,6 @@ class gemm_blocking_space<StorageOrder,LhsScalar_,RhsScalar_,MaxRows, MaxCols, M
|
||||
};
|
||||
typedef std::conditional_t<Transpose,RhsScalar_,LhsScalar_> LhsScalar;
|
||||
typedef std::conditional_t<Transpose,LhsScalar_,RhsScalar_> RhsScalar;
|
||||
typedef gebp_traits<LhsScalar,RhsScalar> Traits;
|
||||
enum {
|
||||
SizeA = ActualRows * MaxDepth,
|
||||
SizeB = ActualCols * MaxDepth
|
||||
@ -336,7 +335,6 @@ class gemm_blocking_space<StorageOrder,LhsScalar_,RhsScalar_,MaxRows, MaxCols, M
|
||||
};
|
||||
typedef std::conditional_t<Transpose,RhsScalar_,LhsScalar_> LhsScalar;
|
||||
typedef std::conditional_t<Transpose,LhsScalar_,RhsScalar_> RhsScalar;
|
||||
typedef gebp_traits<LhsScalar,RhsScalar> Traits;
|
||||
|
||||
Index m_sizeA;
|
||||
Index m_sizeB;
|
||||
|
@ -67,7 +67,7 @@ struct general_matrix_matrix_triangular_product<Index,LhsScalar,LhsStorageOrder,
|
||||
ResScalar* _res, Index resIncr, Index resStride,
|
||||
const ResScalar& alpha, level3_blocking<LhsScalar,RhsScalar>& blocking)
|
||||
{
|
||||
typedef gebp_traits<LhsScalar,RhsScalar> Traits;
|
||||
typedef gebp_traits<LhsScalar,RhsScalar,ResInnerStride == 1> Traits;
|
||||
|
||||
typedef const_blas_data_mapper<LhsScalar, Index, LhsStorageOrder> LhsMapper;
|
||||
typedef const_blas_data_mapper<RhsScalar, Index, RhsStorageOrder> RhsMapper;
|
||||
@ -140,7 +140,7 @@ struct general_matrix_matrix_triangular_product<Index,LhsScalar,LhsStorageOrder,
|
||||
template<typename LhsScalar, typename RhsScalar, typename Index, int mr, int nr, bool ConjLhs, bool ConjRhs, int ResInnerStride, int UpLo>
|
||||
struct tribb_kernel
|
||||
{
|
||||
typedef gebp_traits<LhsScalar,RhsScalar,ConjLhs,ConjRhs> Traits;
|
||||
typedef gebp_traits<LhsScalar,RhsScalar,ResInnerStride == 1,ConjLhs,ConjRhs> Traits;
|
||||
typedef typename Traits::ResScalar ResScalar;
|
||||
|
||||
enum {
|
||||
|
@ -55,7 +55,7 @@ template< \
|
||||
int RhsStorageOrder, bool ConjugateRhs> \
|
||||
struct general_matrix_matrix_product<Index,EIGTYPE,LhsStorageOrder,ConjugateLhs,EIGTYPE,RhsStorageOrder,ConjugateRhs,ColMajor,1> \
|
||||
{ \
|
||||
typedef gebp_traits<EIGTYPE,EIGTYPE> Traits; \
|
||||
typedef gebp_traits<EIGTYPE,EIGTYPE,true> Traits; \
|
||||
\
|
||||
static void run(Index rows, Index cols, Index depth, \
|
||||
const EIGTYPE* _lhs, Index lhsStride, \
|
||||
|
@ -351,7 +351,7 @@ EIGEN_DONT_INLINE void product_selfadjoint_matrix<Scalar,Index,LhsStorageOrder,t
|
||||
{
|
||||
Index size = rows;
|
||||
|
||||
typedef gebp_traits<Scalar,Scalar> Traits;
|
||||
typedef gebp_traits<Scalar,Scalar,ResInnerStride == 1> Traits;
|
||||
|
||||
typedef const_blas_data_mapper<Scalar, Index, LhsStorageOrder> LhsMapper;
|
||||
typedef const_blas_data_mapper<Scalar, Index, (LhsStorageOrder == RowMajor) ? ColMajor : RowMajor> LhsTransposeMapper;
|
||||
@ -446,7 +446,7 @@ EIGEN_DONT_INLINE void product_selfadjoint_matrix<Scalar,Index,LhsStorageOrder,f
|
||||
{
|
||||
Index size = cols;
|
||||
|
||||
typedef gebp_traits<Scalar,Scalar> Traits;
|
||||
typedef gebp_traits<Scalar,Scalar,ResInnerStride == 1> Traits;
|
||||
|
||||
typedef const_blas_data_mapper<Scalar, Index, LhsStorageOrder> LhsMapper;
|
||||
typedef blas_data_mapper<typename Traits::ResScalar, Index, ColMajor, Unaligned, ResInnerStride> ResMapper;
|
||||
|
@ -89,7 +89,7 @@ struct product_triangular_matrix_matrix<Scalar,Index,Mode,true,
|
||||
RhsStorageOrder,ConjugateRhs,ColMajor,ResInnerStride,Version>
|
||||
{
|
||||
|
||||
typedef gebp_traits<Scalar,Scalar> Traits;
|
||||
typedef gebp_traits<Scalar,Scalar,ResInnerStride == 1> Traits;
|
||||
enum {
|
||||
SmallPanelWidth = 2 * plain_enum_max(Traits::mr, Traits::nr),
|
||||
IsLower = (Mode&Lower) == Lower,
|
||||
@ -247,7 +247,7 @@ struct product_triangular_matrix_matrix<Scalar,Index,Mode,false,
|
||||
LhsStorageOrder,ConjugateLhs,
|
||||
RhsStorageOrder,ConjugateRhs,ColMajor,ResInnerStride,Version>
|
||||
{
|
||||
typedef gebp_traits<Scalar,Scalar> Traits;
|
||||
typedef gebp_traits<Scalar,Scalar,ResInnerStride == 1> Traits;
|
||||
enum {
|
||||
SmallPanelWidth = plain_enum_max(Traits::mr, Traits::nr),
|
||||
IsLower = (Mode&Lower) == Lower,
|
||||
|
@ -189,7 +189,7 @@ EIGEN_DONT_INLINE void triangular_solve_matrix<Scalar,Index,OnTheLeft,Mode,Conju
|
||||
TriMapper tri(_tri, triStride);
|
||||
OtherMapper other(_other, otherStride, otherIncr);
|
||||
|
||||
typedef gebp_traits<Scalar,Scalar> Traits;
|
||||
typedef gebp_traits<Scalar,Scalar,OtherInnerStride == 1> Traits;
|
||||
|
||||
enum {
|
||||
SmallPanelWidth = plain_enum_max(Traits::mr, Traits::nr),
|
||||
@ -336,7 +336,7 @@ EIGEN_DONT_INLINE void triangular_solve_matrix<Scalar,Index,OnTheRight,Mode,Conj
|
||||
LhsMapper lhs(_other, otherStride, otherIncr);
|
||||
RhsMapper rhs(_tri, triStride);
|
||||
|
||||
typedef gebp_traits<Scalar,Scalar> Traits;
|
||||
typedef gebp_traits<Scalar,Scalar,OtherInnerStride == 1> Traits;
|
||||
enum {
|
||||
RhsStorageOrder = TriStorageOrder,
|
||||
SmallPanelWidth = plain_enum_max(Traits::mr, Traits::nr),
|
||||
|
@ -173,6 +173,7 @@ class blas_data_mapper<Scalar,Index,StorageOrder,AlignmentType,1>
|
||||
public:
|
||||
typedef BlasLinearMapper<Scalar, Index, AlignmentType> LinearMapper;
|
||||
typedef BlasVectorMapper<Scalar, Index> VectorMapper;
|
||||
static constexpr int incr = 1;
|
||||
|
||||
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE blas_data_mapper(Scalar* data, Index stride, Index incr=1)
|
||||
: m_data(data), m_stride(stride)
|
||||
@ -285,6 +286,7 @@ class blas_data_mapper
|
||||
{
|
||||
public:
|
||||
typedef BlasLinearMapper<Scalar, Index, AlignmentType,Incr> LinearMapper;
|
||||
static constexpr int incr = Incr;
|
||||
|
||||
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE blas_data_mapper(Scalar* data, Index stride, Index incr) : m_data(data), m_stride(stride), m_incr(incr) {}
|
||||
|
||||
@ -402,6 +404,9 @@ public:
|
||||
storePacketBlock_helper<SubPacket, Scalar, n, n-1> spb;
|
||||
spb.store(this, i,j,block);
|
||||
}
|
||||
|
||||
EIGEN_DEVICE_FUNC const Index stride() const { return m_stride; }
|
||||
EIGEN_DEVICE_FUNC Scalar* data() const { return m_data; }
|
||||
protected:
|
||||
Scalar* EIGEN_RESTRICT m_data;
|
||||
const Index m_stride;
|
||||
|
@ -143,8 +143,8 @@ int main()
|
||||
|
||||
// Specialize GEBP kernel and traits for mpreal (no need for peeling, nor complicated stuff)
|
||||
// This also permits to directly call mpfr's routines and avoid many temporaries produced by mpreal
|
||||
template<>
|
||||
class gebp_traits<mpfr::mpreal, mpfr::mpreal, false, false>
|
||||
template<bool UnitResIncr>
|
||||
class gebp_traits<mpfr::mpreal, mpfr::mpreal, UnitResIncr, false, false>
|
||||
{
|
||||
public:
|
||||
typedef mpfr::mpreal ResScalar;
|
||||
|
Loading…
x
Reference in New Issue
Block a user