Add AVX512 optimizations for matrix multiply

This commit is contained in:
aaraujom 2022-05-12 23:41:19 +00:00 committed by Rasmus Munk Larsen
parent 00b75375e7
commit 25db0b4a82
17 changed files with 1251 additions and 142 deletions

View File

@ -191,6 +191,7 @@ using std::ptrdiff_t;
#include "src/Core/arch/AVX/MathFunctions.h" #include "src/Core/arch/AVX/MathFunctions.h"
#include "src/Core/arch/AVX512/MathFunctions.h" #include "src/Core/arch/AVX512/MathFunctions.h"
#include "src/Core/arch/AVX512/TrsmKernel.h" #include "src/Core/arch/AVX512/TrsmKernel.h"
#include "src/Core/arch/AVX512/GemmKernel.h"
#elif defined EIGEN_VECTORIZE_AVX #elif defined EIGEN_VECTORIZE_AVX
// Use AVX for floats and doubles, SSE for integers // Use AVX for floats and doubles, SSE for integers
#include "src/Core/arch/SSE/PacketMath.h" #include "src/Core/arch/SSE/PacketMath.h"

View File

@ -0,0 +1,973 @@
#ifndef GEMM_KERNEL_H
#define GEMM_KERNEL_H
#include <x86intrin.h>
#include <immintrin.h>
#include <type_traits>
#define SECOND_FETCH (32)
#if (EIGEN_COMP_GNUC_STRICT != 0) && !defined(EIGEN_ARCH_AVX512_GEMM_KERNEL_USE_LESS_A_REGS)
// Use less registers to load A elements to workaround compiler spills. Loose a
// bit of performance (less than ~2%).
#define EIGEN_ARCH_AVX512_GEMM_KERNEL_USE_LESS_A_REGS
#endif
namespace Eigen {
namespace internal {
static inline constexpr int div_up(int a, int b) {
return (a + b - 1) / b;
}
template <typename Scalar>
class gemm_class
{
using vec = typename std::conditional<std::is_same<Scalar, float>::value,
Packet16f, Packet8d>::type;
using vec_ymm = typename std::conditional<std::is_same<Scalar, float>::value,
Packet8f, Packet4d>::type;
using vec_xmm = typename std::conditional<std::is_same<Scalar, float>::value,
Packet4f, Packet2d>::type;
static constexpr bool is_f32 = sizeof(Scalar) == sizeof(float);
static constexpr bool is_f64 = sizeof(Scalar) == sizeof(double);
#ifndef EIGEN_ARCH_AVX512_GEMM_KERNEL_USE_LESS_A_REGS
static constexpr int a_regs[] = {0, 1, 2, 3, 4, 5};
#else
static constexpr int a_regs[] = {0, 1, 2, 0, 1, 2};
#endif
#ifndef EIGEN_ARCH_AVX512_GEMM_KERNEL_USE_LESS_B_REGS
static constexpr int b_regs[] = {6, 7};
#else
static constexpr int b_regs[] = {6, 6};
#endif
static constexpr int c_regs[] = {
8 , 16, 24,
9 , 17, 25,
10, 18, 26,
11, 19, 27,
12, 20, 28,
13, 21, 29,
14, 22, 30,
15, 23, 31,
};
static constexpr int a_shift = 128;
static constexpr int b_shift = 128;
static constexpr int nelems_in_cache_line = is_f32 ? 16 : 8;
static constexpr int a_prefetch_size = nelems_in_cache_line * 2;
static constexpr int b_prefetch_size = nelems_in_cache_line * 8;
vec zmm[32];
// gemm arguments.
int64_t m;
const int64_t n, k, ldc;
const Scalar *alpha;
const Scalar *a, *b;
Scalar *c;
const bool is_alpha1;
const bool is_beta0;
const int64_t a_stride, b_stride;
const int64_t a_off, b_off;
public:
EIGEN_ALWAYS_INLINE void prefetch_a(const Scalar *a_addr)
{
_mm_prefetch((char *) (a_prefetch_size + a_addr - a_shift), _MM_HINT_T0);
}
EIGEN_ALWAYS_INLINE void prefetch_b(const Scalar *b_addr)
{
_mm_prefetch((char *) (b_prefetch_size + b_addr - b_shift), _MM_HINT_T0);
}
EIGEN_ALWAYS_INLINE void prefetch_x(const Scalar *x_addr)
{
_mm_prefetch((char *) (x_addr - a_shift), _MM_HINT_T2);
}
EIGEN_ALWAYS_INLINE void prefetch_c(const Scalar *c_addr)
{
#if defined(__PRFCHW__) && __PRFCHW__ == 1
_m_prefetchw((void *) c_addr);
#else
_mm_prefetch((char *) c_addr, _MM_HINT_T0);
#endif
}
template <int nelems>
EIGEN_ALWAYS_INLINE void a_load(vec &a_reg, const Scalar *a_addr)
{
switch (nelems * sizeof(*a_addr) * 8) {
default:
case 512 * 3: a_reg = ploadu<vec>(a_addr); break;
case 512 * 2: a_reg = ploadu<vec>(a_addr); break;
case 512 * 1: a_reg = ploadu<vec>(a_addr); break;
case 256 * 1: a_reg = preinterpret<vec>(_mm512_broadcast_f64x4(ploadu<Packet4d>(reinterpret_cast<const double *>(a_addr)))); break;
case 128 * 1: a_reg = preinterpret<vec>(_mm512_broadcast_f32x4(ploadu<Packet4f>(reinterpret_cast<const float *>(a_addr)))); break;
case 64 * 1: a_reg = preinterpret<vec>(pload1<Packet8d>(reinterpret_cast<const double *>(a_addr))); break;
case 32 * 1: a_reg = pload1<vec>(a_addr); break;
}
}
EIGEN_ALWAYS_INLINE void b_load(vec &b_reg, const Scalar *b_addr)
{
b_reg = pload1<vec>(b_addr);
}
template <int nelems>
EIGEN_ALWAYS_INLINE void c_store(Scalar *mem, vec &src)
{
switch (nelems * sizeof(*mem) * 8) {
default:
case 512 * 3: pstoreu(mem, src); break;
case 512 * 2: pstoreu(mem, src); break;
case 512 * 1: pstoreu(mem, src); break;
case 256 * 1: pstoreu(mem, preinterpret<vec_ymm>(src)); break;
case 128 * 1: pstoreu(mem, preinterpret<vec_xmm>(src)); break;
case 64 * 1: pstorel(mem, preinterpret<vec_xmm>(src)); break;
case 32 * 1: pstores(mem, preinterpret<vec_xmm>(src)); break;
}
}
template <int nelems>
EIGEN_ALWAYS_INLINE void vaddm(vec &dst, const Scalar *mem, vec &src)
{
switch (nelems * sizeof(*mem) * 8) {
default:
case 512 * 3: dst = padd(src, ploadu<vec>(mem)); break;
case 512 * 2: dst = padd(src, ploadu<vec>(mem)); break;
case 512 * 1: dst = padd(src, ploadu<vec>(mem)); break;
case 256 * 1: dst = preinterpret<vec>(padd(preinterpret<vec_ymm>(src), ploadu<vec_ymm>(mem))); break;
case 128 * 1: dst = preinterpret<vec>(padd(preinterpret<vec_xmm>(src), ploadu<vec_xmm>(mem))); break;
case 64 * 1: dst = preinterpret<vec>(padd(preinterpret<vec_xmm>(src), ploadl<vec_xmm>(mem))); break;
case 32 * 1: dst = preinterpret<vec>(padds(preinterpret<vec_xmm>(src), ploads<vec_xmm>(mem))); break;
}
}
EIGEN_STRONG_INLINE void vfmadd(vec &dst, const vec &src1, const vec &src2) {
dst = pmadd(src1, src2, dst);
#if (EIGEN_COMP_GNUC != 0) || (EIGEN_COMP_CLANG != 0)
// Workaround register spills for gcc and clang
__asm__ ("#" : [dst] "+v" (dst) : [src1] "%v" (src1), [src2] "v" (src2));
#endif
}
template <int nelems>
EIGEN_ALWAYS_INLINE void vfmaddm(vec &dst, const Scalar *mem, vec &src, vec &scale)
{
switch (nelems * sizeof(*mem) * 8) {
default:
case 512 * 3: dst = pmadd(scale, src, ploadu<vec>(mem)); break;
case 512 * 2: dst = pmadd(scale, src, ploadu<vec>(mem)); break;
case 512 * 1: dst = pmadd(scale, src, ploadu<vec>(mem)); break;
case 256 * 1: dst = preinterpret<vec>(pmadd(preinterpret<vec_ymm>(scale), preinterpret<vec_ymm>(src), ploadu<vec_ymm>(mem))); break;
case 128 * 1: dst = preinterpret<vec>(pmadd(preinterpret<vec_xmm>(scale), preinterpret<vec_xmm>(src), ploadu<vec_xmm>(mem))); break;
case 64 * 1: dst = preinterpret<vec>(pmadd(preinterpret<vec_xmm>(scale), preinterpret<vec_xmm>(src), ploadl<vec_xmm>(mem))); break;
case 32 * 1: dst = preinterpret<vec>(pmadds(preinterpret<vec_xmm>(scale), preinterpret<vec_xmm>(src), ploads<vec_xmm>(mem))); break;
}
}
gemm_class(int64_t m_, int64_t n_, int64_t k_, int64_t ldc_, const Scalar *alpha_,
const Scalar *a_, const Scalar *b_, Scalar *c_,
bool is_alpha1_, bool is_beta0_,
int64_t a_stride_, int64_t b_stride_,
int64_t a_off_, int64_t b_off_)
: m(m_)
, n(n_)
, k(k_)
, ldc(ldc_)
, alpha(alpha_)
, a(a_)
, b(b_)
, c(c_)
, is_alpha1(is_alpha1_)
, is_beta0(is_beta0_)
, a_stride(a_stride_)
, b_stride(b_stride_)
, a_off(a_off_)
, b_off(b_off_)
{
// Zero out all accumulation registers.
zmm[8 ] = pzero(zmm[8 ]);
zmm[9 ] = pzero(zmm[9 ]);
zmm[10] = pzero(zmm[10]);
zmm[11] = pzero(zmm[11]);
zmm[12] = pzero(zmm[12]);
zmm[13] = pzero(zmm[13]);
zmm[14] = pzero(zmm[14]);
zmm[15] = pzero(zmm[15]);
zmm[16] = pzero(zmm[16]);
zmm[17] = pzero(zmm[17]);
zmm[18] = pzero(zmm[18]);
zmm[19] = pzero(zmm[19]);
zmm[20] = pzero(zmm[20]);
zmm[21] = pzero(zmm[21]);
zmm[22] = pzero(zmm[22]);
zmm[23] = pzero(zmm[23]);
zmm[24] = pzero(zmm[24]);
zmm[25] = pzero(zmm[25]);
zmm[26] = pzero(zmm[26]);
zmm[27] = pzero(zmm[27]);
zmm[28] = pzero(zmm[28]);
zmm[29] = pzero(zmm[29]);
zmm[30] = pzero(zmm[30]);
zmm[31] = pzero(zmm[31]);
}
template <int j, int endX, int i, int endY, int nelems>
EIGEN_ALWAYS_INLINE std::enable_if_t<(j > endX) || (i > endY)>
a_loads(const Scalar *ao)
{
EIGEN_UNUSED_VARIABLE(ao);
}
template <int j, int endX, int i, int endY, int nelems>
EIGEN_ALWAYS_INLINE std::enable_if_t<(j <= endX) && (i <= endY)>
a_loads(const Scalar *ao)
{
if (j < endX) {
if (i < endY) {
auto &a_reg = zmm[a_regs[i + (j % 2) * 3]];
const Scalar *a_addr = ao + nelems * j + nelems_in_cache_line * i - a_shift;
a_load<nelems>(a_reg, a_addr);
a_loads<j, endX, i + 1, endY, nelems>(ao);
} else {
a_loads<j + 1, endX, 0, endY, nelems>(ao);
}
}
}
template <int un, int max_b_unroll, int i, int um_vecs, int a_unroll, int b_unroll>
EIGEN_ALWAYS_INLINE std::enable_if_t<(un > max_b_unroll) || (i > um_vecs)>
prefetch_cs(const Scalar *co1, const Scalar *co2)
{
EIGEN_UNUSED_VARIABLE(co1);
EIGEN_UNUSED_VARIABLE(co2);
}
/* C prefetch loop structure.
* for (int un = 0; un < 8; un++) {
* if (b_unroll >= un + 1) {
* if (un == 4) co2 = co1 + 4 * ldc;
*
* for (int i = 0; i < um_vecs; i++) {
* Scalar *co = (un + 1 <= 4) ? co1 : co2;
* auto co_off = (un % 4) * ldc + a_unroll - 1 + i * nelems_in_cache_line * sizeof *co;
* prefetch_c(co + co_off);
* }
* }
* }
*/
template <int un, int max_b_unroll, int i, int um_vecs, int a_unroll, int b_unroll>
EIGEN_ALWAYS_INLINE std::enable_if_t<(un <= max_b_unroll) && (i <= um_vecs)>
prefetch_cs(Scalar *&co1, Scalar *&co2)
{
if (un < max_b_unroll) {
if (b_unroll >= un + 1) {
if (un == 4 && i == 0) co2 = co1 + 4 * ldc;
if (i < um_vecs) {
Scalar *co = (un + 1 <= 4) ? co1 : co2;
auto co_off = (un % 4) * ldc + a_unroll - 1 + i * nelems_in_cache_line * sizeof *co;
prefetch_c(co + co_off);
prefetch_cs<un, max_b_unroll, i + 1, um_vecs, a_unroll, b_unroll>(co1, co2);
} else {
prefetch_cs<un + 1, max_b_unroll, 0, um_vecs, a_unroll, b_unroll>(co1, co2);
}
} else {
prefetch_cs<un + 1, max_b_unroll, 0, um_vecs, a_unroll, b_unroll>(co1, co2);
}
}
}
// load_c
template <int i, int um_vecs, int idx, int nelems>
EIGEN_ALWAYS_INLINE std::enable_if_t<(i > um_vecs)>
scale_load_c(const Scalar *cox, vec &alpha_reg)
{
EIGEN_UNUSED_VARIABLE(cox);
EIGEN_UNUSED_VARIABLE(alpha_reg);
}
template <int i, int um_vecs, int idx, int nelems>
EIGEN_ALWAYS_INLINE std::enable_if_t<(i <= um_vecs)>
scale_load_c(const Scalar *cox, vec &alpha_reg)
{
if (i < um_vecs) {
auto &c_reg = zmm[c_regs[i + idx * 3]];
auto c_mem = cox + i * nelems_in_cache_line;
if (!is_beta0 && is_alpha1)
vaddm<nelems>(c_reg, c_mem, c_reg);
else if (!is_beta0 && !is_alpha1)
vfmaddm<nelems>(c_reg, c_mem, c_reg, alpha_reg);
else if (is_beta0 && !is_alpha1)
c_reg = pmul(alpha_reg, c_reg);
scale_load_c<i + 1, um_vecs, idx, nelems>(cox, alpha_reg);
}
}
// store_c
template <int i, int um_vecs, int idx, int nelems>
EIGEN_ALWAYS_INLINE std::enable_if_t<(i > um_vecs)>
write_c(Scalar *cox)
{
EIGEN_UNUSED_VARIABLE(cox);
}
template <int i, int um_vecs, int idx, int nelems>
EIGEN_ALWAYS_INLINE std::enable_if_t<(i <= um_vecs)>
write_c(Scalar *cox)
{
if (i < um_vecs) {
auto &c_reg = zmm[c_regs[i + idx * 3]];
auto c_mem = cox + i * nelems_in_cache_line;
c_store<nelems>(c_mem, c_reg);
c_reg = pzero(c_reg);
write_c<i + 1, um_vecs, idx, nelems>(cox);
}
}
// update c matrix
template <int pow, int max_b_unroll, int count, int a_unroll, int b_unroll, int idx>
EIGEN_ALWAYS_INLINE std::enable_if_t<(pow > (max_b_unroll << 1)) || (count > (pow + 1) / 2 + 1)>
c_update(Scalar *&co1, Scalar *&co2)
{
EIGEN_UNUSED_VARIABLE(co1);
EIGEN_UNUSED_VARIABLE(co2);
}
/* C update loop structure.
* co2 = co1 + ldc;
*
* auto &alpha_reg = zmm[0];
* if (!is_alpha1) alpha_reg = pload1<vec>(alpha);
*
* int idx = 0;
* for (pow = 1; pow <= 8; pow <<= 1) {
*
* if (b_unroll >= pow) {
* for (count = 1; count < (pow + 1) / 2 + 1; count++) {
* if (pow >= 4) co2 += ldc;
*
* const Scalar *cox = (idx == 0) ? co1 : co2;
*
* const int um_vecs = div_up(a_unroll, nelems_in_cache_line);
* scale_load_c<0, um_vecs, idx, a_unroll>(cox, alpha_reg);
* write_c<0, um_vecs, idx, a_unroll>(cox);
*
* idx++;
* }
* }
* }
*
* if (b_unroll == 1)
* co1 += ldc;
* else
* co1 = co2 + ldc;
*/
template <int pow, int max_b_unroll, int count, int a_unroll, int b_unroll, int idx>
EIGEN_ALWAYS_INLINE std::enable_if_t<(pow <= (max_b_unroll << 1)) && (count <= (pow + 1) / 2 + 1)>
c_update(Scalar *&co1, Scalar *&co2)
{
const bool first_call = idx == 0;
auto &alpha_reg = zmm[0];
if (first_call) {
co2 = co1 + ldc;
if (!is_alpha1) alpha_reg = pload1<vec>(alpha);
}
if (pow < (max_b_unroll << 1) && pow <= b_unroll) {
if (count < (pow + 1) / 2 + 1) {
if (pow >= 4) co2 += ldc;
Scalar *cox = idx == 0 ? co1 : co2;
const int um_vecs = div_up(a_unroll, nelems_in_cache_line);
scale_load_c<0, um_vecs, idx, a_unroll>(cox, alpha_reg);
write_c<0, um_vecs, idx, a_unroll>(cox);
// Go to the next count and next idx.
c_update<pow, max_b_unroll, count + 1, a_unroll, b_unroll, idx + 1>(co1, co2);
} else {
// Go to the next pow and reset count.
c_update<pow << 1, max_b_unroll, 1, a_unroll, b_unroll, idx>(co1, co2);
}
} else {
if (b_unroll == 1)
co1 += ldc;
else
co1 = co2 + ldc;
}
}
// compute
template <int um, int um_vecs, int idx, int uk, bool fetch_x, bool ktail>
EIGEN_ALWAYS_INLINE std::enable_if_t<(um > um_vecs)>
compute(const Scalar *ao, const Scalar *bo, int &fetchA_idx, int &fetchB_idx, vec &b_reg)
{
EIGEN_UNUSED_VARIABLE(ao);
EIGEN_UNUSED_VARIABLE(bo);
EIGEN_UNUSED_VARIABLE(fetchA_idx);
EIGEN_UNUSED_VARIABLE(fetchB_idx);
EIGEN_UNUSED_VARIABLE(b_reg);
}
template <int um, int um_vecs, int idx, int uk, bool fetch_x, bool ktail>
EIGEN_ALWAYS_INLINE std::enable_if_t<(um <= um_vecs)>
compute(const Scalar *ao, const Scalar *bo, int &fetchA_idx, int &fetchB_idx, vec &b_reg)
{
if (um < um_vecs) {
auto &c_reg = zmm[c_regs[um + idx * 3]];
auto &a_reg = zmm[a_regs[um + (uk % 2) * 3]];
vfmadd(c_reg, a_reg, b_reg);
if (!fetch_x && um == 0 && (((idx == 0 || idx == 6) && (uk % 2 == 0 || is_f64 || ktail)) || (idx == 3 && (uk % 2 == 1 || is_f64 || ktail)))) {
prefetch_a(ao + nelems_in_cache_line * fetchA_idx);
fetchA_idx++;
}
if (um == 0 && idx == 1 && (uk % 2 == 0 || is_f64 || ktail)) {
prefetch_b(bo + nelems_in_cache_line * fetchB_idx);
fetchB_idx++;
}
compute<um + 1, um_vecs, idx, uk, fetch_x, ktail>(ao, bo, fetchA_idx, fetchB_idx, b_reg);
}
}
// load_a
template <int um, int um_vecs, int uk, int nelems, bool ktail>
EIGEN_ALWAYS_INLINE std::enable_if_t<(um > um_vecs)>
load_a(const Scalar *ao)
{
EIGEN_UNUSED_VARIABLE(ao);
}
template <int um, int um_vecs, int uk, int nelems, bool ktail>
EIGEN_ALWAYS_INLINE std::enable_if_t<(um <= um_vecs)>
load_a(const Scalar *ao)
{
if (um < um_vecs) {
auto &a_reg = zmm[a_regs[um + (uk % 2) * 3]];
#ifndef EIGEN_ARCH_AVX512_GEMM_KERNEL_USE_LESS_A_REGS
const Scalar *a_addr = ao + nelems * (1 + !ktail + uk) + nelems_in_cache_line * um - a_shift;
#else
const Scalar *a_addr = ao + nelems * (1 + uk) + nelems_in_cache_line * um - a_shift;
#endif
a_load<nelems>(a_reg, a_addr);
load_a<um + 1, um_vecs, uk, nelems, ktail>(ao);
}
}
template<int uk, int pow, int count, int um_vecs, int b_unroll, bool ktail, bool fetch_x, bool c_fetch>
EIGEN_ALWAYS_INLINE std::enable_if_t<(count > (pow + 1) / 2)>
innerkernel_1pow(const Scalar *&aa, const Scalar * const &ao, const Scalar * const &bo, Scalar *&co2, int &fetchA_idx, int &fetchB_idx)
{
EIGEN_UNUSED_VARIABLE(aa);
EIGEN_UNUSED_VARIABLE(ao);
EIGEN_UNUSED_VARIABLE(bo);
EIGEN_UNUSED_VARIABLE(co2);
EIGEN_UNUSED_VARIABLE(fetchA_idx);
EIGEN_UNUSED_VARIABLE(fetchB_idx);
}
template<int uk, int pow, int count, int um_vecs, int b_unroll, bool ktail, bool fetch_x, bool c_fetch>
EIGEN_ALWAYS_INLINE std::enable_if_t<(count <= (pow + 1) / 2)>
innerkernel_1pow(const Scalar *&aa, const Scalar * const &ao, const Scalar * const &bo, Scalar *&co2, int &fetchA_idx, int &fetchB_idx)
{
const int idx = (pow / 2) + count;
if (count < (pow + 1) / 2) {
auto &b_reg = zmm[b_regs[idx % 2]];
if (fetch_x && uk == 3 && idx == 0) prefetch_x(aa);
if (fetch_x && uk == 3 && idx == 4) aa += 8;
if (b_unroll >= pow) {
compute<0, um_vecs, idx, uk, fetch_x, ktail>(ao, bo, fetchA_idx, fetchB_idx, b_reg);
#ifndef EIGEN_ARCH_AVX512_GEMM_KERNEL_USE_LESS_B_REGS
const Scalar *b_addr = bo + b_unroll * uk + idx + 1 + (b_unroll > 1) - b_shift;
#else
const Scalar *b_addr = bo + b_unroll * uk + idx + 1 - b_shift;
#endif
b_load(b_reg, b_addr);
}
// Go to the next count.
innerkernel_1pow<uk, pow, count + 1, um_vecs, b_unroll, ktail, fetch_x, c_fetch>(aa, ao, bo, co2, fetchA_idx, fetchB_idx);
} else {
// Maybe prefetch C data after count-loop.
if (pow == 2 && c_fetch) {
if (uk % 3 == 0 && uk > 0) {
co2 += ldc;
} else {
prefetch_c(co2 + (uk % 3) * nelems_in_cache_line);
}
}
}
}
template<int uk, int max_b_unroll, int a_unroll, int b_unroll, bool ktail, bool fetch_x, bool c_fetch>
EIGEN_ALWAYS_INLINE void innerkernel_1uk(const Scalar *&aa, const Scalar * const &ao, const Scalar * const &bo, Scalar *&co2, int &fetchA_idx, int &fetchB_idx)
{
const int um_vecs = div_up(a_unroll, nelems_in_cache_line);
if (max_b_unroll >= 1) innerkernel_1pow<uk, 1, 0, um_vecs, b_unroll, ktail, fetch_x, c_fetch>(aa, ao, bo, co2, fetchA_idx, fetchB_idx);
if (max_b_unroll >= 2) innerkernel_1pow<uk, 2, 0, um_vecs, b_unroll, ktail, fetch_x, c_fetch>(aa, ao, bo, co2, fetchA_idx, fetchB_idx);
if (max_b_unroll >= 4) innerkernel_1pow<uk, 4, 0, um_vecs, b_unroll, ktail, fetch_x, c_fetch>(aa, ao, bo, co2, fetchA_idx, fetchB_idx);
if (max_b_unroll >= 8) innerkernel_1pow<uk, 8, 0, um_vecs, b_unroll, ktail, fetch_x, c_fetch>(aa, ao, bo, co2, fetchA_idx, fetchB_idx);
// Load A after pow-loop.
load_a<0, um_vecs, uk, a_unroll, ktail>(ao);
}
/* Inner kernel loop structure.
* for (int uk = 0; uk < kfactor; uk++) {
* int idx = 0;
*
* for (pow = 1; pow < max_b_unroll << 1; pow <<= 1) {
* for (int count = 0; count < (pow + 1) / 2; count++) {
* auto &b_reg = zmm[b_regs[idx % 2]];
*
* if (fetch_x && uk == 3 && idx == 0) prefetch_x(aa);
* if (fetch_x && uk == 3 && idx == 4) aa += 8;
*
* if (b_unroll >= pow) {
* compute<0, um_vecs, idx, uk, fetchx, ktail>(ao, bo, fetchA_idx, fetchB_idx, b_reg);
*
* const Scalar *b_addr = bo + b_unroll * uk + idx + 1 + (b_unroll > 1) - b_shift ;
* b_load(b_reg, b_addr);
* }
* idx++;
* }
*
* Maybe prefetch C data.
* if (pow == 2 && c_fetch) {
* if (uk % 3 == 0 && uk > 0) {
* co2 += ldc;
* } else {
* prefetch_c(co2 + (uk % 3) * nelems_in_cache_line);
* }
* }
* }
*
* Load A.
* load_a<0, um_vecs, uk, ktail, a_unroll>(ao);
* }
*
* Advance A/B pointers after uk-loop.
* ao += a_unroll * kfactor;
* bo += b_unroll * kfactor;
*/
template <int a_unroll, int b_unroll, int k_factor, int max_b_unroll, int max_k_factor, bool c_fetch>
EIGEN_ALWAYS_INLINE void innerkernel(const Scalar *&aa, const Scalar *&ao, const Scalar *&bo, Scalar *&co2)
{
int fetchA_idx = 0;
int fetchB_idx = 0;
const bool fetch_x = k_factor == max_k_factor;
const bool ktail = k_factor == 1;
static_assert(k_factor <= 4 && k_factor > 0,
"innerkernel maximum k_factor supported is 4");
if (k_factor > 0) innerkernel_1uk<0, max_b_unroll, a_unroll, b_unroll, ktail, fetch_x, c_fetch>(aa, ao, bo, co2, fetchA_idx, fetchB_idx);
if (k_factor > 1) innerkernel_1uk<1, max_b_unroll, a_unroll, b_unroll, ktail, fetch_x, c_fetch>(aa, ao, bo, co2, fetchA_idx, fetchB_idx);
if (k_factor > 2) innerkernel_1uk<2, max_b_unroll, a_unroll, b_unroll, ktail, fetch_x, c_fetch>(aa, ao, bo, co2, fetchA_idx, fetchB_idx);
if (k_factor > 3) innerkernel_1uk<3, max_b_unroll, a_unroll, b_unroll, ktail, fetch_x, c_fetch>(aa, ao, bo, co2, fetchA_idx, fetchB_idx);
// Advance A/B pointers after uk-loop.
ao += a_unroll * k_factor;
bo += b_unroll * k_factor;
}
template <int a_unroll, int b_unroll, int max_b_unroll>
EIGEN_ALWAYS_INLINE void kloop(const Scalar *&aa, const Scalar *&ao, const Scalar *&bo, Scalar *&co1, Scalar *&co2)
{
const int um_vecs = div_up(a_unroll, nelems_in_cache_line);
#ifndef EIGEN_ARCH_AVX512_GEMM_KERNEL_USE_LESS_A_REGS
a_loads<0, 2, 0, um_vecs, a_unroll>(ao);
#else
a_loads<0, 1, 0, um_vecs, a_unroll>(ao);
#endif
b_load(zmm[b_regs[0]], bo - b_shift + 0);
#ifndef EIGEN_ARCH_AVX512_GEMM_KERNEL_USE_LESS_B_REGS
b_load(zmm[b_regs[1]], bo - b_shift + 1);
#endif
#ifndef SECOND_FETCH
prefetch_cs<0, max_b_unroll, 0, um_vecs, a_unroll, b_unroll>(co1, co2);
#endif // SECOND_FETCH
// Unrolling k-loop by a factor of 4.
const int max_k_factor = 4;
int64_t loop_count = k / max_k_factor;
if (loop_count > 0) {
#ifdef SECOND_FETCH
loop_count -= SECOND_FETCH;
#endif
while (loop_count > 0) {
innerkernel<a_unroll, b_unroll, max_k_factor, max_b_unroll, max_k_factor, 0>(aa, ao, bo, co2);
loop_count--;
}
#ifdef SECOND_FETCH
co2 = co1 + nelems_in_cache_line - 1;
loop_count += b_unroll;
while (loop_count > 0) {
innerkernel<a_unroll, b_unroll, max_k_factor, max_b_unroll, max_k_factor, 1>(aa, ao, bo, co2);
loop_count--;
}
loop_count += SECOND_FETCH - b_unroll;
while (loop_count > 0) {
innerkernel<a_unroll, b_unroll, max_k_factor, max_b_unroll, max_k_factor, 0>(aa, ao, bo, co2);
loop_count--;
}
#endif
}
// k-loop remainder handling.
loop_count = k % max_k_factor;
while (loop_count > 0) {
innerkernel<a_unroll, b_unroll, 1, max_b_unroll, max_k_factor, 0>(aa, ao, bo, co2);
loop_count--;
}
// Update C matrix.
c_update<1, max_b_unroll, 1, a_unroll, b_unroll, 0>(co1, co2);
}
template <int a_unroll, int b_unroll, int max_b_unroll>
EIGEN_ALWAYS_INLINE void nloop(const Scalar *&aa, const Scalar *&ao, const Scalar *&bo, Scalar *&co1, Scalar *&co2)
{
// Set A matrix pointer.
ao = a + a_off * a_unroll;
// Set B matrix pointer if needed.
bo += b_unroll * b_off;
kloop<a_unroll, b_unroll, max_b_unroll>(aa, ao, bo, co1, co2);
// Advance B matrix pointer if needed.
bo += b_unroll * (b_stride - k - b_off);
// Advance prefetch A pointer.
aa += 16;
}
template <int a_unroll, int max_a_unroll, int max_b_unroll>
EIGEN_ALWAYS_INLINE void mloop(const Scalar *&ao, const Scalar *&bo, Scalar *&co1, Scalar *&co2)
{
// Set prefetch A pointers.
const Scalar *aa = a + a_unroll * a_stride;
// Set C matrix pointers.
co1 = c;
if (a_unroll >= max_a_unroll) co2 = c + 2 * ldc;
c += a_unroll;
// Set B matrix pointer.
bo = b;
// Main n-loop.
for (int64_t i = n / max_b_unroll; i > 0; i--)
nloop<a_unroll, max_b_unroll, max_b_unroll>(aa, ao, bo, co1, co2);
// n-remainders.
if (n & 4 && max_b_unroll > 4) nloop<a_unroll, 4, max_b_unroll>(aa, ao, bo, co1, co2);
#if 0
if (n & 2 && max_b_unroll > 2) nloop<a_unroll, 2, max_b_unroll>(aa, ao, bo, co1, co2);
if (n & 1 && max_b_unroll > 1) nloop<a_unroll, 1, max_b_unroll>(aa, ao, bo, co1, co2);
#else
// Copy kernels don't support tails of n = 2 for single/double precision.
// Loop over ones.
int n_rem = 2 * ((n & 2) != 0) + 1 * ((n & 1) != 0);
while (n_rem > 0) {nloop<a_unroll, 1, max_b_unroll>(aa, ao, bo, co1, co2); n_rem--;}
#endif
// Advance A matrix pointer.
a = ao + a_unroll * (a_stride - k - a_off);
}
// Compute kernel unrolling C matrix by max_a_unroll x max_b_unroll.
template <int max_a_unroll, int max_b_unroll>
EIGEN_ALWAYS_INLINE void compute_kern()
{
a -= -a_shift;
b -= -b_shift;
const Scalar *ao = nullptr;
const Scalar *bo = nullptr;
Scalar *co1 = nullptr;
Scalar *co2 = nullptr;
// Main m-loop.
for (; m >= max_a_unroll; m -= max_a_unroll)
mloop<max_a_unroll, max_a_unroll, max_b_unroll>(ao, bo, co1, co2);
// m-remainders.
if (m & 32 && max_a_unroll > 32) mloop<32, max_a_unroll, max_b_unroll>(ao, bo, co1, co2);
if (m & 16 && max_a_unroll > 16) mloop<16, max_a_unroll, max_b_unroll>(ao, bo, co1, co2);
if (m & 8 && max_a_unroll > 8) mloop< 8, max_a_unroll, max_b_unroll>(ao, bo, co1, co2);
if (m & 4 && max_a_unroll > 4) mloop< 4, max_a_unroll, max_b_unroll>(ao, bo, co1, co2);
if (m & 2 && max_a_unroll > 2 && is_f64) mloop< 2, max_a_unroll, max_b_unroll>(ao, bo, co1, co2);
if (m & 1 && max_a_unroll > 1 && is_f64) mloop< 1, max_a_unroll, max_b_unroll>(ao, bo, co1, co2);
// Copy kernels don't support tails of m = 2 for single precision.
// Loop over ones.
if (is_f32) {
int m_rem = 2 * ((m & 2) != 0) + 1 * ((m & 1) != 0);
while (m_rem > 0) {mloop< 1, max_a_unroll, max_b_unroll>(ao, bo, co1, co2); m_rem--;}
}
}
};
// Compute kernel with max unroll support of:
// Single precision:
// max_a_unroll: 48, 32, 16, 8, 4, 2, 1
// max_b_unroll: 8, 4, 2, 1
// Double precision:
// max_a_unroll: 24, 16, 8, 4, 2, 1
// max_b_unroll: 8, 4, 2, 1
template <typename Scalar, int max_a_unroll, int max_b_unroll, bool is_alpha1, bool is_beta0>
EIGEN_DONT_INLINE void gemm_kern_avx512(int64_t *p_m, int64_t *p_n, int64_t *p_k,
Scalar *alpha, const Scalar *a, const Scalar *b, Scalar *c,
int64_t ldc, int64_t a_stride = -1, int64_t b_stride = -1,
int64_t a_off = 0, int64_t b_off = 0)
{
if (a_stride == -1) a_stride = *p_k;
if (b_stride == -1) b_stride = *p_k;
gemm_class<Scalar> g(*p_m, *p_n, *p_k, ldc, alpha, a, b, c,
is_alpha1, is_beta0, a_stride, b_stride, a_off, b_off);
g.template compute_kern<max_a_unroll, max_b_unroll>();
}
template <typename a_t, typename b_t, typename c_t>
bool gemm_kernel(int64_t m, int64_t n, int64_t k, c_t alpha,
const a_t *a, const b_t *b, c_t *c, int64_t ldc,
int64_t a_stride = -1, int64_t b_stride = -1,
int64_t a_off = 0, int64_t b_off = 0)
{
EIGEN_UNUSED_VARIABLE(m);
EIGEN_UNUSED_VARIABLE(n);
EIGEN_UNUSED_VARIABLE(k);
EIGEN_UNUSED_VARIABLE(alpha);
EIGEN_UNUSED_VARIABLE(a);
EIGEN_UNUSED_VARIABLE(b);
EIGEN_UNUSED_VARIABLE(c);
EIGEN_UNUSED_VARIABLE(ldc);
EIGEN_UNUSED_VARIABLE(a_stride);
EIGEN_UNUSED_VARIABLE(b_stride);
EIGEN_UNUSED_VARIABLE(a_off);
EIGEN_UNUSED_VARIABLE(b_off);
return false;
}
template <>
bool gemm_kernel(int64_t m, int64_t n, int64_t k, float alpha,
const float *a, const float *b, float *c, int64_t ldc,
int64_t a_stride, int64_t b_stride,
int64_t a_off, int64_t b_off)
{
if (alpha == 1.f)
gemm_kern_avx512<float, 48, 8, true, false>(&m, &n, &k, &alpha, a, b, c,
ldc, a_stride, b_stride, a_off, b_off);
else
gemm_kern_avx512<float, 48, 8, false, false>(&m, &n, &k, &alpha, a, b, c,
ldc, a_stride, b_stride, a_off, b_off);
return true;
}
template <>
bool gemm_kernel(int64_t m, int64_t n, int64_t k, double alpha,
const double *a, const double *b, double *c, int64_t ldc,
int64_t a_stride, int64_t b_stride,
int64_t a_off, int64_t b_off)
{
if (alpha == 1.)
gemm_kern_avx512<double, 24, 8, true, false>(&m, &n, &k, &alpha, a, b, c,
ldc, a_stride, b_stride, a_off, b_off);
else
gemm_kern_avx512<double, 24, 8, false, false>(&m, &n, &k, &alpha, a, b, c,
ldc, a_stride, b_stride, a_off, b_off);
return true;
}
template<typename Scalar, typename Index, typename DataMapper, int nr, int StorageOrder, bool Conjugate, bool PanelMode>
struct gemm_pack_rhs;
template<typename Scalar, typename Index, typename DataMapper, bool Conjugate, bool PanelMode>
struct gemm_pack_rhs<Scalar, Index, DataMapper, 8, ColMajor, Conjugate, PanelMode>
{
typedef typename packet_traits<Scalar>::type Packet;
typedef typename DataMapper::LinearMapper LinearMapper;
enum { PacketSize = packet_traits<Scalar>::size };
EIGEN_DONT_INLINE void operator()(Scalar* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride=0, Index offset=0);
};
template<typename Scalar, typename Index, typename DataMapper, bool Conjugate, bool PanelMode>
EIGEN_DONT_INLINE void gemm_pack_rhs<Scalar, Index, DataMapper, 8, ColMajor, Conjugate, PanelMode>
::operator()(Scalar* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride, Index offset)
{
constexpr int nr = 8;
EIGEN_ASM_COMMENT("EIGEN PRODUCT PACK RHS COLMAJOR");
EIGEN_UNUSED_VARIABLE(stride);
EIGEN_UNUSED_VARIABLE(offset);
eigen_assert(((!PanelMode) && stride==0 && offset==0) || (PanelMode && stride>=depth && offset<=stride));
conj_if<NumTraits<Scalar>::IsComplex && Conjugate> cj;
Index packet_cols8 = nr>=8 ? (cols/8) * 8 : 0;
Index packet_cols4 = nr>=4 ? (cols/4) * 4 : 0;
Index count = 0;
const Index peeled_k = (depth/PacketSize)*PacketSize;
if(nr>=8)
{
for(Index j2=0; j2<packet_cols8; j2+=8)
{
// skip what we have before
if(PanelMode) count += 8 * offset;
const LinearMapper dm0 = rhs.getLinearMapper(0, j2+0);
const LinearMapper dm1 = rhs.getLinearMapper(0, j2+1);
const LinearMapper dm2 = rhs.getLinearMapper(0, j2+2);
const LinearMapper dm3 = rhs.getLinearMapper(0, j2+3);
const LinearMapper dm4 = rhs.getLinearMapper(0, j2+4);
const LinearMapper dm5 = rhs.getLinearMapper(0, j2+5);
const LinearMapper dm6 = rhs.getLinearMapper(0, j2+6);
const LinearMapper dm7 = rhs.getLinearMapper(0, j2+7);
Index k=0;
if((PacketSize%8)==0) // TODO enable vectorized transposition for PacketSize==4
{
for(; k<peeled_k; k+=PacketSize) {
PacketBlock<Packet,(PacketSize%8)==0?8:PacketSize> kernel;
kernel.packet[0] = dm0.template loadPacket<Packet>(k);
kernel.packet[1] = dm1.template loadPacket<Packet>(k);
kernel.packet[2] = dm2.template loadPacket<Packet>(k);
kernel.packet[3] = dm3.template loadPacket<Packet>(k);
kernel.packet[4] = dm4.template loadPacket<Packet>(k);
kernel.packet[5] = dm5.template loadPacket<Packet>(k);
kernel.packet[6] = dm6.template loadPacket<Packet>(k);
kernel.packet[7] = dm7.template loadPacket<Packet>(k);
ptranspose(kernel);
pstoreu(blockB+count+0*PacketSize, cj.pconj(kernel.packet[0]));
pstoreu(blockB+count+1*PacketSize, cj.pconj(kernel.packet[1%PacketSize]));
pstoreu(blockB+count+2*PacketSize, cj.pconj(kernel.packet[2%PacketSize]));
pstoreu(blockB+count+3*PacketSize, cj.pconj(kernel.packet[3%PacketSize]));
pstoreu(blockB+count+4*PacketSize, cj.pconj(kernel.packet[4%PacketSize]));
pstoreu(blockB+count+5*PacketSize, cj.pconj(kernel.packet[5%PacketSize]));
pstoreu(blockB+count+6*PacketSize, cj.pconj(kernel.packet[6%PacketSize]));
pstoreu(blockB+count+7*PacketSize, cj.pconj(kernel.packet[7%PacketSize]));
count+=8*PacketSize;
}
}
for(; k<depth; k++)
{
blockB[count+0] = cj(dm0(k));
blockB[count+1] = cj(dm1(k));
blockB[count+2] = cj(dm2(k));
blockB[count+3] = cj(dm3(k));
blockB[count+4] = cj(dm4(k));
blockB[count+5] = cj(dm5(k));
blockB[count+6] = cj(dm6(k));
blockB[count+7] = cj(dm7(k));
count += 8;
}
// skip what we have after
if(PanelMode) count += 8 * (stride-offset-depth);
}
}
if(nr>=4)
{
for(Index j2=packet_cols8; j2<packet_cols4; j2+=4)
{
// skip what we have before
if(PanelMode) count += 4 * offset;
const LinearMapper dm0 = rhs.getLinearMapper(0, j2 + 0);
const LinearMapper dm1 = rhs.getLinearMapper(0, j2 + 1);
const LinearMapper dm2 = rhs.getLinearMapper(0, j2 + 2);
const LinearMapper dm3 = rhs.getLinearMapper(0, j2 + 3);
Index k=0;
if((PacketSize%4)==0) // TODO enable vectorized transposition for PacketSize==2 ??
{
for(; k<peeled_k; k+=PacketSize) {
PacketBlock<Packet,(PacketSize%4)==0?4:PacketSize> kernel;
kernel.packet[0 ] = dm0.template loadPacket<Packet>(k);
kernel.packet[1%PacketSize] = dm1.template loadPacket<Packet>(k);
kernel.packet[2%PacketSize] = dm2.template loadPacket<Packet>(k);
kernel.packet[3%PacketSize] = dm3.template loadPacket<Packet>(k);
ptranspose(kernel);
pstoreu(blockB+count+0*PacketSize, cj.pconj(kernel.packet[0]));
pstoreu(blockB+count+1*PacketSize, cj.pconj(kernel.packet[1%PacketSize]));
pstoreu(blockB+count+2*PacketSize, cj.pconj(kernel.packet[2%PacketSize]));
pstoreu(blockB+count+3*PacketSize, cj.pconj(kernel.packet[3%PacketSize]));
count+=4*PacketSize;
}
}
for(; k<depth; k++)
{
blockB[count+0] = cj(dm0(k));
blockB[count+1] = cj(dm1(k));
blockB[count+2] = cj(dm2(k));
blockB[count+3] = cj(dm3(k));
count += 4;
}
// skip what we have after
if(PanelMode) count += 4 * (stride-offset-depth);
}
}
// copy the remaining columns one at a time (nr==1)
for(Index j2=packet_cols4; j2<cols; ++j2)
{
if(PanelMode) count += offset;
const LinearMapper dm0 = rhs.getLinearMapper(0, j2);
for(Index k=0; k<depth; k++)
{
blockB[count] = cj(dm0(k));
count += 1;
}
if(PanelMode) count += (stride-offset-depth);
}
}
} // namespace Eigen
} // namespace internal
#endif // GEMM_KERNEL_H

