Use 3px8/2px8/1px8/1x8 gebp_kernel on arm64-neon

This commit is contained in:
Lianhuang Li 2022-09-21 16:36:40 +00:00 committed by Rasmus Munk Larsen
parent 7b2901e2aa
commit 23299632c2
2 changed files with 796 additions and 97 deletions

View File

@ -49,7 +49,9 @@ struct gebp_traits <float,float,false,false,Architecture::NEON,GEBPPacketFull>
{
typedef float RhsPacket;
typedef float32x4_t RhsPacketx4;
enum {
nr = 8
};
EIGEN_STRONG_INLINE void loadRhs(const RhsScalar* b, RhsPacket& dest) const
{
dest = *b;
@ -77,7 +79,6 @@ struct gebp_traits <float,float,false,false,Architecture::NEON,GEBPPacketFull>
{
c = vfmaq_n_f32(c, a, b);
}
// NOTE: Template parameter inference failed when compiled with Android NDK:
// "candidate template ignored: could not match 'FixedInt<N>' against 'Eigen::internal::FixedInt<0>".
@ -94,9 +95,10 @@ struct gebp_traits <float,float,false,false,Architecture::NEON,GEBPPacketFull>
template<int LaneID>
EIGEN_STRONG_INLINE void madd_helper(const LhsPacket& a, const RhsPacketx4& b, AccPacket& c) const
{
#if EIGEN_COMP_GNUC_STRICT && !(EIGEN_GNUC_AT_LEAST(9,0))
// workaround gcc issue https://gcc.gnu.org/bugzilla/show_bug.cgi?id=89101
// vfmaq_laneq_f32 is implemented through a costly dup
#if EIGEN_COMP_GNUC_STRICT
// 1. workaround gcc issue https://gcc.gnu.org/bugzilla/show_bug.cgi?id=89101
// vfmaq_laneq_f32 is implemented through a costly dup, which was fixed in gcc9
// 2. workaround the gcc register split problem on arm64-neon
if(LaneID==0) asm("fmla %0.4s, %1.4s, %2.s[0]\n" : "+w" (c) : "w" (a), "w" (b) : );
else if(LaneID==1) asm("fmla %0.4s, %1.4s, %2.s[1]\n" : "+w" (c) : "w" (a), "w" (b) : );
else if(LaneID==2) asm("fmla %0.4s, %1.4s, %2.s[2]\n" : "+w" (c) : "w" (a), "w" (b) : );
@ -113,7 +115,9 @@ struct gebp_traits <double,double,false,false,Architecture::NEON>
: gebp_traits<double,double,false,false,Architecture::Generic>
{
typedef double RhsPacket;
enum {
nr = 8
};
struct RhsPacketx4 {
float64x2_t B_0, B_1;
};
@ -163,9 +167,10 @@ struct gebp_traits <double,double,false,false,Architecture::NEON>
template <int LaneID>
EIGEN_STRONG_INLINE void madd_helper(const LhsPacket& a, const RhsPacketx4& b, AccPacket& c) const
{
#if EIGEN_COMP_GNUC_STRICT && !(EIGEN_GNUC_AT_LEAST(9,0))
// workaround gcc issue https://gcc.gnu.org/bugzilla/show_bug.cgi?id=89101
// vfmaq_laneq_f64 is implemented through a costly dup
#if EIGEN_COMP_GNUC_STRICT
// 1. workaround gcc issue https://gcc.gnu.org/bugzilla/show_bug.cgi?id=89101
// vfmaq_laneq_f64 is implemented through a costly dup, which was fixed in gcc9
// 2. workaround the gcc register split problem on arm64-neon
if(LaneID==0) asm("fmla %0.2d, %1.2d, %2.d[0]\n" : "+w" (c) : "w" (a), "w" (b.B_0) : );
else if(LaneID==1) asm("fmla %0.2d, %1.2d, %2.d[1]\n" : "+w" (c) : "w" (a), "w" (b.B_0) : );
else if(LaneID==2) asm("fmla %0.2d, %1.2d, %2.d[0]\n" : "+w" (c) : "w" (a), "w" (b.B_1) : );
@ -179,6 +184,77 @@ struct gebp_traits <double,double,false,false,Architecture::NEON>
}
};
#if EIGEN_HAS_ARM64_FP16_VECTOR_ARITHMETIC
template<>
struct gebp_traits <half,half,false,false,Architecture::NEON>
: gebp_traits<half,half,false,false,Architecture::Generic>
{
typedef half RhsPacket;
typedef float16x4_t RhsPacketx4;
typedef float16x4_t PacketHalf;
enum {
nr = 8
};
EIGEN_STRONG_INLINE void loadRhs(const RhsScalar* b, RhsPacket& dest) const
{
dest = *b;
}
EIGEN_STRONG_INLINE void loadRhs(const RhsScalar* b, RhsPacketx4& dest) const
{
dest = vld1_f16((const __fp16 *)b);
}
EIGEN_STRONG_INLINE void updateRhs(const RhsScalar* b, RhsPacket& dest) const
{
dest = *b;
}
EIGEN_STRONG_INLINE void updateRhs(const RhsScalar*, RhsPacketx4&) const
{}
EIGEN_STRONG_INLINE void loadRhsQuad(const RhsScalar* b, RhsPacket& dest) const
{
loadRhs(b,dest);
}
EIGEN_STRONG_INLINE void madd(const LhsPacket& a, const RhsPacket& b, AccPacket& c, RhsPacket& /*tmp*/, const FixedInt<0>&) const
{
c = vfmaq_n_f16(c, a, b);
}
EIGEN_STRONG_INLINE void madd(const PacketHalf& a, const RhsPacket& b, PacketHalf& c, RhsPacket& /*tmp*/, const FixedInt<0>&) const
{
c = vfma_n_f16(c, a, b);
}
// NOTE: Template parameter inference failed when compiled with Android NDK:
// "candidate template ignored: could not match 'FixedInt<N>' against 'Eigen::internal::FixedInt<0>".
EIGEN_STRONG_INLINE void madd(const LhsPacket& a, const RhsPacketx4& b, AccPacket& c, RhsPacket& /*tmp*/, const FixedInt<0>&) const
{ madd_helper<0>(a, b, c); }
EIGEN_STRONG_INLINE void madd(const LhsPacket& a, const RhsPacketx4& b, AccPacket& c, RhsPacket& /*tmp*/, const FixedInt<1>&) const
{ madd_helper<1>(a, b, c); }
EIGEN_STRONG_INLINE void madd(const LhsPacket& a, const RhsPacketx4& b, AccPacket& c, RhsPacket& /*tmp*/, const FixedInt<2>&) const
{ madd_helper<2>(a, b, c); }
EIGEN_STRONG_INLINE void madd(const LhsPacket& a, const RhsPacketx4& b, AccPacket& c, RhsPacket& /*tmp*/, const FixedInt<3>&) const
{ madd_helper<3>(a, b, c); }
private:
template<int LaneID>
EIGEN_STRONG_INLINE void madd_helper(const LhsPacket& a, const RhsPacketx4& b, AccPacket& c) const
{
#if EIGEN_COMP_GNUC_STRICT
// 1. vfmaq_lane_f16 is implemented through a costly dup
// 2. workaround the gcc register split problem on arm64-neon
if(LaneID==0) asm("fmla %0.8h, %1.8h, %2.h[0]\n" : "+w" (c) : "w" (a), "w" (b) : );
else if(LaneID==1) asm("fmla %0.8h, %1.8h, %2.h[1]\n" : "+w" (c) : "w" (a), "w" (b) : );
else if(LaneID==2) asm("fmla %0.8h, %1.8h, %2.h[2]\n" : "+w" (c) : "w" (a), "w" (b) : );
else if(LaneID==3) asm("fmla %0.8h, %1.8h, %2.h[3]\n" : "+w" (c) : "w" (a), "w" (b) : );
#else
c = vfmaq_lane_f16(c, a, b, LaneID);
#endif
}
};
#endif // EIGEN_HAS_ARM64_FP16_VECTOR_ARITHMETIC
#endif // EIGEN_ARCH_ARM64
} // namespace internal

View File

@ -1070,6 +1070,7 @@ struct gebp_kernel
typedef typename Traits::RhsPacketx4 RhsPacketx4;
typedef typename RhsPanelHelper<RhsPacket, RhsPacketx4, 15>::type RhsPanel15;
typedef typename RhsPanelHelper<RhsPacket, RhsPacketx4, 27>::type RhsPanel27;
typedef gebp_traits<RhsScalar,LhsScalar,ConjugateRhs,ConjugateLhs,Architecture::Target> SwappedTraits;
@ -1215,13 +1216,135 @@ struct lhs_process_one_packet
int prefetch_res_offset, Index peeled_kc, Index pk, Index cols, Index depth, Index packet_cols4)
{
GEBPTraits traits;
Index packet_cols8 = nr>=8 ? (cols/8) * 8 : 0;
// loops on each largest micro horizontal panel of lhs
// (LhsProgress x depth)
for(Index i=peelStart; i<peelEnd; i+=LhsProgress)
{
for(Index j2=0; j2<packet_cols8; j2+=8)
{
const LhsScalar* blA = &blockA[i*strideA+offsetA*(LhsProgress)];
prefetch(&blA[0]);
// gets res block as register
AccPacket C0, C1, C2, C3, C4, C5, C6, C7;
traits.initAcc(C0);
traits.initAcc(C1);
traits.initAcc(C2);
traits.initAcc(C3);
traits.initAcc(C4);
traits.initAcc(C5);
traits.initAcc(C6);
traits.initAcc(C7);
LinearMapper r0 = res.getLinearMapper(i, j2 + 0);
LinearMapper r1 = res.getLinearMapper(i, j2 + 1);
LinearMapper r2 = res.getLinearMapper(i, j2 + 2);
LinearMapper r3 = res.getLinearMapper(i, j2 + 3);
LinearMapper r4 = res.getLinearMapper(i, j2 + 4);
LinearMapper r5 = res.getLinearMapper(i, j2 + 5);
LinearMapper r6 = res.getLinearMapper(i, j2 + 6);
LinearMapper r7 = res.getLinearMapper(i, j2 + 7);
r0.prefetch(prefetch_res_offset);
r1.prefetch(prefetch_res_offset);
r2.prefetch(prefetch_res_offset);
r3.prefetch(prefetch_res_offset);
r4.prefetch(prefetch_res_offset);
r5.prefetch(prefetch_res_offset);
r6.prefetch(prefetch_res_offset);
r7.prefetch(prefetch_res_offset);
const RhsScalar* blB = &blockB[j2*strideB+offsetB*8];
prefetch(&blB[0]);
LhsPacket A0;
for(Index k=0; k<peeled_kc; k+=pk)
{
RhsPacketx4 rhs_panel;
RhsPacket T0;
#define EIGEN_GEBGP_ONESTEP(K) \
do { \
EIGEN_ASM_COMMENT("begin step of gebp micro kernel 1pX8"); \
traits.loadLhs(&blA[(0 + 1 * K) * LhsProgress], A0); \
traits.loadRhs(&blB[(0 + 8 * K) * RhsProgress], rhs_panel); \
traits.madd(A0, rhs_panel, C0, T0, fix<0>); \
traits.updateRhs(&blB[(1 + 8 * K) * RhsProgress], rhs_panel); \
traits.madd(A0, rhs_panel, C1, T0, fix<1>); \
traits.updateRhs(&blB[(2 + 8 * K) * RhsProgress], rhs_panel); \
traits.madd(A0, rhs_panel, C2, T0, fix<2>); \
traits.updateRhs(&blB[(3 + 8 * K) * RhsProgress], rhs_panel); \
traits.madd(A0, rhs_panel, C3, T0, fix<3>); \
traits.loadRhs(&blB[(4 + 8 * K) * RhsProgress], rhs_panel); \
traits.madd(A0, rhs_panel, C4, T0, fix<0>); \
traits.updateRhs(&blB[(5 + 8 * K) * RhsProgress], rhs_panel); \
traits.madd(A0, rhs_panel, C5, T0, fix<1>); \
traits.updateRhs(&blB[(6 + 8 * K) * RhsProgress], rhs_panel); \
traits.madd(A0, rhs_panel, C6, T0, fix<2>); \
traits.updateRhs(&blB[(7 + 8 * K) * RhsProgress], rhs_panel); \
traits.madd(A0, rhs_panel, C7, T0, fix<3>); \
EIGEN_ASM_COMMENT("end step of gebp micro kernel 1pX8"); \
} while (false)
EIGEN_ASM_COMMENT("begin gebp micro kernel 1pX8");
EIGEN_GEBGP_ONESTEP(0);
EIGEN_GEBGP_ONESTEP(1);
EIGEN_GEBGP_ONESTEP(2);
EIGEN_GEBGP_ONESTEP(3);
EIGEN_GEBGP_ONESTEP(4);
EIGEN_GEBGP_ONESTEP(5);
EIGEN_GEBGP_ONESTEP(6);
EIGEN_GEBGP_ONESTEP(7);
blB += pk*8*RhsProgress;
blA += pk*(1*LhsProgress);
EIGEN_ASM_COMMENT("end gebp micro kernel 1pX8");
}
// process remaining peeled loop
for(Index k=peeled_kc; k<depth; k++)
{
RhsPacketx4 rhs_panel;
RhsPacket T0;
EIGEN_GEBGP_ONESTEP(0);
blB += 8*RhsProgress;
blA += 1*LhsProgress;
}
#undef EIGEN_GEBGP_ONESTEP
ResPacket R0, R1;
ResPacket alphav = pset1<ResPacket>(alpha);
R0 = r0.template loadPacket<ResPacket>(0);
R1 = r1.template loadPacket<ResPacket>(0);
traits.acc(C0, alphav, R0);
traits.acc(C1, alphav, R1);
r0.storePacket(0, R0);
r1.storePacket(0, R1);
R0 = r2.template loadPacket<ResPacket>(0);
R1 = r3.template loadPacket<ResPacket>(0);
traits.acc(C2, alphav, R0);
traits.acc(C3, alphav, R1);
r2.storePacket(0, R0);
r3.storePacket(0, R1);
R0 = r4.template loadPacket<ResPacket>(0);
R1 = r5.template loadPacket<ResPacket>(0);
traits.acc(C4, alphav, R0);
traits.acc(C5, alphav, R1);
r4.storePacket(0, R0);
r5.storePacket(0, R1);
R0 = r6.template loadPacket<ResPacket>(0);
R1 = r7.template loadPacket<ResPacket>(0);
traits.acc(C6, alphav, R0);
traits.acc(C7, alphav, R1);
r6.storePacket(0, R0);
r7.storePacket(0, R1);
}
// loops on each largest micro vertical panel of rhs (depth * nr)
for(Index j2=0; j2<packet_cols4; j2+=nr)
for(Index j2=packet_cols8; j2<packet_cols4; j2+=4)
{
// We select a LhsProgress x nr micro block of res
// which is entirely stored into 1 x nr registers.
@ -1257,7 +1380,7 @@ struct lhs_process_one_packet
r3.prefetch(prefetch_res_offset);
// performs "inner" products
const RhsScalar* blB = &blockB[j2*strideB+offsetB*nr];
const RhsScalar* blB = &blockB[j2*strideB+offsetB*4];
prefetch(&blB[0]);
LhsPacket A0, A1;
@ -1415,6 +1538,7 @@ void gebp_kernel<LhsScalar,RhsScalar,Index,DataMapper,mr,nr,ConjugateLhs,Conjuga
if(strideB==-1) strideB = depth;
conj_helper<LhsScalar,RhsScalar,ConjugateLhs,ConjugateRhs> cj;
Index packet_cols4 = nr>=4 ? (cols/4) * 4 : 0;
Index packet_cols8 = nr>=8 ? (cols/8) * 8 : 0;
const Index peeled_mc3 = mr>=3*Traits::LhsProgress ? (rows/(3*LhsProgress))*(3*LhsProgress) : 0;
const Index peeled_mc2 = mr>=2*Traits::LhsProgress ? peeled_mc3+((rows-peeled_mc3)/(2*LhsProgress))*(2*LhsProgress) : 0;
const Index peeled_mc1 = mr>=1*Traits::LhsProgress ? peeled_mc2+((rows-peeled_mc2)/(1*LhsProgress))*(1*LhsProgress) : 0;
@ -1443,7 +1567,219 @@ void gebp_kernel<LhsScalar,RhsScalar,Index,DataMapper,mr,nr,ConjugateLhs,Conjuga
for(Index i1=0; i1<peeled_mc3; i1+=actual_panel_rows)
{
const Index actual_panel_end = (std::min)(i1+actual_panel_rows, peeled_mc3);
for(Index j2=0; j2<packet_cols4; j2+=nr)
// nr >= 8
for(Index j2=0; j2<packet_cols8; j2+=8)
{
for(Index i=i1; i<actual_panel_end; i+=3*LhsProgress)
{
const LhsScalar* blA = &blockA[i*strideA+offsetA*(3*LhsProgress)];
prefetch(&blA[0]);
// gets res block as register
AccPacket C0, C1, C2, C3, C4, C5, C6, C7,
C8, C9, C10, C11, C12, C13, C14, C15,
C16, C17, C18, C19, C20, C21, C22, C23;
traits.initAcc(C0); traits.initAcc(C1); traits.initAcc(C2); traits.initAcc(C3);
traits.initAcc(C4); traits.initAcc(C5); traits.initAcc(C6); traits.initAcc(C7);
traits.initAcc(C8); traits.initAcc(C9); traits.initAcc(C10); traits.initAcc(C11);
traits.initAcc(C12); traits.initAcc(C13); traits.initAcc(C14); traits.initAcc(C15);
traits.initAcc(C16); traits.initAcc(C17); traits.initAcc(C18); traits.initAcc(C19);
traits.initAcc(C20); traits.initAcc(C21); traits.initAcc(C22); traits.initAcc(C23);
LinearMapper r0 = res.getLinearMapper(i, j2 + 0);
LinearMapper r1 = res.getLinearMapper(i, j2 + 1);
LinearMapper r2 = res.getLinearMapper(i, j2 + 2);
LinearMapper r3 = res.getLinearMapper(i, j2 + 3);
LinearMapper r4 = res.getLinearMapper(i, j2 + 4);
LinearMapper r5 = res.getLinearMapper(i, j2 + 5);
LinearMapper r6 = res.getLinearMapper(i, j2 + 6);
LinearMapper r7 = res.getLinearMapper(i, j2 + 7);
r0.prefetch(0);
r1.prefetch(0);
r2.prefetch(0);
r3.prefetch(0);
r4.prefetch(0);
r5.prefetch(0);
r6.prefetch(0);
r7.prefetch(0);
// performs "inner" products
const RhsScalar* blB = &blockB[j2*strideB+offsetB*8];
prefetch(&blB[0]);
LhsPacket A0, A1;
for(Index k=0; k<peeled_kc; k+=pk)
{
EIGEN_ASM_COMMENT("begin gebp micro kernel 3pX8");
// 27 registers are taken (24 for acc, 3 for lhs).
RhsPanel27 rhs_panel;
RhsPacket T0;
LhsPacket A2;
#if EIGEN_COMP_GNUC_STRICT && EIGEN_ARCH_ARM64 && defined(EIGEN_VECTORIZE_NEON) && !(EIGEN_GNUC_AT_LEAST(9,0))
// see http://eigen.tuxfamily.org/bz/show_bug.cgi?id=1633
// without this workaround A0, A1, and A2 are loaded in the same register,
// which is not good for pipelining
#define EIGEN_GEBP_3Px8_REGISTER_ALLOC_WORKAROUND __asm__ ("" : "+w,m" (A0), "+w,m" (A1), "+w,m" (A2));
#else
#define EIGEN_GEBP_3Px8_REGISTER_ALLOC_WORKAROUND
#endif
#define EIGEN_GEBP_ONESTEP(K) \
do { \
EIGEN_ASM_COMMENT("begin step of gebp micro kernel 3pX8"); \
traits.loadLhs(&blA[(0 + 3 * K) * LhsProgress], A0); \
traits.loadLhs(&blA[(1 + 3 * K) * LhsProgress], A1); \
traits.loadLhs(&blA[(2 + 3 * K) * LhsProgress], A2); \
EIGEN_GEBP_3Px8_REGISTER_ALLOC_WORKAROUND \
traits.loadRhs(blB + (0 + 8 * K) * Traits::RhsProgress, rhs_panel); \
traits.madd(A0, rhs_panel, C0, T0, fix<0>); \
traits.madd(A1, rhs_panel, C8, T0, fix<0>); \
traits.madd(A2, rhs_panel, C16, T0, fix<0>); \
traits.updateRhs(blB + (1 + 8 * K) * Traits::RhsProgress, rhs_panel); \
traits.madd(A0, rhs_panel, C1, T0, fix<1>); \
traits.madd(A1, rhs_panel, C9, T0, fix<1>); \
traits.madd(A2, rhs_panel, C17, T0, fix<1>); \
traits.updateRhs(blB + (2 + 8 * K) * Traits::RhsProgress, rhs_panel); \
traits.madd(A0, rhs_panel, C2, T0, fix<2>); \
traits.madd(A1, rhs_panel, C10, T0, fix<2>); \
traits.madd(A2, rhs_panel, C18, T0, fix<2>); \
traits.updateRhs(blB + (3 + 8 * K) * Traits::RhsProgress, rhs_panel); \
traits.madd(A0, rhs_panel, C3, T0, fix<3>); \
traits.madd(A1, rhs_panel, C11, T0, fix<3>); \
traits.madd(A2, rhs_panel, C19, T0, fix<3>); \
traits.loadRhs(blB + (4 + 8 * K) * Traits::RhsProgress, rhs_panel); \
traits.madd(A0, rhs_panel, C4, T0, fix<0>); \
traits.madd(A1, rhs_panel, C12, T0, fix<0>); \
traits.madd(A2, rhs_panel, C20, T0, fix<0>); \
traits.updateRhs(blB + (5 + 8 * K) * Traits::RhsProgress, rhs_panel); \
traits.madd(A0, rhs_panel, C5, T0, fix<1>); \
traits.madd(A1, rhs_panel, C13, T0, fix<1>); \
traits.madd(A2, rhs_panel, C21, T0, fix<1>); \
traits.updateRhs(blB + (6 + 8 * K) * Traits::RhsProgress, rhs_panel); \
traits.madd(A0, rhs_panel, C6, T0, fix<2>); \
traits.madd(A1, rhs_panel, C14, T0, fix<2>); \
traits.madd(A2, rhs_panel, C22, T0, fix<2>); \
traits.updateRhs(blB + (7 + 8 * K) * Traits::RhsProgress, rhs_panel); \
traits.madd(A0, rhs_panel, C7, T0, fix<3>); \
traits.madd(A1, rhs_panel, C15, T0, fix<3>); \
traits.madd(A2, rhs_panel, C23, T0, fix<3>); \
EIGEN_ASM_COMMENT("end step of gebp micro kernel 3pX8"); \
} while (false)
EIGEN_GEBP_ONESTEP(0);
EIGEN_GEBP_ONESTEP(1);
EIGEN_GEBP_ONESTEP(2);
EIGEN_GEBP_ONESTEP(3);
EIGEN_GEBP_ONESTEP(4);
EIGEN_GEBP_ONESTEP(5);
EIGEN_GEBP_ONESTEP(6);
EIGEN_GEBP_ONESTEP(7);
blB += pk * 8 * RhsProgress;
blA += pk * 3 * Traits::LhsProgress;
EIGEN_ASM_COMMENT("end gebp micro kernel 3pX8");
}
// process remaining peeled loop
for (Index k = peeled_kc; k < depth; k++)
{
RhsPanel27 rhs_panel;
RhsPacket T0;
LhsPacket A2;
EIGEN_GEBP_ONESTEP(0);
blB += 8 * RhsProgress;
blA += 3 * Traits::LhsProgress;
}
#undef EIGEN_GEBP_ONESTEP
ResPacket R0, R1, R2;
ResPacket alphav = pset1<ResPacket>(alpha);
R0 = r0.template loadPacket<ResPacket>(0 * Traits::ResPacketSize);
R1 = r0.template loadPacket<ResPacket>(1 * Traits::ResPacketSize);
R2 = r0.template loadPacket<ResPacket>(2 * Traits::ResPacketSize);
traits.acc(C0, alphav, R0);
traits.acc(C8, alphav, R1);
traits.acc(C16, alphav, R2);
r0.storePacket(0 * Traits::ResPacketSize, R0);
r0.storePacket(1 * Traits::ResPacketSize, R1);
r0.storePacket(2 * Traits::ResPacketSize, R2);
R0 = r1.template loadPacket<ResPacket>(0 * Traits::ResPacketSize);
R1 = r1.template loadPacket<ResPacket>(1 * Traits::ResPacketSize);
R2 = r1.template loadPacket<ResPacket>(2 * Traits::ResPacketSize);
traits.acc(C1, alphav, R0);
traits.acc(C9, alphav, R1);
traits.acc(C17, alphav, R2);
r1.storePacket(0 * Traits::ResPacketSize, R0);
r1.storePacket(1 * Traits::ResPacketSize, R1);
r1.storePacket(2 * Traits::ResPacketSize, R2);
R0 = r2.template loadPacket<ResPacket>(0 * Traits::ResPacketSize);
R1 = r2.template loadPacket<ResPacket>(1 * Traits::ResPacketSize);
R2 = r2.template loadPacket<ResPacket>(2 * Traits::ResPacketSize);
traits.acc(C2, alphav, R0);
traits.acc(C10, alphav, R1);
traits.acc(C18, alphav, R2);
r2.storePacket(0 * Traits::ResPacketSize, R0);
r2.storePacket(1 * Traits::ResPacketSize, R1);
r2.storePacket(2 * Traits::ResPacketSize, R2);
R0 = r3.template loadPacket<ResPacket>(0 * Traits::ResPacketSize);
R1 = r3.template loadPacket<ResPacket>(1 * Traits::ResPacketSize);
R2 = r3.template loadPacket<ResPacket>(2 * Traits::ResPacketSize);
traits.acc(C3, alphav, R0);
traits.acc(C11, alphav, R1);
traits.acc(C19, alphav, R2);
r3.storePacket(0 * Traits::ResPacketSize, R0);
r3.storePacket(1 * Traits::ResPacketSize, R1);
r3.storePacket(2 * Traits::ResPacketSize, R2);
R0 = r4.template loadPacket<ResPacket>(0 * Traits::ResPacketSize);
R1 = r4.template loadPacket<ResPacket>(1 * Traits::ResPacketSize);
R2 = r4.template loadPacket<ResPacket>(2 * Traits::ResPacketSize);
traits.acc(C4, alphav, R0);
traits.acc(C12, alphav, R1);
traits.acc(C20, alphav, R2);
r4.storePacket(0 * Traits::ResPacketSize, R0);
r4.storePacket(1 * Traits::ResPacketSize, R1);
r4.storePacket(2 * Traits::ResPacketSize, R2);
R0 = r5.template loadPacket<ResPacket>(0 * Traits::ResPacketSize);
R1 = r5.template loadPacket<ResPacket>(1 * Traits::ResPacketSize);
R2 = r5.template loadPacket<ResPacket>(2 * Traits::ResPacketSize);
traits.acc(C5, alphav, R0);
traits.acc(C13, alphav, R1);
traits.acc(C21, alphav, R2);
r5.storePacket(0 * Traits::ResPacketSize, R0);
r5.storePacket(1 * Traits::ResPacketSize, R1);
r5.storePacket(2 * Traits::ResPacketSize, R2);
R0 = r6.template loadPacket<ResPacket>(0 * Traits::ResPacketSize);
R1 = r6.template loadPacket<ResPacket>(1 * Traits::ResPacketSize);
R2 = r6.template loadPacket<ResPacket>(2 * Traits::ResPacketSize);
traits.acc(C6, alphav, R0);
traits.acc(C14, alphav, R1);
traits.acc(C22, alphav, R2);
r6.storePacket(0 * Traits::ResPacketSize, R0);
r6.storePacket(1 * Traits::ResPacketSize, R1);
r6.storePacket(2 * Traits::ResPacketSize, R2);
R0 = r7.template loadPacket<ResPacket>(0 * Traits::ResPacketSize);
R1 = r7.template loadPacket<ResPacket>(1 * Traits::ResPacketSize);
R2 = r7.template loadPacket<ResPacket>(2 * Traits::ResPacketSize);
traits.acc(C7, alphav, R0);
traits.acc(C15, alphav, R1);
traits.acc(C23, alphav, R2);
r7.storePacket(0 * Traits::ResPacketSize, R0);
r7.storePacket(1 * Traits::ResPacketSize, R1);
r7.storePacket(2 * Traits::ResPacketSize, R2);
}
}
for(Index j2=packet_cols8; j2<packet_cols4; j2+=4)
{
for(Index i=i1; i<actual_panel_end; i+=3*LhsProgress)
{
@ -1473,14 +1809,14 @@ void gebp_kernel<LhsScalar,RhsScalar,Index,DataMapper,mr,nr,ConjugateLhs,Conjuga
r3.prefetch(0);
// performs "inner" products
const RhsScalar* blB = &blockB[j2*strideB+offsetB*nr];
const RhsScalar* blB = &blockB[j2*strideB+offsetB*4];
prefetch(&blB[0]);
LhsPacket A0, A1;
for(Index k=0; k<peeled_kc; k+=pk)
{
EIGEN_ASM_COMMENT("begin gebp micro kernel 3pX4");
// 15 registers are taken (12 for acc, 2 for lhs).
// 15 registers are taken (12 for acc, 3 for lhs).
RhsPanel15 rhs_panel;
RhsPacket T0;
LhsPacket A2;
@ -1689,7 +2025,170 @@ void gebp_kernel<LhsScalar,RhsScalar,Index,DataMapper,mr,nr,ConjugateLhs,Conjuga
for(Index i1=peeled_mc3; i1<peeled_mc2; i1+=actual_panel_rows)
{
Index actual_panel_end = (std::min)(i1+actual_panel_rows, peeled_mc2);
for(Index j2=0; j2<packet_cols4; j2+=nr)
for(Index j2=0; j2<packet_cols8; j2+=8)
{
for(Index i=i1; i<actual_panel_end; i+=2*LhsProgress)
{
const LhsScalar* blA = &blockA[i*strideA+offsetA*(2*Traits::LhsProgress)];
prefetch(&blA[0]);
AccPacket C0, C1, C2, C3, C4, C5, C6, C7,
C8, C9, C10, C11, C12, C13, C14, C15;
traits.initAcc(C0); traits.initAcc(C1); traits.initAcc(C2); traits.initAcc(C3);
traits.initAcc(C4); traits.initAcc(C5); traits.initAcc(C6); traits.initAcc(C7);
traits.initAcc(C8); traits.initAcc(C9); traits.initAcc(C10); traits.initAcc(C11);
traits.initAcc(C12); traits.initAcc(C13); traits.initAcc(C14); traits.initAcc(C15);
LinearMapper r0 = res.getLinearMapper(i, j2 + 0);
LinearMapper r1 = res.getLinearMapper(i, j2 + 1);
LinearMapper r2 = res.getLinearMapper(i, j2 + 2);
LinearMapper r3 = res.getLinearMapper(i, j2 + 3);
LinearMapper r4 = res.getLinearMapper(i, j2 + 4);
LinearMapper r5 = res.getLinearMapper(i, j2 + 5);
LinearMapper r6 = res.getLinearMapper(i, j2 + 6);
LinearMapper r7 = res.getLinearMapper(i, j2 + 7);
r0.prefetch(prefetch_res_offset);
r1.prefetch(prefetch_res_offset);
r2.prefetch(prefetch_res_offset);
r3.prefetch(prefetch_res_offset);
r4.prefetch(prefetch_res_offset);
r5.prefetch(prefetch_res_offset);
r6.prefetch(prefetch_res_offset);
r7.prefetch(prefetch_res_offset);
const RhsScalar* blB = &blockB[j2*strideB+offsetB*8];
prefetch(&blB[0]);
LhsPacket A0, A1;
for(Index k=0; k<peeled_kc; k+=pk)
{
RhsPacketx4 rhs_panel;
RhsPacket T0;
// NOTE: the begin/end asm comments below work around bug 935!
// but they are not enough for gcc>=6 without FMA (bug 1637)
#if EIGEN_GNUC_AT_LEAST(6,0) && defined(EIGEN_VECTORIZE_SSE)
#define EIGEN_GEBP_2Px8_SPILLING_WORKAROUND __asm__ ("" : [a0] "+x,m" (A0),[a1] "+x,m" (A1));
#else
#define EIGEN_GEBP_2Px8_SPILLING_WORKAROUND
#endif
#define EIGEN_GEBGP_ONESTEP(K) \
do { \
EIGEN_ASM_COMMENT("begin step of gebp micro kernel 2pX8"); \
traits.loadLhs(&blA[(0 + 2 * K) * LhsProgress], A0); \
traits.loadLhs(&blA[(1 + 2 * K) * LhsProgress], A1); \
traits.loadRhs(&blB[(0 + 8 * K) * RhsProgress], rhs_panel); \
traits.madd(A0, rhs_panel, C0, T0, fix<0>); \
traits.madd(A1, rhs_panel, C8, T0, fix<0>); \
traits.updateRhs(&blB[(1 + 8 * K) * RhsProgress], rhs_panel); \
traits.madd(A0, rhs_panel, C1, T0, fix<1>); \
traits.madd(A1, rhs_panel, C9, T0, fix<1>); \
traits.updateRhs(&blB[(2 + 8 * K) * RhsProgress], rhs_panel); \
traits.madd(A0, rhs_panel, C2, T0, fix<2>); \
traits.madd(A1, rhs_panel, C10, T0, fix<2>); \
traits.updateRhs(&blB[(3 + 8 * K) * RhsProgress], rhs_panel); \
traits.madd(A0, rhs_panel, C3, T0, fix<3>); \
traits.madd(A1, rhs_panel, C11, T0, fix<3>); \
traits.loadRhs(&blB[(4 + 8 * K) * RhsProgress], rhs_panel); \
traits.madd(A0, rhs_panel, C4, T0, fix<0>); \
traits.madd(A1, rhs_panel, C12, T0, fix<0>); \
traits.updateRhs(&blB[(5 + 8 * K) * RhsProgress], rhs_panel); \
traits.madd(A0, rhs_panel, C5, T0, fix<1>); \
traits.madd(A1, rhs_panel, C13, T0, fix<1>); \
traits.updateRhs(&blB[(6 + 8 * K) * RhsProgress], rhs_panel); \
traits.madd(A0, rhs_panel, C6, T0, fix<2>); \
traits.madd(A1, rhs_panel, C14, T0, fix<2>); \
traits.updateRhs(&blB[(7 + 8 * K) * RhsProgress], rhs_panel); \
traits.madd(A0, rhs_panel, C7, T0, fix<3>); \
traits.madd(A1, rhs_panel, C15, T0, fix<3>); \
EIGEN_GEBP_2Px8_SPILLING_WORKAROUND \
EIGEN_ASM_COMMENT("end step of gebp micro kernel 2pX8"); \
} while (false)
EIGEN_ASM_COMMENT("begin gebp micro kernel 2pX8");
EIGEN_GEBGP_ONESTEP(0);
EIGEN_GEBGP_ONESTEP(1);
EIGEN_GEBGP_ONESTEP(2);
EIGEN_GEBGP_ONESTEP(3);
EIGEN_GEBGP_ONESTEP(4);
EIGEN_GEBGP_ONESTEP(5);
EIGEN_GEBGP_ONESTEP(6);
EIGEN_GEBGP_ONESTEP(7);
blB += pk*8*RhsProgress;
blA += pk*(2*Traits::LhsProgress);
EIGEN_ASM_COMMENT("end gebp micro kernel 2pX8");
}
// process remaining peeled loop
for(Index k=peeled_kc; k<depth; k++)
{
RhsPacketx4 rhs_panel;
RhsPacket T0;
EIGEN_GEBGP_ONESTEP(0);
blB += 8*RhsProgress;
blA += 2*Traits::LhsProgress;
}
#undef EIGEN_GEBGP_ONESTEP
ResPacket R0, R1, R2, R3;
ResPacket alphav = pset1<ResPacket>(alpha);
R0 = r0.template loadPacket<ResPacket>(0 * Traits::ResPacketSize);
R1 = r0.template loadPacket<ResPacket>(1 * Traits::ResPacketSize);
R2 = r1.template loadPacket<ResPacket>(0 * Traits::ResPacketSize);
R3 = r1.template loadPacket<ResPacket>(1 * Traits::ResPacketSize);
traits.acc(C0, alphav, R0);
traits.acc(C8, alphav, R1);
traits.acc(C1, alphav, R2);
traits.acc(C9, alphav, R3);
r0.storePacket(0 * Traits::ResPacketSize, R0);
r0.storePacket(1 * Traits::ResPacketSize, R1);
r1.storePacket(0 * Traits::ResPacketSize, R2);
r1.storePacket(1 * Traits::ResPacketSize, R3);
R0 = r2.template loadPacket<ResPacket>(0 * Traits::ResPacketSize);
R1 = r2.template loadPacket<ResPacket>(1 * Traits::ResPacketSize);
R2 = r3.template loadPacket<ResPacket>(0 * Traits::ResPacketSize);
R3 = r3.template loadPacket<ResPacket>(1 * Traits::ResPacketSize);
traits.acc(C2, alphav, R0);
traits.acc(C10, alphav, R1);
traits.acc(C3, alphav, R2);
traits.acc(C11, alphav, R3);
r2.storePacket(0 * Traits::ResPacketSize, R0);
r2.storePacket(1 * Traits::ResPacketSize, R1);
r3.storePacket(0 * Traits::ResPacketSize, R2);
r3.storePacket(1 * Traits::ResPacketSize, R3);
R0 = r4.template loadPacket<ResPacket>(0 * Traits::ResPacketSize);
R1 = r4.template loadPacket<ResPacket>(1 * Traits::ResPacketSize);
R2 = r5.template loadPacket<ResPacket>(0 * Traits::ResPacketSize);
R3 = r5.template loadPacket<ResPacket>(1 * Traits::ResPacketSize);
traits.acc(C4, alphav, R0);
traits.acc(C12, alphav, R1);
traits.acc(C5, alphav, R2);
traits.acc(C13, alphav, R3);
r4.storePacket(0 * Traits::ResPacketSize, R0);
r4.storePacket(1 * Traits::ResPacketSize, R1);
r5.storePacket(0 * Traits::ResPacketSize, R2);
r5.storePacket(1 * Traits::ResPacketSize, R3);
R0 = r6.template loadPacket<ResPacket>(0 * Traits::ResPacketSize);
R1 = r6.template loadPacket<ResPacket>(1 * Traits::ResPacketSize);
R2 = r7.template loadPacket<ResPacket>(0 * Traits::ResPacketSize);
R3 = r7.template loadPacket<ResPacket>(1 * Traits::ResPacketSize);
traits.acc(C6, alphav, R0);
traits.acc(C14, alphav, R1);
traits.acc(C7, alphav, R2);
traits.acc(C15, alphav, R3);
r6.storePacket(0 * Traits::ResPacketSize, R0);
r6.storePacket(1 * Traits::ResPacketSize, R1);
r7.storePacket(0 * Traits::ResPacketSize, R2);
r7.storePacket(1 * Traits::ResPacketSize, R3);
}
}
for(Index j2=packet_cols8; j2<packet_cols4; j2+=4)
{
for(Index i=i1; i<actual_panel_end; i+=2*LhsProgress)
{
@ -1717,7 +2216,7 @@ void gebp_kernel<LhsScalar,RhsScalar,Index,DataMapper,mr,nr,ConjugateLhs,Conjuga
r3.prefetch(prefetch_res_offset);
// performs "inner" products
const RhsScalar* blB = &blockB[j2*strideB+offsetB*nr];
const RhsScalar* blB = &blockB[j2*strideB+offsetB*4];
prefetch(&blB[0]);
LhsPacket A0, A1;
@ -1907,14 +2406,66 @@ void gebp_kernel<LhsScalar,RhsScalar,Index,DataMapper,mr,nr,ConjugateLhs,Conjuga
if(peeled_mc_quarter<rows)
{
// loop on each panel of the rhs
for(Index j2=0; j2<packet_cols4; j2+=nr)
for(Index j2=0; j2<packet_cols8; j2+=8)
{
// loop on each row of the lhs (1*LhsProgress x depth)
for(Index i=peeled_mc_quarter; i<rows; i+=1)
{
const LhsScalar* blA = &blockA[i*strideA+offsetA];
prefetch(&blA[0]);
const RhsScalar* blB = &blockB[j2*strideB+offsetB*nr];
// gets a 1 x 1 res block as registers
ResScalar C0(0),C1(0),C2(0),C3(0),C4(0),C5(0),C6(0),C7(0);
const RhsScalar* blB = &blockB[j2*strideB+offsetB*8];
for(Index k=0; k<depth; k++)
{
LhsScalar A0 = blA[k];
RhsScalar B_0;
B_0 = blB[0];
C0 = cj.pmadd(A0, B_0, C0);
B_0 = blB[1];
C1 = cj.pmadd(A0, B_0, C1);
B_0 = blB[2];
C2 = cj.pmadd(A0, B_0, C2);
B_0 = blB[3];
C3 = cj.pmadd(A0, B_0, C3);
B_0 = blB[4];
C4 = cj.pmadd(A0, B_0, C4);
B_0 = blB[5];
C5 = cj.pmadd(A0, B_0, C5);
B_0 = blB[6];
C6 = cj.pmadd(A0, B_0, C6);
B_0 = blB[7];
C7 = cj.pmadd(A0, B_0, C7);
blB += 8;
}
res(i, j2 + 0) += alpha * C0;
res(i, j2 + 1) += alpha * C1;
res(i, j2 + 2) += alpha * C2;
res(i, j2 + 3) += alpha * C3;
res(i, j2 + 4) += alpha * C4;
res(i, j2 + 5) += alpha * C5;
res(i, j2 + 6) += alpha * C6;
res(i, j2 + 7) += alpha * C7;
}
}
for(Index j2=packet_cols8; j2<packet_cols4; j2+=4)
{
// loop on each row of the lhs (1*LhsProgress x depth)
for(Index i=peeled_mc_quarter; i<rows; i+=1)
{
const LhsScalar* blA = &blockA[i*strideA+offsetA];
prefetch(&blA[0]);
const RhsScalar* blB = &blockB[j2*strideB+offsetB*4];
// If LhsProgress is 8 or 16, it assumes that there is a
// half or quarter packet, respectively, of the same size as
@ -2397,51 +2948,121 @@ EIGEN_DONT_INLINE void gemm_pack_rhs<Scalar, Index, DataMapper, nr, ColMajor, Co
Index packet_cols4 = nr>=4 ? (cols/4) * 4 : 0;
Index count = 0;
const Index peeled_k = (depth/PacketSize)*PacketSize;
// if(nr>=8)
// {
// for(Index j2=0; j2<packet_cols8; j2+=8)
// {
// // skip what we have before
// if(PanelMode) count += 8 * offset;
// const Scalar* b0 = &rhs[(j2+0)*rhsStride];
// const Scalar* b1 = &rhs[(j2+1)*rhsStride];
// const Scalar* b2 = &rhs[(j2+2)*rhsStride];
// const Scalar* b3 = &rhs[(j2+3)*rhsStride];
// const Scalar* b4 = &rhs[(j2+4)*rhsStride];
// const Scalar* b5 = &rhs[(j2+5)*rhsStride];
// const Scalar* b6 = &rhs[(j2+6)*rhsStride];
// const Scalar* b7 = &rhs[(j2+7)*rhsStride];
// Index k=0;
// if(PacketSize==8) // TODO enable vectorized transposition for PacketSize==4
// {
// for(; k<peeled_k; k+=PacketSize) {
// PacketBlock<Packet> kernel;
// for (int p = 0; p < PacketSize; ++p) {
// kernel.packet[p] = ploadu<Packet>(&rhs[(j2+p)*rhsStride+k]);
// }
// ptranspose(kernel);
// for (int p = 0; p < PacketSize; ++p) {
// pstoreu(blockB+count, cj.pconj(kernel.packet[p]));
// count+=PacketSize;
// }
// }
// }
// for(; k<depth; k++)
// {
// blockB[count+0] = cj(b0[k]);
// blockB[count+1] = cj(b1[k]);
// blockB[count+2] = cj(b2[k]);
// blockB[count+3] = cj(b3[k]);
// blockB[count+4] = cj(b4[k]);
// blockB[count+5] = cj(b5[k]);
// blockB[count+6] = cj(b6[k]);
// blockB[count+7] = cj(b7[k]);
// count += 8;
// }
// // skip what we have after
// if(PanelMode) count += 8 * (stride-offset-depth);
// }
// }
if(nr>=8)
{
for(Index j2=0; j2<packet_cols8; j2+=8)
{
// skip what we have before
if(PanelMode) count += 8 * offset;
const LinearMapper dm0 = rhs.getLinearMapper(0, j2 + 0);
const LinearMapper dm1 = rhs.getLinearMapper(0, j2 + 1);
const LinearMapper dm2 = rhs.getLinearMapper(0, j2 + 2);
const LinearMapper dm3 = rhs.getLinearMapper(0, j2 + 3);
const LinearMapper dm4 = rhs.getLinearMapper(0, j2 + 4);
const LinearMapper dm5 = rhs.getLinearMapper(0, j2 + 5);
const LinearMapper dm6 = rhs.getLinearMapper(0, j2 + 6);
const LinearMapper dm7 = rhs.getLinearMapper(0, j2 + 7);
Index k = 0;
if (PacketSize % 2 == 0 && PacketSize <= 8) // 2 4 8
{
for (; k < peeled_k; k += PacketSize)
{
if (PacketSize == 2)
{
PacketBlock<Packet, PacketSize==2 ?2:PacketSize> kernel0, kernel1, kernel2, kernel3;
kernel0.packet[0%PacketSize] = dm0.template loadPacket<Packet>(k);
kernel0.packet[1%PacketSize] = dm1.template loadPacket<Packet>(k);
kernel1.packet[0%PacketSize] = dm2.template loadPacket<Packet>(k);
kernel1.packet[1%PacketSize] = dm3.template loadPacket<Packet>(k);
kernel2.packet[0%PacketSize] = dm4.template loadPacket<Packet>(k);
kernel2.packet[1%PacketSize] = dm5.template loadPacket<Packet>(k);
kernel3.packet[0%PacketSize] = dm6.template loadPacket<Packet>(k);
kernel3.packet[1%PacketSize] = dm7.template loadPacket<Packet>(k);
ptranspose(kernel0);
ptranspose(kernel1);
ptranspose(kernel2);
ptranspose(kernel3);
pstoreu(blockB + count + 0 * PacketSize, cj.pconj(kernel0.packet[0 % PacketSize]));
pstoreu(blockB + count + 1 * PacketSize, cj.pconj(kernel1.packet[0 % PacketSize]));
pstoreu(blockB + count + 2 * PacketSize, cj.pconj(kernel2.packet[0 % PacketSize]));
pstoreu(blockB + count + 3 * PacketSize, cj.pconj(kernel3.packet[0 % PacketSize]));
pstoreu(blockB + count + 4 * PacketSize, cj.pconj(kernel0.packet[1 % PacketSize]));
pstoreu(blockB + count + 5 * PacketSize, cj.pconj(kernel1.packet[1 % PacketSize]));
pstoreu(blockB + count + 6 * PacketSize, cj.pconj(kernel2.packet[1 % PacketSize]));
pstoreu(blockB + count + 7 * PacketSize, cj.pconj(kernel3.packet[1 % PacketSize]));
count+=8*PacketSize;
}
else if (PacketSize == 4)
{
PacketBlock<Packet, PacketSize == 4?4:PacketSize> kernel0, kernel1;
kernel0.packet[0%PacketSize] = dm0.template loadPacket<Packet>(k);
kernel0.packet[1%PacketSize] = dm1.template loadPacket<Packet>(k);
kernel0.packet[2%PacketSize] = dm2.template loadPacket<Packet>(k);
kernel0.packet[3%PacketSize] = dm3.template loadPacket<Packet>(k);
kernel1.packet[0%PacketSize] = dm4.template loadPacket<Packet>(k);
kernel1.packet[1%PacketSize] = dm5.template loadPacket<Packet>(k);
kernel1.packet[2%PacketSize] = dm6.template loadPacket<Packet>(k);
kernel1.packet[3%PacketSize] = dm7.template loadPacket<Packet>(k);
ptranspose(kernel0);
ptranspose(kernel1);
pstoreu(blockB+count+0*PacketSize, cj.pconj(kernel0.packet[0%PacketSize]));
pstoreu(blockB+count+1*PacketSize, cj.pconj(kernel1.packet[0%PacketSize]));
pstoreu(blockB+count+2*PacketSize, cj.pconj(kernel0.packet[1%PacketSize]));
pstoreu(blockB+count+3*PacketSize, cj.pconj(kernel1.packet[1%PacketSize]));
pstoreu(blockB+count+4*PacketSize, cj.pconj(kernel0.packet[2%PacketSize]));
pstoreu(blockB+count+5*PacketSize, cj.pconj(kernel1.packet[2%PacketSize]));
pstoreu(blockB+count+6*PacketSize, cj.pconj(kernel0.packet[3%PacketSize]));
pstoreu(blockB+count+7*PacketSize, cj.pconj(kernel1.packet[3%PacketSize]));
count+=8*PacketSize;
}
else if (PacketSize == 8)
{
PacketBlock<Packet, PacketSize==8?8:PacketSize> kernel0;
kernel0.packet[0%PacketSize] = dm0.template loadPacket<Packet>(k);
kernel0.packet[1%PacketSize] = dm1.template loadPacket<Packet>(k);
kernel0.packet[2%PacketSize] = dm2.template loadPacket<Packet>(k);
kernel0.packet[3%PacketSize] = dm3.template loadPacket<Packet>(k);
kernel0.packet[4%PacketSize] = dm4.template loadPacket<Packet>(k);
kernel0.packet[5%PacketSize] = dm5.template loadPacket<Packet>(k);
kernel0.packet[6%PacketSize] = dm6.template loadPacket<Packet>(k);
kernel0.packet[7%PacketSize] = dm7.template loadPacket<Packet>(k);
ptranspose(kernel0);
pstoreu(blockB+count+0*PacketSize, cj.pconj(kernel0.packet[0%PacketSize]));
pstoreu(blockB+count+1*PacketSize, cj.pconj(kernel0.packet[1%PacketSize]));
pstoreu(blockB+count+2*PacketSize, cj.pconj(kernel0.packet[2%PacketSize]));
pstoreu(blockB+count+3*PacketSize, cj.pconj(kernel0.packet[3%PacketSize]));
pstoreu(blockB+count+4*PacketSize, cj.pconj(kernel0.packet[4%PacketSize]));
pstoreu(blockB+count+5*PacketSize, cj.pconj(kernel0.packet[5%PacketSize]));
pstoreu(blockB+count+6*PacketSize, cj.pconj(kernel0.packet[6%PacketSize]));
pstoreu(blockB+count+7*PacketSize, cj.pconj(kernel0.packet[7%PacketSize]));
count+=8*PacketSize;
}
}
}
for(; k<depth; k++)
{
blockB[count+0] = cj(dm0(k));
blockB[count+1] = cj(dm1(k));
blockB[count+2] = cj(dm2(k));
blockB[count+3] = cj(dm3(k));
blockB[count+4] = cj(dm4(k));
blockB[count+5] = cj(dm5(k));
blockB[count+6] = cj(dm6(k));
blockB[count+7] = cj(dm7(k));
count += 8;
}
// skip what we have after
if(PanelMode) count += 8 * (stride-offset-depth);
}
}
if(nr>=4)
{
@ -2522,39 +3143,41 @@ struct gemm_pack_rhs<Scalar, Index, DataMapper, nr, RowMajor, Conjugate, PanelMo
Index packet_cols4 = nr>=4 ? (cols/4) * 4 : 0;
Index count = 0;
// if(nr>=8)
// {
// for(Index j2=0; j2<packet_cols8; j2+=8)
// {
// // skip what we have before
// if(PanelMode) count += 8 * offset;
// for(Index k=0; k<depth; k++)
// {
// if (PacketSize==8) {
// Packet A = ploadu<Packet>(&rhs[k*rhsStride + j2]);
// pstoreu(blockB+count, cj.pconj(A));
// } else if (PacketSize==4) {
// Packet A = ploadu<Packet>(&rhs[k*rhsStride + j2]);
// Packet B = ploadu<Packet>(&rhs[k*rhsStride + j2 + PacketSize]);
// pstoreu(blockB+count, cj.pconj(A));
// pstoreu(blockB+count+PacketSize, cj.pconj(B));
// } else {
// const Scalar* b0 = &rhs[k*rhsStride + j2];
// blockB[count+0] = cj(b0[0]);
// blockB[count+1] = cj(b0[1]);
// blockB[count+2] = cj(b0[2]);
// blockB[count+3] = cj(b0[3]);
// blockB[count+4] = cj(b0[4]);
// blockB[count+5] = cj(b0[5]);
// blockB[count+6] = cj(b0[6]);
// blockB[count+7] = cj(b0[7]);
// }
// count += 8;
// }
// // skip what we have after
// if(PanelMode) count += 8 * (stride-offset-depth);
// }
// }
if(nr>=8)
{
for(Index j2=0; j2<packet_cols8; j2+=8)
{
// skip what we have before
if(PanelMode) count += 8 * offset;
for(Index k=0; k<depth; k++)
{
if (PacketSize==8) {
Packet A = rhs.template loadPacket<Packet>(k, j2);
pstoreu(blockB+count, cj.pconj(A));
count += PacketSize;
} else if (PacketSize==4) {
Packet A = rhs.template loadPacket<Packet>(k, j2);
Packet B = rhs.template loadPacket<Packet>(k, j2 + 4);
pstoreu(blockB+count, cj.pconj(A));
pstoreu(blockB+count+PacketSize, cj.pconj(B));
count += 2*PacketSize;
} else {
const LinearMapper dm0 = rhs.getLinearMapper(k, j2);
blockB[count+0] = cj(dm0(0));
blockB[count+1] = cj(dm0(1));
blockB[count+2] = cj(dm0(2));
blockB[count+3] = cj(dm0(3));
blockB[count+4] = cj(dm0(4));
blockB[count+5] = cj(dm0(5));
blockB[count+6] = cj(dm0(6));
blockB[count+7] = cj(dm0(7));
count += 8;
}
}
// skip what we have after
if(PanelMode) count += 8 * (stride-offset-depth);
}
}
if(nr>=4)
{
for(Index j2=packet_cols8; j2<packet_cols4; j2+=4)