View File

@ -1432,6 +1432,7 @@ EIGEN_DEVICE_FUNC inline void ptranspose(PacketBlock<Packet16f, 16>& kernel) {
EIGEN_INSERT_8f_INTO_16f(OUTPUT[INDEX], INPUT[2 * INDEX], \ EIGEN_INSERT_8f_INTO_16f(OUTPUT[INDEX], INPUT[2 * INDEX], \
INPUT[2 * INDEX + STRIDE]); INPUT[2 * INDEX + STRIDE]);
template<bool for_trsm = false>
EIGEN_DEVICE_FUNC inline void ptranspose(PacketBlock<Packet16f, 8>& kernel) { EIGEN_DEVICE_FUNC inline void ptranspose(PacketBlock<Packet16f, 8>& kernel) {
__m512 T0 = _mm512_unpacklo_ps(kernel.packet[0],kernel.packet[1]); __m512 T0 = _mm512_unpacklo_ps(kernel.packet[0],kernel.packet[1]);
__m512 T1 = _mm512_unpackhi_ps(kernel.packet[0],kernel.packet[1]); __m512 T1 = _mm512_unpackhi_ps(kernel.packet[0],kernel.packet[1]);
@ -1451,6 +1452,26 @@ EIGEN_DEVICE_FUNC inline void ptranspose(PacketBlock<Packet16f, 8>& kernel) {
kernel.packet[6] = _mm512_castpd_ps(_mm512_unpacklo_pd(_mm512_castps_pd(T5),_mm512_castps_pd(T7))); kernel.packet[6] = _mm512_castpd_ps(_mm512_unpacklo_pd(_mm512_castps_pd(T5),_mm512_castps_pd(T7)));
kernel.packet[7] = _mm512_castpd_ps(_mm512_unpackhi_pd(_mm512_castps_pd(T5),_mm512_castps_pd(T7))); kernel.packet[7] = _mm512_castpd_ps(_mm512_unpackhi_pd(_mm512_castps_pd(T5),_mm512_castps_pd(T7)));
// Transpose for gemm is slightly different than trsm.
if (!for_trsm) {
T0 = _mm512_shuffle_f32x4(kernel.packet[0], kernel.packet[4], 0x44);
T1 = _mm512_shuffle_f32x4(kernel.packet[0], kernel.packet[4], 0xee);
T2 = _mm512_shuffle_f32x4(kernel.packet[1], kernel.packet[5], 0x44);
T3 = _mm512_shuffle_f32x4(kernel.packet[1], kernel.packet[5], 0xee);
T4 = _mm512_shuffle_f32x4(kernel.packet[2], kernel.packet[6], 0x44);
T5 = _mm512_shuffle_f32x4(kernel.packet[2], kernel.packet[6], 0xee);
T6 = _mm512_shuffle_f32x4(kernel.packet[3], kernel.packet[7], 0x44);
T7 = _mm512_shuffle_f32x4(kernel.packet[3], kernel.packet[7], 0xee);
kernel.packet[0] = _mm512_shuffle_f32x4(T0, T2, 0x88);
kernel.packet[2] = _mm512_shuffle_f32x4(T0, T2, 0xdd);
kernel.packet[1] = _mm512_shuffle_f32x4(T4, T6, 0x88);
kernel.packet[3] = _mm512_shuffle_f32x4(T4, T6, 0xdd);
kernel.packet[4] = _mm512_shuffle_f32x4(T1, T3, 0x88);
kernel.packet[6] = _mm512_shuffle_f32x4(T1, T3, 0xdd);
kernel.packet[5] = _mm512_shuffle_f32x4(T5, T7, 0x88);
kernel.packet[7] = _mm512_shuffle_f32x4(T5, T7, 0xdd);
} else {
T0 = _mm512_castpd_ps(_mm512_permutex_pd(_mm512_castps_pd(kernel.packet[4]), 0x4E)); T0 = _mm512_castpd_ps(_mm512_permutex_pd(_mm512_castps_pd(kernel.packet[4]), 0x4E));
T0 = _mm512_mask_blend_ps(0xF0F0, kernel.packet[0], T0); T0 = _mm512_mask_blend_ps(0xF0F0, kernel.packet[0], T0);
T4 = _mm512_castpd_ps(_mm512_permutex_pd(_mm512_castps_pd(kernel.packet[0]), 0x4E)); T4 = _mm512_castpd_ps(_mm512_permutex_pd(_mm512_castps_pd(kernel.packet[0]), 0x4E));
@ -1472,6 +1493,7 @@ EIGEN_DEVICE_FUNC inline void ptranspose(PacketBlock<Packet16f, 8>& kernel) {
kernel.packet[2] = T2; kernel.packet[3] = T3; kernel.packet[2] = T2; kernel.packet[3] = T3;
kernel.packet[4] = T4; kernel.packet[5] = T5; kernel.packet[4] = T4; kernel.packet[5] = T5;
kernel.packet[6] = T6; kernel.packet[7] = T7; kernel.packet[6] = T6; kernel.packet[7] = T7;
}
} }
EIGEN_DEVICE_FUNC inline void ptranspose(PacketBlock<Packet16f, 4>& kernel) { EIGEN_DEVICE_FUNC inline void ptranspose(PacketBlock<Packet16f, 4>& kernel) {
@ -1549,7 +1571,9 @@ EIGEN_DEVICE_FUNC inline void ptranspose(PacketBlock<Packet8d, 4>& kernel) {
PACK_OUTPUT_D(kernel.packet, tmp.packet, 3, 1); PACK_OUTPUT_D(kernel.packet, tmp.packet, 3, 1);
} }
template<bool for_trsm = false>
EIGEN_DEVICE_FUNC inline void ptranspose(PacketBlock<Packet8d, 8>& kernel) { EIGEN_DEVICE_FUNC inline void ptranspose(PacketBlock<Packet8d, 8>& kernel) {
// Transpose for trsm is the same as for gemm.
__m512d T0 = _mm512_unpacklo_pd(kernel.packet[0],kernel.packet[1]); __m512d T0 = _mm512_unpacklo_pd(kernel.packet[0],kernel.packet[1]);
__m512d T1 = _mm512_unpackhi_pd(kernel.packet[0],kernel.packet[1]); __m512d T1 = _mm512_unpackhi_pd(kernel.packet[0],kernel.packet[1]);
__m512d T2 = _mm512_unpacklo_pd(kernel.packet[2],kernel.packet[3]); __m512d T2 = _mm512_unpacklo_pd(kernel.packet[2],kernel.packet[3]);

View File

@ -198,7 +198,7 @@ public:
r.packet[5] = zmm.packet[packetIndexOffset + zmmStride*5]; r.packet[5] = zmm.packet[packetIndexOffset + zmmStride*5];
r.packet[6] = zmm.packet[packetIndexOffset + zmmStride*6]; r.packet[6] = zmm.packet[packetIndexOffset + zmmStride*6];
r.packet[7] = zmm.packet[packetIndexOffset + zmmStride*7]; r.packet[7] = zmm.packet[packetIndexOffset + zmmStride*7];
ptranspose(r); ptranspose<true>(r);
zmm.packet[packetIndexOffset + zmmStride*0] = r.packet[0]; zmm.packet[packetIndexOffset + zmmStride*0] = r.packet[0];
zmm.packet[packetIndexOffset + zmmStride*1] = r.packet[1]; zmm.packet[packetIndexOffset + zmmStride*1] = r.packet[1];
zmm.packet[packetIndexOffset + zmmStride*2] = r.packet[2]; zmm.packet[packetIndexOffset + zmmStride*2] = r.packet[2];

View File

@ -44,6 +44,34 @@ template<> EIGEN_STRONG_INLINE Packet8f preinterpret<Packet8f, Packet16f>(const
return _mm512_castps512_ps256(a); return _mm512_castps512_ps256(a);
} }
template<> EIGEN_STRONG_INLINE Packet4f preinterpret<Packet4f, Packet16f>(const Packet16f& a) {
return _mm512_castps512_ps128(a);
}
template<> EIGEN_STRONG_INLINE Packet4d preinterpret<Packet4d, Packet8d>(const Packet8d& a) {
return _mm512_castpd512_pd256(a);
}
template<> EIGEN_STRONG_INLINE Packet2d preinterpret<Packet2d, Packet8d>(const Packet8d& a) {
return _mm512_castpd512_pd128(a);
}
template<> EIGEN_STRONG_INLINE Packet16f preinterpret<Packet16f, Packet8f>(const Packet8f& a) {
return _mm512_castps256_ps512(a);
}
template<> EIGEN_STRONG_INLINE Packet16f preinterpret<Packet16f, Packet4f>(const Packet4f& a) {
return _mm512_castps128_ps512(a);
}
template<> EIGEN_STRONG_INLINE Packet8d preinterpret<Packet8d, Packet4d>(const Packet4d& a) {
return _mm512_castpd256_pd512(a);
}
template<> EIGEN_STRONG_INLINE Packet8d preinterpret<Packet8d, Packet2d>(const Packet2d& a) {
return _mm512_castpd128_pd512(a);
}
template<> EIGEN_STRONG_INLINE Packet16f preinterpret<Packet16f, Packet16f>(const Packet16f& a) { template<> EIGEN_STRONG_INLINE Packet16f preinterpret<Packet16f, Packet16f>(const Packet16f& a) {
return a; return a;
} }

View File

@ -8,9 +8,9 @@ namespace internal {
// Clang seems to excessively spill registers in the GEBP kernel on 32-bit arm. // Clang seems to excessively spill registers in the GEBP kernel on 32-bit arm.
// Here we specialize gebp_traits to eliminate these register spills. // Here we specialize gebp_traits to eliminate these register spills.
// See #2138. // See #2138.
template<> template<bool UnitResIncr>
struct gebp_traits <float,float,false,false,Architecture::NEON,GEBPPacketFull> struct gebp_traits <float,float,UnitResIncr,false,false,Architecture::NEON,GEBPPacketFull>
: gebp_traits<float,float,false,false,Architecture::Generic,GEBPPacketFull> : gebp_traits<float,float,UnitResIncr,false,false,Architecture::Generic,GEBPPacketFull>
{ {
EIGEN_STRONG_INLINE void acc(const AccPacket& c, const ResPacket& alpha, ResPacket& r) const EIGEN_STRONG_INLINE void acc(const AccPacket& c, const ResPacket& alpha, ResPacket& r) const
{ {
@ -43,9 +43,9 @@ struct gebp_traits <float,float,false,false,Architecture::NEON,GEBPPacketFull>
#if EIGEN_ARCH_ARM64 #if EIGEN_ARCH_ARM64
template<> template<bool UnitResIncr>
struct gebp_traits <float,float,false,false,Architecture::NEON,GEBPPacketFull> struct gebp_traits <float,float,UnitResIncr,false,false,Architecture::NEON,GEBPPacketFull>
: gebp_traits<float,float,false,false,Architecture::Generic,GEBPPacketFull> : gebp_traits<float,float,UnitResIncr,false,false,Architecture::Generic,GEBPPacketFull>
{ {
typedef float RhsPacket; typedef float RhsPacket;
typedef float32x4_t RhsPacketx4; typedef float32x4_t RhsPacketx4;
@ -108,9 +108,9 @@ struct gebp_traits <float,float,false,false,Architecture::NEON,GEBPPacketFull>
}; };
template<> template<bool UnitResIncr>
struct gebp_traits <double,double,false,false,Architecture::NEON> struct gebp_traits <double,double,UnitResIncr,false,false,Architecture::NEON>
: gebp_traits<double,double,false,false,Architecture::Generic> : gebp_traits<double,double,UnitResIncr,false,false,Architecture::Generic>
{ {
typedef double RhsPacket; typedef double RhsPacket;

View File

@ -285,6 +285,10 @@ template<> EIGEN_STRONG_INLINE Packet4i padd<Packet4i>(const Packet4i& a, const
template<> EIGEN_STRONG_INLINE Packet16b padd<Packet16b>(const Packet16b& a, const Packet16b& b) { return _mm_or_si128(a,b); } template<> EIGEN_STRONG_INLINE Packet16b padd<Packet16b>(const Packet16b& a, const Packet16b& b) { return _mm_or_si128(a,b); }
template<typename Packet> EIGEN_STRONG_INLINE Packet padds(const Packet& a, const Packet& b);
template<> EIGEN_STRONG_INLINE Packet4f padds<Packet4f>(const Packet4f& a, const Packet4f& b) { return _mm_add_ss(a,b); }
template<> EIGEN_STRONG_INLINE Packet2d padds<Packet2d>(const Packet2d& a, const Packet2d& b) { return _mm_add_sd(a,b); }
template<> EIGEN_STRONG_INLINE Packet4f psub<Packet4f>(const Packet4f& a, const Packet4f& b) { return _mm_sub_ps(a,b); } template<> EIGEN_STRONG_INLINE Packet4f psub<Packet4f>(const Packet4f& a, const Packet4f& b) { return _mm_sub_ps(a,b); }
template<> EIGEN_STRONG_INLINE Packet2d psub<Packet2d>(const Packet2d& a, const Packet2d& b) { return _mm_sub_pd(a,b); } template<> EIGEN_STRONG_INLINE Packet2d psub<Packet2d>(const Packet2d& a, const Packet2d& b) { return _mm_sub_pd(a,b); }
template<> EIGEN_STRONG_INLINE Packet4i psub<Packet4i>(const Packet4i& a, const Packet4i& b) { return _mm_sub_epi32(a,b); } template<> EIGEN_STRONG_INLINE Packet4i psub<Packet4i>(const Packet4i& a, const Packet4i& b) { return _mm_sub_epi32(a,b); }
@ -370,6 +374,10 @@ template<> EIGEN_STRONG_INLINE Packet4f pnmadd(const Packet4f& a, const Packet4f
template<> EIGEN_STRONG_INLINE Packet2d pnmadd(const Packet2d& a, const Packet2d& b, const Packet2d& c) { return _mm_fnmadd_pd(a,b,c); } template<> EIGEN_STRONG_INLINE Packet2d pnmadd(const Packet2d& a, const Packet2d& b, const Packet2d& c) { return _mm_fnmadd_pd(a,b,c); }
template<> EIGEN_STRONG_INLINE Packet4f pnmsub(const Packet4f& a, const Packet4f& b, const Packet4f& c) { return _mm_fnmsub_ps(a,b,c); } template<> EIGEN_STRONG_INLINE Packet4f pnmsub(const Packet4f& a, const Packet4f& b, const Packet4f& c) { return _mm_fnmsub_ps(a,b,c); }
template<> EIGEN_STRONG_INLINE Packet2d pnmsub(const Packet2d& a, const Packet2d& b, const Packet2d& c) { return _mm_fnmsub_pd(a,b,c); } template<> EIGEN_STRONG_INLINE Packet2d pnmsub(const Packet2d& a, const Packet2d& b, const Packet2d& c) { return _mm_fnmsub_pd(a,b,c); }
template<typename Packet> EIGEN_STRONG_INLINE Packet pmadds(const Packet& a, const Packet& b, const Packet& c);
template<> EIGEN_STRONG_INLINE Packet4f pmadds<Packet4f>(const Packet4f& a, const Packet4f& b, const Packet4f& c) { return _mm_fmadd_ss(a,b,c); }
template<> EIGEN_STRONG_INLINE Packet2d pmadds<Packet2d>(const Packet2d& a, const Packet2d& b, const Packet2d& c) { return _mm_fmadd_sd(a,b,c); }
#endif #endif
#ifdef EIGEN_VECTORIZE_SSE4_1 #ifdef EIGEN_VECTORIZE_SSE4_1
@ -746,6 +754,15 @@ template<> EIGEN_STRONG_INLINE Packet16b ploadu<Packet16b>(const bool* from)
return _mm_loadu_si128(reinterpret_cast<const __m128i*>(from)); return _mm_loadu_si128(reinterpret_cast<const __m128i*>(from));
} }
// Load lower part of packet zero extending.
template<typename Packet> EIGEN_STRONG_INLINE Packet ploadl(const typename unpacket_traits<Packet>::type* from);
template<> EIGEN_STRONG_INLINE Packet4f ploadl<Packet4f>(const float* from) { EIGEN_DEBUG_UNALIGNED_LOAD return _mm_castpd_ps(_mm_load_sd(reinterpret_cast<const double*>(from))); }
template<> EIGEN_STRONG_INLINE Packet2d ploadl<Packet2d>(const double* from) { EIGEN_DEBUG_UNALIGNED_LOAD return _mm_load_sd(from); }
// Load scalar
template<typename Packet> EIGEN_STRONG_INLINE Packet ploads(const typename unpacket_traits<Packet>::type* from);
template<> EIGEN_STRONG_INLINE Packet4f ploads<Packet4f>(const float* from) { EIGEN_DEBUG_UNALIGNED_LOAD return _mm_load_ss(from); }
template<> EIGEN_STRONG_INLINE Packet2d ploads<Packet2d>(const double* from) { EIGEN_DEBUG_UNALIGNED_LOAD return _mm_load_sd(from); }
template<> EIGEN_STRONG_INLINE Packet4f ploaddup<Packet4f>(const float* from) template<> EIGEN_STRONG_INLINE Packet4f ploaddup<Packet4f>(const float* from)
{ {
@ -787,6 +804,14 @@ template<> EIGEN_STRONG_INLINE void pstoreu<float>(float* to, const Packet4f&
template<> EIGEN_STRONG_INLINE void pstoreu<int>(int* to, const Packet4i& from) { EIGEN_DEBUG_UNALIGNED_STORE _mm_storeu_si128(reinterpret_cast<__m128i*>(to), from); } template<> EIGEN_STRONG_INLINE void pstoreu<int>(int* to, const Packet4i& from) { EIGEN_DEBUG_UNALIGNED_STORE _mm_storeu_si128(reinterpret_cast<__m128i*>(to), from); }
template<> EIGEN_STRONG_INLINE void pstoreu<bool>(bool* to, const Packet16b& from) { EIGEN_DEBUG_ALIGNED_STORE _mm_storeu_si128(reinterpret_cast<__m128i*>(to), from); } template<> EIGEN_STRONG_INLINE void pstoreu<bool>(bool* to, const Packet16b& from) { EIGEN_DEBUG_ALIGNED_STORE _mm_storeu_si128(reinterpret_cast<__m128i*>(to), from); }
template<typename Scalar, typename Packet> EIGEN_STRONG_INLINE void pstorel(Scalar* to, const Packet& from);
template<> EIGEN_STRONG_INLINE void pstorel(float* to, const Packet4f& from) { EIGEN_DEBUG_UNALIGNED_STORE _mm_storel_pi(reinterpret_cast<__m64*>(to), from); }
template<> EIGEN_STRONG_INLINE void pstorel(double* to, const Packet2d& from) { EIGEN_DEBUG_UNALIGNED_STORE _mm_storel_pd(to, from); }
template<typename Scalar, typename Packet> EIGEN_STRONG_INLINE void pstores(Scalar* to, const Packet& from);
template<> EIGEN_STRONG_INLINE void pstores(float* to, const Packet4f& from) { EIGEN_DEBUG_UNALIGNED_STORE _mm_store_ss(to, from); }
template<> EIGEN_STRONG_INLINE void pstores(double* to, const Packet2d& from) { EIGEN_DEBUG_UNALIGNED_STORE _mm_store_sd(to, from); }
template<> EIGEN_DEVICE_FUNC inline Packet4f pgather<float, Packet4f>(const float* from, Index stride) template<> EIGEN_DEVICE_FUNC inline Packet4f pgather<float, Packet4f>(const float* from, Index stride)
{ {
return _mm_set_ps(from[3*stride], from[2*stride], from[1*stride], from[0*stride]); return _mm_set_ps(from[3*stride], from[2*stride], from[1*stride], from[0*stride]);

View File

@ -71,6 +71,14 @@ template<> EIGEN_STRONG_INLINE Packet2d pcast<Packet4f, Packet2d>(const Packet4f
return _mm_cvtps_pd(a); return _mm_cvtps_pd(a);
} }
template<> EIGEN_STRONG_INLINE Packet2d preinterpret<Packet2d, Packet4f>(const Packet4f& a) {
return _mm_castps_pd(a);
}
template<> EIGEN_STRONG_INLINE Packet4f preinterpret<Packet4f, Packet2d>(const Packet2d& a) {
return _mm_castpd_ps(a);
}
template<> EIGEN_STRONG_INLINE Packet4i preinterpret<Packet4i,Packet4f>(const Packet4f& a) { template<> EIGEN_STRONG_INLINE Packet4i preinterpret<Packet4i,Packet4f>(const Packet4f& a) {
return _mm_castps_si128(a); return _mm_castps_si128(a);
} }

View File

@ -2,6 +2,7 @@
// for linear algebra. // for linear algebra.
// //
// Copyright (C) 2008-2009 Gael Guennebaud <gael.guennebaud@inria.fr> // Copyright (C) 2008-2009 Gael Guennebaud <gael.guennebaud@inria.fr>
// Modifications Copyright (C) 2022 Intel Corporation
// //
// This Source Code Form is subject to the terms of the Mozilla // This Source Code Form is subject to the terms of the Mozilla
// Public License v. 2.0. If a copy of the MPL was not distributed // Public License v. 2.0. If a copy of the MPL was not distributed
@ -23,7 +24,7 @@ enum GEBPPacketSizeType {
GEBPPacketQuarter GEBPPacketQuarter
}; };
template<typename LhsScalar_, typename RhsScalar_, bool ConjLhs_=false, bool ConjRhs_=false, int Arch=Architecture::Target, int PacketSize_=GEBPPacketFull> template<typename LhsScalar_, typename RhsScalar_, bool UnitResIncr=false, bool ConjLhs_=false, bool ConjRhs_=false, int Arch=Architecture::Target, int PacketSize_=GEBPPacketFull>
class gebp_traits; class gebp_traits;
@ -125,7 +126,7 @@ inline void manage_caching_sizes(Action action, std::ptrdiff_t* l1, std::ptrdiff
template<typename LhsScalar, typename RhsScalar, int KcFactor, typename Index> template<typename LhsScalar, typename RhsScalar, int KcFactor, typename Index>
void evaluateProductBlockingSizesHeuristic(Index& k, Index& m, Index& n, Index num_threads = 1) void evaluateProductBlockingSizesHeuristic(Index& k, Index& m, Index& n, Index num_threads = 1)
{ {
typedef gebp_traits<LhsScalar,RhsScalar> Traits; typedef gebp_traits<LhsScalar,RhsScalar, true> Traits;
// Explanations: // Explanations:
// Let's recall that the product algorithms form mc x kc vertical panels A' on the lhs and // Let's recall that the product algorithms form mc x kc vertical panels A' on the lhs and
@ -416,7 +417,7 @@ struct packet_conditional<GEBPPacketHalf, T1, T2, T3> { typedef T2 type; };
* cplx*real : unpack rhs to constant packets, ... * cplx*real : unpack rhs to constant packets, ...
* real*cplx : load lhs as (a0,a0,a1,a1), and mul as usual * real*cplx : load lhs as (a0,a0,a1,a1), and mul as usual
*/ */
template<typename LhsScalar_, typename RhsScalar_, bool ConjLhs_, bool ConjRhs_, int Arch, int PacketSize_> template<typename LhsScalar_, typename RhsScalar_, bool UnitResIncr_, bool ConjLhs_, bool ConjRhs_, int Arch, int PacketSize_>
class gebp_traits class gebp_traits
{ {
public: public:
@ -429,6 +430,7 @@ public:
PACKET_DECL_COND_POSTFIX(_, Res, PacketSize_); PACKET_DECL_COND_POSTFIX(_, Res, PacketSize_);
enum { enum {
UnitResIncr = UnitResIncr_,
ConjLhs = ConjLhs_, ConjLhs = ConjLhs_,
ConjRhs = ConjRhs_, ConjRhs = ConjRhs_,
Vectorizable = unpacket_traits<LhsPacket_>::vectorizable && unpacket_traits<RhsPacket_>::vectorizable, Vectorizable = unpacket_traits<LhsPacket_>::vectorizable && unpacket_traits<RhsPacket_>::vectorizable,
@ -437,9 +439,17 @@ public:
ResPacketSize = Vectorizable ? unpacket_traits<ResPacket_>::size : 1, ResPacketSize = Vectorizable ? unpacket_traits<ResPacket_>::size : 1,
NumberOfRegisters = EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS, NumberOfRegisters = EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS,
IsReal = std::is_same<LhsScalar, RhsScalar>::value
&& (std::is_same<LhsScalar, float>::value
|| std::is_same<LhsScalar, double>::value),
// register block size along the N direction must be 1 or 4 // register block size along the N direction must be 1 or 4
#if defined(EIGEN_VECTORIZE_AVX512)
// AVX512 support nr = 8 for unit inner strides for result matrix.
nr = IsReal && Vectorizable && UnitResIncr ? 8 : 4,
#else
nr = 4, nr = 4,
#endif
// register block size along the M direction (currently, this one cannot be modified) // register block size along the M direction (currently, this one cannot be modified)
default_mr = (plain_enum_min(16, NumberOfRegisters)/2/nr)*LhsPacketSize, default_mr = (plain_enum_min(16, NumberOfRegisters)/2/nr)*LhsPacketSize,
@ -545,8 +555,9 @@ public:
}; };
template<typename RealScalar, bool ConjLhs_, int Arch, int PacketSize_>
class gebp_traits<std::complex<RealScalar>, RealScalar, ConjLhs_, false, Arch, PacketSize_> template<typename RealScalar, bool UnitResIncr_, bool ConjLhs_, int Arch, int PacketSize_>
class gebp_traits<std::complex<RealScalar>, RealScalar, UnitResIncr_, ConjLhs_, false, Arch, PacketSize_>
{ {
public: public:
typedef std::complex<RealScalar> LhsScalar; typedef std::complex<RealScalar> LhsScalar;
@ -756,8 +767,8 @@ template<typename Packet> struct unpacket_traits<DoublePacket<Packet> > {
// return res; // return res;
// } // }
template<typename RealScalar, bool ConjLhs_, bool ConjRhs_, int Arch, int PacketSize_> template<typename RealScalar, bool UnitResIncr_, bool ConjLhs_, bool ConjRhs_, int Arch, int PacketSize_>
class gebp_traits<std::complex<RealScalar>, std::complex<RealScalar>, ConjLhs_, ConjRhs_, Arch, PacketSize_ > class gebp_traits<std::complex<RealScalar>, std::complex<RealScalar>, UnitResIncr_, ConjLhs_, ConjRhs_, Arch, PacketSize_ >
{ {
public: public:
typedef std::complex<RealScalar> Scalar; typedef std::complex<RealScalar> Scalar;
@ -922,8 +933,8 @@ protected:
conj_helper<LhsScalar,RhsScalar,ConjLhs,ConjRhs> cj; conj_helper<LhsScalar,RhsScalar,ConjLhs,ConjRhs> cj;
}; };
template<typename RealScalar, bool ConjRhs_, int Arch, int PacketSize_> template<typename RealScalar, bool UnitResIncr, bool ConjRhs_, int Arch, int PacketSize_>
class gebp_traits<RealScalar, std::complex<RealScalar>, false, ConjRhs_, Arch, PacketSize_ > class gebp_traits<RealScalar, std::complex<RealScalar>, UnitResIncr, false, ConjRhs_, Arch, PacketSize_ >
{ {
public: public:
typedef std::complex<RealScalar> Scalar; typedef std::complex<RealScalar> Scalar;
@ -1058,9 +1069,9 @@ protected:
template<typename LhsScalar, typename RhsScalar, typename Index, typename DataMapper, int mr, int nr, bool ConjugateLhs, bool ConjugateRhs> template<typename LhsScalar, typename RhsScalar, typename Index, typename DataMapper, int mr, int nr, bool ConjugateLhs, bool ConjugateRhs>
struct gebp_kernel struct gebp_kernel
{ {
typedef gebp_traits<LhsScalar,RhsScalar,ConjugateLhs,ConjugateRhs,Architecture::Target> Traits; typedef gebp_traits<LhsScalar,RhsScalar,DataMapper::incr == 1,ConjugateLhs,ConjugateRhs,Architecture::Target> Traits;
typedef gebp_traits<LhsScalar,RhsScalar,ConjugateLhs,ConjugateRhs,Architecture::Target,GEBPPacketHalf> HalfTraits; typedef gebp_traits<LhsScalar,RhsScalar,DataMapper::incr == 1,ConjugateLhs,ConjugateRhs,Architecture::Target,GEBPPacketHalf> HalfTraits;
typedef gebp_traits<LhsScalar,RhsScalar,ConjugateLhs,ConjugateRhs,Architecture::Target,GEBPPacketQuarter> QuarterTraits; typedef gebp_traits<LhsScalar,RhsScalar,DataMapper::incr == 1,ConjugateLhs,ConjugateRhs,Architecture::Target,GEBPPacketQuarter> QuarterTraits;
typedef typename Traits::ResScalar ResScalar; typedef typename Traits::ResScalar ResScalar;
typedef typename Traits::LhsPacket LhsPacket; typedef typename Traits::LhsPacket LhsPacket;
@ -1071,7 +1082,7 @@ struct gebp_kernel
typedef typename RhsPanelHelper<RhsPacket, RhsPacketx4, 15>::type RhsPanel15; typedef typename RhsPanelHelper<RhsPacket, RhsPacketx4, 15>::type RhsPanel15;
typedef gebp_traits<RhsScalar,LhsScalar,ConjugateRhs,ConjugateLhs,Architecture::Target> SwappedTraits; typedef gebp_traits<RhsScalar,LhsScalar,DataMapper::incr == 1,ConjugateRhs,ConjugateLhs,Architecture::Target> SwappedTraits;
typedef typename SwappedTraits::ResScalar SResScalar; typedef typename SwappedTraits::ResScalar SResScalar;
typedef typename SwappedTraits::LhsPacket SLhsPacket; typedef typename SwappedTraits::LhsPacket SLhsPacket;
@ -1109,11 +1120,11 @@ struct gebp_kernel
}; };
template<typename LhsScalar, typename RhsScalar, typename Index, typename DataMapper, int mr, int nr, bool ConjugateLhs, bool ConjugateRhs, template<typename LhsScalar, typename RhsScalar, typename Index, typename DataMapper, int mr, int nr, bool ConjugateLhs, bool ConjugateRhs,
int SwappedLhsProgress = gebp_traits<RhsScalar,LhsScalar,ConjugateRhs,ConjugateLhs,Architecture::Target>::LhsProgress> int SwappedLhsProgress = gebp_traits<RhsScalar,LhsScalar,DataMapper::incr == 1,ConjugateRhs,ConjugateLhs,Architecture::Target>::LhsProgress>
struct last_row_process_16_packets struct last_row_process_16_packets
{ {
typedef gebp_traits<LhsScalar,RhsScalar,ConjugateLhs,ConjugateRhs,Architecture::Target> Traits; typedef gebp_traits<LhsScalar,RhsScalar,DataMapper::incr == 1,ConjugateLhs,ConjugateRhs,Architecture::Target> Traits;
typedef gebp_traits<RhsScalar,LhsScalar,ConjugateRhs,ConjugateLhs,Architecture::Target> SwappedTraits; typedef gebp_traits<RhsScalar,LhsScalar,DataMapper::incr == 1,ConjugateRhs,ConjugateLhs,Architecture::Target> SwappedTraits;
typedef typename Traits::ResScalar ResScalar; typedef typename Traits::ResScalar ResScalar;
typedef typename SwappedTraits::LhsPacket SLhsPacket; typedef typename SwappedTraits::LhsPacket SLhsPacket;
@ -1141,8 +1152,8 @@ struct last_row_process_16_packets
template<typename LhsScalar, typename RhsScalar, typename Index, typename DataMapper, int mr, int nr, bool ConjugateLhs, bool ConjugateRhs> template<typename LhsScalar, typename RhsScalar, typename Index, typename DataMapper, int mr, int nr, bool ConjugateLhs, bool ConjugateRhs>
struct last_row_process_16_packets<LhsScalar, RhsScalar, Index, DataMapper, mr, nr, ConjugateLhs, ConjugateRhs, 16> { struct last_row_process_16_packets<LhsScalar, RhsScalar, Index, DataMapper, mr, nr, ConjugateLhs, ConjugateRhs, 16> {
typedef gebp_traits<LhsScalar,RhsScalar,ConjugateLhs,ConjugateRhs,Architecture::Target> Traits; typedef gebp_traits<LhsScalar,RhsScalar,DataMapper::incr == 1,ConjugateLhs,ConjugateRhs,Architecture::Target> Traits;
typedef gebp_traits<RhsScalar,LhsScalar,ConjugateRhs,ConjugateLhs,Architecture::Target> SwappedTraits; typedef gebp_traits<RhsScalar,LhsScalar,DataMapper::incr == 1,ConjugateRhs,ConjugateLhs,Architecture::Target> SwappedTraits;
typedef typename Traits::ResScalar ResScalar; typedef typename Traits::ResScalar ResScalar;
typedef typename SwappedTraits::LhsPacket SLhsPacket; typedef typename SwappedTraits::LhsPacket SLhsPacket;
@ -1408,6 +1419,15 @@ void gebp_kernel<LhsScalar,RhsScalar,Index,DataMapper,mr,nr,ConjugateLhs,Conjuga
Index rows, Index depth, Index cols, ResScalar alpha, Index rows, Index depth, Index cols, ResScalar alpha,
Index strideA, Index strideB, Index offsetA, Index offsetB) Index strideA, Index strideB, Index offsetA, Index offsetB)
{ {
#if defined(EIGEN_VECTORIZE_AVX512)
if (nr == 8) {
bool done = gemm_kernel(
rows, cols, depth, alpha, blockA, blockB,
(ResScalar *)res.data(), res.stride(),
strideA, strideB, offsetA, offsetB);
if (done) return;
}
#endif
Traits traits; Traits traits;
SwappedTraits straits; SwappedTraits straits;
@ -2397,51 +2417,67 @@ EIGEN_DONT_INLINE void gemm_pack_rhs<Scalar, Index, DataMapper, nr, ColMajor, Co
Index packet_cols4 = nr>=4 ? (cols/4) * 4 : 0; Index packet_cols4 = nr>=4 ? (cols/4) * 4 : 0;
Index count = 0; Index count = 0;
const Index peeled_k = (depth/PacketSize)*PacketSize; const Index peeled_k = (depth/PacketSize)*PacketSize;
// if(nr>=8) if(nr>=8)
// { {
// for(Index j2=0; j2<packet_cols8; j2+=8) for(Index j2=0; j2<packet_cols8; j2+=8)
// { {
// // skip what we have before // skip what we have before
// if(PanelMode) count += 8 * offset; if(PanelMode) count += 8 * offset;
// const Scalar* b0 = &rhs[(j2+0)*rhsStride]; const LinearMapper dm0 = rhs.getLinearMapper(0, j2+0);
// const Scalar* b1 = &rhs[(j2+1)*rhsStride]; const LinearMapper dm1 = rhs.getLinearMapper(0, j2+1);
// const Scalar* b2 = &rhs[(j2+2)*rhsStride]; const LinearMapper dm2 = rhs.getLinearMapper(0, j2+2);
// const Scalar* b3 = &rhs[(j2+3)*rhsStride]; const LinearMapper dm3 = rhs.getLinearMapper(0, j2+3);
// const Scalar* b4 = &rhs[(j2+4)*rhsStride]; const LinearMapper dm4 = rhs.getLinearMapper(0, j2+4);
// const Scalar* b5 = &rhs[(j2+5)*rhsStride]; const LinearMapper dm5 = rhs.getLinearMapper(0, j2+5);
// const Scalar* b6 = &rhs[(j2+6)*rhsStride]; const LinearMapper dm6 = rhs.getLinearMapper(0, j2+6);
// const Scalar* b7 = &rhs[(j2+7)*rhsStride]; const LinearMapper dm7 = rhs.getLinearMapper(0, j2+7);
// Index k=0; Index k=0;
// if(PacketSize==8) // TODO enable vectorized transposition for PacketSize==4 #if 0
// { // TODO Need to enable vectorized transposition.
// for(; k<peeled_k; k+=PacketSize) { if((PacketSize%8)==0) // TODO enable vectorized transposition for PacketSize==4
// PacketBlock<Packet> kernel; {
// for (int p = 0; p < PacketSize; ++p) { for(; k<peeled_k; k+=PacketSize) {
// kernel.packet[p] = ploadu<Packet>(&rhs[(j2+p)*rhsStride+k]); PacketBlock<Packet,(PacketSize%8)==0?8:PacketSize> kernel;
// }
// ptranspose(kernel); kernel.packet[0] = dm0.template loadPacket<Packet>(k);
// for (int p = 0; p < PacketSize; ++p) { kernel.packet[1] = dm1.template loadPacket<Packet>(k);
// pstoreu(blockB+count, cj.pconj(kernel.packet[p])); kernel.packet[2] = dm2.template loadPacket<Packet>(k);
// count+=PacketSize; kernel.packet[3] = dm3.template loadPacket<Packet>(k);
// } kernel.packet[4] = dm4.template loadPacket<Packet>(k);
// } kernel.packet[5] = dm5.template loadPacket<Packet>(k);
// } kernel.packet[6] = dm6.template loadPacket<Packet>(k);
// for(; k<depth; k++) kernel.packet[7] = dm7.template loadPacket<Packet>(k);
// {
// blockB[count+0] = cj(b0[k]); ptranspose(kernel);
// blockB[count+1] = cj(b1[k]);
// blockB[count+2] = cj(b2[k]); pstoreu(blockB+count+0*PacketSize, cj.pconj(kernel.packet[0]));
// blockB[count+3] = cj(b3[k]); pstoreu(blockB+count+1*PacketSize, cj.pconj(kernel.packet[1%PacketSize]));
// blockB[count+4] = cj(b4[k]); pstoreu(blockB+count+2*PacketSize, cj.pconj(kernel.packet[2%PacketSize]));
// blockB[count+5] = cj(b5[k]); pstoreu(blockB+count+3*PacketSize, cj.pconj(kernel.packet[3%PacketSize]));
// blockB[count+6] = cj(b6[k]); pstoreu(blockB+count+4*PacketSize, cj.pconj(kernel.packet[4%PacketSize]));
// blockB[count+7] = cj(b7[k]); pstoreu(blockB+count+5*PacketSize, cj.pconj(kernel.packet[5%PacketSize]));
// count += 8; pstoreu(blockB+count+6*PacketSize, cj.pconj(kernel.packet[6%PacketSize]));
// } pstoreu(blockB+count+7*PacketSize, cj.pconj(kernel.packet[7%PacketSize]));
// // skip what we have after count+=8*PacketSize;
// if(PanelMode) count += 8 * (stride-offset-depth); }
// } }
// } #endif
for(; k<depth; k++)
{
blockB[count+0] = cj(dm0(k));
blockB[count+1] = cj(dm1(k));
blockB[count+2] = cj(dm2(k));
blockB[count+3] = cj(dm3(k));
blockB[count+4] = cj(dm4(k));
blockB[count+5] = cj(dm5(k));
blockB[count+6] = cj(dm6(k));
blockB[count+7] = cj(dm7(k));
count += 8;
}
// skip what we have after
if(PanelMode) count += 8 * (stride-offset-depth);
}
}
if(nr>=4) if(nr>=4)
{ {
@ -2522,39 +2558,50 @@ struct gemm_pack_rhs<Scalar, Index, DataMapper, nr, RowMajor, Conjugate, PanelMo
Index packet_cols4 = nr>=4 ? (cols/4) * 4 : 0; Index packet_cols4 = nr>=4 ? (cols/4) * 4 : 0;
Index count = 0; Index count = 0;
// if(nr>=8) if(nr>=8)
// { {
// for(Index j2=0; j2<packet_cols8; j2+=8) for(Index j2=0; j2<packet_cols8; j2+=8)
// { {
// // skip what we have before // skip what we have before
// if(PanelMode) count += 8 * offset; if(PanelMode) count += 8 * offset;
// for(Index k=0; k<depth; k++) for(Index k=0; k<depth; k++)
// { {
// if (PacketSize==8) { if (PacketSize==8) {
// Packet A = ploadu<Packet>(&rhs[k*rhsStride + j2]); // Packet A = ploadu<Packet>(&rhs.data()[k*rhs.stride() + j2]);
// pstoreu(blockB+count, cj.pconj(A)); Packet A = rhs.template loadPacket<Packet>(k, j2);
// } else if (PacketSize==4) { pstoreu(blockB+count, cj.pconj(A));
// Packet A = ploadu<Packet>(&rhs[k*rhsStride + j2]); } else if (HasHalf && HalfPacketSize==8) {
// Packet B = ploadu<Packet>(&rhs[k*rhsStride + j2 + PacketSize]); HalfPacket A = rhs.template loadPacket<HalfPacket>(k, j2);
// pstoreu(blockB+count, cj.pconj(A)); pstoreu(blockB+count, cj.pconj(A));
// pstoreu(blockB+count+PacketSize, cj.pconj(B)); } else if (HasQuarter && QuarterPacketSize==8) {
// } else { QuarterPacket A = rhs.template loadPacket<QuarterPacket>(k, j2);
// const Scalar* b0 = &rhs[k*rhsStride + j2]; pstoreu(blockB+count, cj.pconj(A));
// blockB[count+0] = cj(b0[0]); } else if (PacketSize==4) {
// blockB[count+1] = cj(b0[1]); // Packet A = ploadu<Packet>(&rhs.data()[k*rhs.stride() + j2]);
// blockB[count+2] = cj(b0[2]); // Packet B = ploadu<Packet>(&rhs.data()[k*rhs.stride() + j2 + PacketSize]);
// blockB[count+3] = cj(b0[3]); Packet A = rhs.template loadPacket<Packet>(k, j2);
// blockB[count+4] = cj(b0[4]); Packet B = rhs.template loadPacket<Packet>(k, j2 + PacketSize);
// blockB[count+5] = cj(b0[5]); pstoreu(blockB+count, cj.pconj(A));
// blockB[count+6] = cj(b0[6]); pstoreu(blockB+count+PacketSize, cj.pconj(B));
// blockB[count+7] = cj(b0[7]); } else {
// } // const Scalar* b0 = &rhs.data()[k*rhs.stride() + j2];
// count += 8; const LinearMapper dm0 = rhs.getLinearMapper(k, j2);
// } blockB[count+0] = cj(dm0(0));
// // skip what we have after blockB[count+1] = cj(dm0(1));
// if(PanelMode) count += 8 * (stride-offset-depth); blockB[count+2] = cj(dm0(2));
// } blockB[count+3] = cj(dm0(3));
// } blockB[count+4] = cj(dm0(4));
blockB[count+5] = cj(dm0(5));
blockB[count+6] = cj(dm0(6));
blockB[count+7] = cj(dm0(7));
}
count += 8;
}
// skip what we have after
if(PanelMode) count += 8 * (stride-offset-depth);
}
}
if(nr>=4) if(nr>=4)
{ {
for(Index j2=packet_cols8; j2<packet_cols4; j2+=4) for(Index j2=packet_cols8; j2<packet_cols4; j2+=4)

View File

@ -26,7 +26,7 @@ template<
int ResInnerStride> int ResInnerStride>
struct general_matrix_matrix_product<Index,LhsScalar,LhsStorageOrder,ConjugateLhs,RhsScalar,RhsStorageOrder,ConjugateRhs,RowMajor,ResInnerStride> struct general_matrix_matrix_product<Index,LhsScalar,LhsStorageOrder,ConjugateLhs,RhsScalar,RhsStorageOrder,ConjugateRhs,RowMajor,ResInnerStride>
{ {
typedef gebp_traits<RhsScalar,LhsScalar> Traits; typedef gebp_traits<RhsScalar,LhsScalar,ResInnerStride == 1> Traits;
typedef typename ScalarBinaryOpTraits<LhsScalar, RhsScalar>::ReturnType ResScalar; typedef typename ScalarBinaryOpTraits<LhsScalar, RhsScalar>::ReturnType ResScalar;
static EIGEN_STRONG_INLINE void run( static EIGEN_STRONG_INLINE void run(
@ -57,7 +57,7 @@ template<
struct general_matrix_matrix_product<Index,LhsScalar,LhsStorageOrder,ConjugateLhs,RhsScalar,RhsStorageOrder,ConjugateRhs,ColMajor,ResInnerStride> struct general_matrix_matrix_product<Index,LhsScalar,LhsStorageOrder,ConjugateLhs,RhsScalar,RhsStorageOrder,ConjugateRhs,ColMajor,ResInnerStride>
{ {
typedef gebp_traits<LhsScalar,RhsScalar> Traits; typedef gebp_traits<LhsScalar,RhsScalar, ResInnerStride == 1> Traits;
typedef typename ScalarBinaryOpTraits<LhsScalar, RhsScalar>::ReturnType ResScalar; typedef typename ScalarBinaryOpTraits<LhsScalar, RhsScalar>::ReturnType ResScalar;
static void run(Index rows, Index cols, Index depth, static void run(Index rows, Index cols, Index depth,
@ -287,7 +287,6 @@ class gemm_blocking_space<StorageOrder,LhsScalar_,RhsScalar_,MaxRows, MaxCols, M
}; };
typedef std::conditional_t<Transpose,RhsScalar_,LhsScalar_> LhsScalar; typedef std::conditional_t<Transpose,RhsScalar_,LhsScalar_> LhsScalar;
typedef std::conditional_t<Transpose,LhsScalar_,RhsScalar_> RhsScalar; typedef std::conditional_t<Transpose,LhsScalar_,RhsScalar_> RhsScalar;
typedef gebp_traits<LhsScalar,RhsScalar> Traits;
enum { enum {
SizeA = ActualRows * MaxDepth, SizeA = ActualRows * MaxDepth,
SizeB = ActualCols * MaxDepth SizeB = ActualCols * MaxDepth
@ -336,7 +335,6 @@ class gemm_blocking_space<StorageOrder,LhsScalar_,RhsScalar_,MaxRows, MaxCols, M
}; };
typedef std::conditional_t<Transpose,RhsScalar_,LhsScalar_> LhsScalar; typedef std::conditional_t<Transpose,RhsScalar_,LhsScalar_> LhsScalar;
typedef std::conditional_t<Transpose,LhsScalar_,RhsScalar_> RhsScalar; typedef std::conditional_t<Transpose,LhsScalar_,RhsScalar_> RhsScalar;
typedef gebp_traits<LhsScalar,RhsScalar> Traits;
Index m_sizeA; Index m_sizeA;
Index m_sizeB; Index m_sizeB;

View File

@ -67,7 +67,7 @@ struct general_matrix_matrix_triangular_product<Index,LhsScalar,LhsStorageOrder,
ResScalar* _res, Index resIncr, Index resStride, ResScalar* _res, Index resIncr, Index resStride,
const ResScalar& alpha, level3_blocking<LhsScalar,RhsScalar>& blocking) const ResScalar& alpha, level3_blocking<LhsScalar,RhsScalar>& blocking)
{ {
typedef gebp_traits<LhsScalar,RhsScalar> Traits; typedef gebp_traits<LhsScalar,RhsScalar,ResInnerStride == 1> Traits;
typedef const_blas_data_mapper<LhsScalar, Index, LhsStorageOrder> LhsMapper; typedef const_blas_data_mapper<LhsScalar, Index, LhsStorageOrder> LhsMapper;
typedef const_blas_data_mapper<RhsScalar, Index, RhsStorageOrder> RhsMapper; typedef const_blas_data_mapper<RhsScalar, Index, RhsStorageOrder> RhsMapper;
@ -140,7 +140,7 @@ struct general_matrix_matrix_triangular_product<Index,LhsScalar,LhsStorageOrder,
template<typename LhsScalar, typename RhsScalar, typename Index, int mr, int nr, bool ConjLhs, bool ConjRhs, int ResInnerStride, int UpLo> template<typename LhsScalar, typename RhsScalar, typename Index, int mr, int nr, bool ConjLhs, bool ConjRhs, int ResInnerStride, int UpLo>
struct tribb_kernel struct tribb_kernel
{ {
typedef gebp_traits<LhsScalar,RhsScalar,ConjLhs,ConjRhs> Traits; typedef gebp_traits<LhsScalar,RhsScalar,ResInnerStride == 1,ConjLhs,ConjRhs> Traits;
typedef typename Traits::ResScalar ResScalar; typedef typename Traits::ResScalar ResScalar;
enum { enum {

View File

@ -55,7 +55,7 @@ template< \
int RhsStorageOrder, bool ConjugateRhs> \ int RhsStorageOrder, bool ConjugateRhs> \
struct general_matrix_matrix_product<Index,EIGTYPE,LhsStorageOrder,ConjugateLhs,EIGTYPE,RhsStorageOrder,ConjugateRhs,ColMajor,1> \ struct general_matrix_matrix_product<Index,EIGTYPE,LhsStorageOrder,ConjugateLhs,EIGTYPE,RhsStorageOrder,ConjugateRhs,ColMajor,1> \
{ \ { \
typedef gebp_traits<EIGTYPE,EIGTYPE> Traits; \ typedef gebp_traits<EIGTYPE,EIGTYPE,true> Traits; \
\ \
static void run(Index rows, Index cols, Index depth, \ static void run(Index rows, Index cols, Index depth, \
const EIGTYPE* _lhs, Index lhsStride, \ const EIGTYPE* _lhs, Index lhsStride, \

View File

@ -351,7 +351,7 @@ EIGEN_DONT_INLINE void product_selfadjoint_matrix<Scalar,Index,LhsStorageOrder,t
{ {
Index size = rows; Index size = rows;
typedef gebp_traits<Scalar,Scalar> Traits; typedef gebp_traits<Scalar,Scalar,ResInnerStride == 1> Traits;
typedef const_blas_data_mapper<Scalar, Index, LhsStorageOrder> LhsMapper; typedef const_blas_data_mapper<Scalar, Index, LhsStorageOrder> LhsMapper;
typedef const_blas_data_mapper<Scalar, Index, (LhsStorageOrder == RowMajor) ? ColMajor : RowMajor> LhsTransposeMapper; typedef const_blas_data_mapper<Scalar, Index, (LhsStorageOrder == RowMajor) ? ColMajor : RowMajor> LhsTransposeMapper;
@ -446,7 +446,7 @@ EIGEN_DONT_INLINE void product_selfadjoint_matrix<Scalar,Index,LhsStorageOrder,f
{ {
Index size = cols; Index size = cols;
typedef gebp_traits<Scalar,Scalar> Traits; typedef gebp_traits<Scalar,Scalar,ResInnerStride == 1> Traits;
typedef const_blas_data_mapper<Scalar, Index, LhsStorageOrder> LhsMapper; typedef const_blas_data_mapper<Scalar, Index, LhsStorageOrder> LhsMapper;
typedef blas_data_mapper<typename Traits::ResScalar, Index, ColMajor, Unaligned, ResInnerStride> ResMapper; typedef blas_data_mapper<typename Traits::ResScalar, Index, ColMajor, Unaligned, ResInnerStride> ResMapper;

View File

@ -89,7 +89,7 @@ struct product_triangular_matrix_matrix<Scalar,Index,Mode,true,
RhsStorageOrder,ConjugateRhs,ColMajor,ResInnerStride,Version> RhsStorageOrder,ConjugateRhs,ColMajor,ResInnerStride,Version>
{ {
typedef gebp_traits<Scalar,Scalar> Traits; typedef gebp_traits<Scalar,Scalar,ResInnerStride == 1> Traits;
enum { enum {
SmallPanelWidth = 2 * plain_enum_max(Traits::mr, Traits::nr), SmallPanelWidth = 2 * plain_enum_max(Traits::mr, Traits::nr),
IsLower = (Mode&Lower) == Lower, IsLower = (Mode&Lower) == Lower,
@ -247,7 +247,7 @@ struct product_triangular_matrix_matrix<Scalar,Index,Mode,false,
LhsStorageOrder,ConjugateLhs, LhsStorageOrder,ConjugateLhs,
RhsStorageOrder,ConjugateRhs,ColMajor,ResInnerStride,Version> RhsStorageOrder,ConjugateRhs,ColMajor,ResInnerStride,Version>
{ {
typedef gebp_traits<Scalar,Scalar> Traits; typedef gebp_traits<Scalar,Scalar,ResInnerStride == 1> Traits;
enum { enum {
SmallPanelWidth = plain_enum_max(Traits::mr, Traits::nr), SmallPanelWidth = plain_enum_max(Traits::mr, Traits::nr),
IsLower = (Mode&Lower) == Lower, IsLower = (Mode&Lower) == Lower,

View File

@ -189,7 +189,7 @@ EIGEN_DONT_INLINE void triangular_solve_matrix<Scalar,Index,OnTheLeft,Mode,Conju
TriMapper tri(_tri, triStride); TriMapper tri(_tri, triStride);
OtherMapper other(_other, otherStride, otherIncr); OtherMapper other(_other, otherStride, otherIncr);
typedef gebp_traits<Scalar,Scalar> Traits; typedef gebp_traits<Scalar,Scalar,OtherInnerStride == 1> Traits;
enum { enum {
SmallPanelWidth = plain_enum_max(Traits::mr, Traits::nr), SmallPanelWidth = plain_enum_max(Traits::mr, Traits::nr),
@ -336,7 +336,7 @@ EIGEN_DONT_INLINE void triangular_solve_matrix<Scalar,Index,OnTheRight,Mode,Conj
LhsMapper lhs(_other, otherStride, otherIncr); LhsMapper lhs(_other, otherStride, otherIncr);
RhsMapper rhs(_tri, triStride); RhsMapper rhs(_tri, triStride);
typedef gebp_traits<Scalar,Scalar> Traits; typedef gebp_traits<Scalar,Scalar,OtherInnerStride == 1> Traits;
enum { enum {
RhsStorageOrder = TriStorageOrder, RhsStorageOrder = TriStorageOrder,
SmallPanelWidth = plain_enum_max(Traits::mr, Traits::nr), SmallPanelWidth = plain_enum_max(Traits::mr, Traits::nr),

View File

@ -173,6 +173,7 @@ class blas_data_mapper<Scalar,Index,StorageOrder,AlignmentType,1>
public: public:
typedef BlasLinearMapper<Scalar, Index, AlignmentType> LinearMapper; typedef BlasLinearMapper<Scalar, Index, AlignmentType> LinearMapper;
typedef BlasVectorMapper<Scalar, Index> VectorMapper; typedef BlasVectorMapper<Scalar, Index> VectorMapper;
static constexpr int incr = 1;
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE blas_data_mapper(Scalar* data, Index stride, Index incr=1) EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE blas_data_mapper(Scalar* data, Index stride, Index incr=1)
: m_data(data), m_stride(stride) : m_data(data), m_stride(stride)
@ -285,6 +286,7 @@ class blas_data_mapper
{ {
public: public:
typedef BlasLinearMapper<Scalar, Index, AlignmentType,Incr> LinearMapper; typedef BlasLinearMapper<Scalar, Index, AlignmentType,Incr> LinearMapper;
static constexpr int incr = Incr;
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE blas_data_mapper(Scalar* data, Index stride, Index incr) : m_data(data), m_stride(stride), m_incr(incr) {} EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE blas_data_mapper(Scalar* data, Index stride, Index incr) : m_data(data), m_stride(stride), m_incr(incr) {}
@ -402,6 +404,9 @@ public:
storePacketBlock_helper<SubPacket, Scalar, n, n-1> spb; storePacketBlock_helper<SubPacket, Scalar, n, n-1> spb;
spb.store(this, i,j,block); spb.store(this, i,j,block);
} }
EIGEN_DEVICE_FUNC const Index stride() const { return m_stride; }
EIGEN_DEVICE_FUNC Scalar* data() const { return m_data; }
protected: protected:
Scalar* EIGEN_RESTRICT m_data; Scalar* EIGEN_RESTRICT m_data;
const Index m_stride; const Index m_stride;

View File

@ -143,8 +143,8 @@ int main()
// Specialize GEBP kernel and traits for mpreal (no need for peeling, nor complicated stuff) // Specialize GEBP kernel and traits for mpreal (no need for peeling, nor complicated stuff)
// This also permits to directly call mpfr's routines and avoid many temporaries produced by mpreal // This also permits to directly call mpfr's routines and avoid many temporaries produced by mpreal
template<> template<bool UnitResIncr>
class gebp_traits<mpfr::mpreal, mpfr::mpreal, false, false> class gebp_traits<mpfr::mpreal, mpfr::mpreal, UnitResIncr, false, false>
{ {
public: public:
typedef mpfr::mpreal ResScalar; typedef mpfr::mpreal ResScalar;