mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-08-10 18:59:01 +08:00
Use 3px8/2px8/1px8/1x8 gebp_kernel on arm64-neon
This commit is contained in:
parent
7b2901e2aa
commit
23299632c2
@ -49,7 +49,9 @@ struct gebp_traits <float,float,false,false,Architecture::NEON,GEBPPacketFull>
|
||||
{
|
||||
typedef float RhsPacket;
|
||||
typedef float32x4_t RhsPacketx4;
|
||||
|
||||
enum {
|
||||
nr = 8
|
||||
};
|
||||
EIGEN_STRONG_INLINE void loadRhs(const RhsScalar* b, RhsPacket& dest) const
|
||||
{
|
||||
dest = *b;
|
||||
@ -77,7 +79,6 @@ struct gebp_traits <float,float,false,false,Architecture::NEON,GEBPPacketFull>
|
||||
{
|
||||
c = vfmaq_n_f32(c, a, b);
|
||||
}
|
||||
|
||||
// NOTE: Template parameter inference failed when compiled with Android NDK:
|
||||
// "candidate template ignored: could not match 'FixedInt<N>' against 'Eigen::internal::FixedInt<0>".
|
||||
|
||||
@ -94,9 +95,10 @@ struct gebp_traits <float,float,false,false,Architecture::NEON,GEBPPacketFull>
|
||||
template<int LaneID>
|
||||
EIGEN_STRONG_INLINE void madd_helper(const LhsPacket& a, const RhsPacketx4& b, AccPacket& c) const
|
||||
{
|
||||
#if EIGEN_COMP_GNUC_STRICT && !(EIGEN_GNUC_AT_LEAST(9,0))
|
||||
// workaround gcc issue https://gcc.gnu.org/bugzilla/show_bug.cgi?id=89101
|
||||
// vfmaq_laneq_f32 is implemented through a costly dup
|
||||
#if EIGEN_COMP_GNUC_STRICT
|
||||
// 1. workaround gcc issue https://gcc.gnu.org/bugzilla/show_bug.cgi?id=89101
|
||||
// vfmaq_laneq_f32 is implemented through a costly dup, which was fixed in gcc9
|
||||
// 2. workaround the gcc register split problem on arm64-neon
|
||||
if(LaneID==0) asm("fmla %0.4s, %1.4s, %2.s[0]\n" : "+w" (c) : "w" (a), "w" (b) : );
|
||||
else if(LaneID==1) asm("fmla %0.4s, %1.4s, %2.s[1]\n" : "+w" (c) : "w" (a), "w" (b) : );
|
||||
else if(LaneID==2) asm("fmla %0.4s, %1.4s, %2.s[2]\n" : "+w" (c) : "w" (a), "w" (b) : );
|
||||
@ -113,7 +115,9 @@ struct gebp_traits <double,double,false,false,Architecture::NEON>
|
||||
: gebp_traits<double,double,false,false,Architecture::Generic>
|
||||
{
|
||||
typedef double RhsPacket;
|
||||
|
||||
enum {
|
||||
nr = 8
|
||||
};
|
||||
struct RhsPacketx4 {
|
||||
float64x2_t B_0, B_1;
|
||||
};
|
||||
@ -163,9 +167,10 @@ struct gebp_traits <double,double,false,false,Architecture::NEON>
|
||||
template <int LaneID>
|
||||
EIGEN_STRONG_INLINE void madd_helper(const LhsPacket& a, const RhsPacketx4& b, AccPacket& c) const
|
||||
{
|
||||
#if EIGEN_COMP_GNUC_STRICT && !(EIGEN_GNUC_AT_LEAST(9,0))
|
||||
// workaround gcc issue https://gcc.gnu.org/bugzilla/show_bug.cgi?id=89101
|
||||
// vfmaq_laneq_f64 is implemented through a costly dup
|
||||
#if EIGEN_COMP_GNUC_STRICT
|
||||
// 1. workaround gcc issue https://gcc.gnu.org/bugzilla/show_bug.cgi?id=89101
|
||||
// vfmaq_laneq_f64 is implemented through a costly dup, which was fixed in gcc9
|
||||
// 2. workaround the gcc register split problem on arm64-neon
|
||||
if(LaneID==0) asm("fmla %0.2d, %1.2d, %2.d[0]\n" : "+w" (c) : "w" (a), "w" (b.B_0) : );
|
||||
else if(LaneID==1) asm("fmla %0.2d, %1.2d, %2.d[1]\n" : "+w" (c) : "w" (a), "w" (b.B_0) : );
|
||||
else if(LaneID==2) asm("fmla %0.2d, %1.2d, %2.d[0]\n" : "+w" (c) : "w" (a), "w" (b.B_1) : );
|
||||
@ -179,6 +184,77 @@ struct gebp_traits <double,double,false,false,Architecture::NEON>
|
||||
}
|
||||
};
|
||||
|
||||
#if EIGEN_HAS_ARM64_FP16_VECTOR_ARITHMETIC
|
||||
|
||||
template<>
|
||||
struct gebp_traits <half,half,false,false,Architecture::NEON>
|
||||
: gebp_traits<half,half,false,false,Architecture::Generic>
|
||||
{
|
||||
typedef half RhsPacket;
|
||||
typedef float16x4_t RhsPacketx4;
|
||||
typedef float16x4_t PacketHalf;
|
||||
enum {
|
||||
nr = 8
|
||||
};
|
||||
EIGEN_STRONG_INLINE void loadRhs(const RhsScalar* b, RhsPacket& dest) const
|
||||
{
|
||||
dest = *b;
|
||||
}
|
||||
|
||||
EIGEN_STRONG_INLINE void loadRhs(const RhsScalar* b, RhsPacketx4& dest) const
|
||||
{
|
||||
dest = vld1_f16((const __fp16 *)b);
|
||||
}
|
||||
|
||||
EIGEN_STRONG_INLINE void updateRhs(const RhsScalar* b, RhsPacket& dest) const
|
||||
{
|
||||
dest = *b;
|
||||
}
|
||||
|
||||
EIGEN_STRONG_INLINE void updateRhs(const RhsScalar*, RhsPacketx4&) const
|
||||
{}
|
||||
|
||||
EIGEN_STRONG_INLINE void loadRhsQuad(const RhsScalar* b, RhsPacket& dest) const
|
||||
{
|
||||
loadRhs(b,dest);
|
||||
}
|
||||
|
||||
EIGEN_STRONG_INLINE void madd(const LhsPacket& a, const RhsPacket& b, AccPacket& c, RhsPacket& /*tmp*/, const FixedInt<0>&) const
|
||||
{
|
||||
c = vfmaq_n_f16(c, a, b);
|
||||
}
|
||||
EIGEN_STRONG_INLINE void madd(const PacketHalf& a, const RhsPacket& b, PacketHalf& c, RhsPacket& /*tmp*/, const FixedInt<0>&) const
|
||||
{
|
||||
c = vfma_n_f16(c, a, b);
|
||||
}
|
||||
|
||||
// NOTE: Template parameter inference failed when compiled with Android NDK:
|
||||
// "candidate template ignored: could not match 'FixedInt<N>' against 'Eigen::internal::FixedInt<0>".
|
||||
EIGEN_STRONG_INLINE void madd(const LhsPacket& a, const RhsPacketx4& b, AccPacket& c, RhsPacket& /*tmp*/, const FixedInt<0>&) const
|
||||
{ madd_helper<0>(a, b, c); }
|
||||
EIGEN_STRONG_INLINE void madd(const LhsPacket& a, const RhsPacketx4& b, AccPacket& c, RhsPacket& /*tmp*/, const FixedInt<1>&) const
|
||||
{ madd_helper<1>(a, b, c); }
|
||||
EIGEN_STRONG_INLINE void madd(const LhsPacket& a, const RhsPacketx4& b, AccPacket& c, RhsPacket& /*tmp*/, const FixedInt<2>&) const
|
||||
{ madd_helper<2>(a, b, c); }
|
||||
EIGEN_STRONG_INLINE void madd(const LhsPacket& a, const RhsPacketx4& b, AccPacket& c, RhsPacket& /*tmp*/, const FixedInt<3>&) const
|
||||
{ madd_helper<3>(a, b, c); }
|
||||
private:
|
||||
template<int LaneID>
|
||||
EIGEN_STRONG_INLINE void madd_helper(const LhsPacket& a, const RhsPacketx4& b, AccPacket& c) const
|
||||
{
|
||||
#if EIGEN_COMP_GNUC_STRICT
|
||||
// 1. vfmaq_lane_f16 is implemented through a costly dup
|
||||
// 2. workaround the gcc register split problem on arm64-neon
|
||||
if(LaneID==0) asm("fmla %0.8h, %1.8h, %2.h[0]\n" : "+w" (c) : "w" (a), "w" (b) : );
|
||||
else if(LaneID==1) asm("fmla %0.8h, %1.8h, %2.h[1]\n" : "+w" (c) : "w" (a), "w" (b) : );
|
||||
else if(LaneID==2) asm("fmla %0.8h, %1.8h, %2.h[2]\n" : "+w" (c) : "w" (a), "w" (b) : );
|
||||
else if(LaneID==3) asm("fmla %0.8h, %1.8h, %2.h[3]\n" : "+w" (c) : "w" (a), "w" (b) : );
|
||||
#else
|
||||
c = vfmaq_lane_f16(c, a, b, LaneID);
|
||||
#endif
|
||||
}
|
||||
};
|
||||
#endif // EIGEN_HAS_ARM64_FP16_VECTOR_ARITHMETIC
|
||||
#endif // EIGEN_ARCH_ARM64
|
||||
|
||||
} // namespace internal
|
||||
|
@ -1070,6 +1070,7 @@ struct gebp_kernel
|
||||
typedef typename Traits::RhsPacketx4 RhsPacketx4;
|
||||
|
||||
typedef typename RhsPanelHelper<RhsPacket, RhsPacketx4, 15>::type RhsPanel15;
|
||||
typedef typename RhsPanelHelper<RhsPacket, RhsPacketx4, 27>::type RhsPanel27;
|
||||
|
||||
typedef gebp_traits<RhsScalar,LhsScalar,ConjugateRhs,ConjugateLhs,Architecture::Target> SwappedTraits;
|
||||
|
||||
@ -1215,13 +1216,135 @@ struct lhs_process_one_packet
|
||||
int prefetch_res_offset, Index peeled_kc, Index pk, Index cols, Index depth, Index packet_cols4)
|
||||
{
|
||||
GEBPTraits traits;
|
||||
|
||||
Index packet_cols8 = nr>=8 ? (cols/8) * 8 : 0;
|
||||
// loops on each largest micro horizontal panel of lhs
|
||||
// (LhsProgress x depth)
|
||||
for(Index i=peelStart; i<peelEnd; i+=LhsProgress)
|
||||
{
|
||||
for(Index j2=0; j2<packet_cols8; j2+=8)
|
||||
{
|
||||
const LhsScalar* blA = &blockA[i*strideA+offsetA*(LhsProgress)];
|
||||
prefetch(&blA[0]);
|
||||
|
||||
// gets res block as register
|
||||
AccPacket C0, C1, C2, C3, C4, C5, C6, C7;
|
||||
traits.initAcc(C0);
|
||||
traits.initAcc(C1);
|
||||
traits.initAcc(C2);
|
||||
traits.initAcc(C3);
|
||||
traits.initAcc(C4);
|
||||
traits.initAcc(C5);
|
||||
traits.initAcc(C6);
|
||||
traits.initAcc(C7);
|
||||
|
||||
LinearMapper r0 = res.getLinearMapper(i, j2 + 0);
|
||||
LinearMapper r1 = res.getLinearMapper(i, j2 + 1);
|
||||
LinearMapper r2 = res.getLinearMapper(i, j2 + 2);
|
||||
LinearMapper r3 = res.getLinearMapper(i, j2 + 3);
|
||||
LinearMapper r4 = res.getLinearMapper(i, j2 + 4);
|
||||
LinearMapper r5 = res.getLinearMapper(i, j2 + 5);
|
||||
LinearMapper r6 = res.getLinearMapper(i, j2 + 6);
|
||||
LinearMapper r7 = res.getLinearMapper(i, j2 + 7);
|
||||
r0.prefetch(prefetch_res_offset);
|
||||
r1.prefetch(prefetch_res_offset);
|
||||
r2.prefetch(prefetch_res_offset);
|
||||
r3.prefetch(prefetch_res_offset);
|
||||
r4.prefetch(prefetch_res_offset);
|
||||
r5.prefetch(prefetch_res_offset);
|
||||
r6.prefetch(prefetch_res_offset);
|
||||
r7.prefetch(prefetch_res_offset);
|
||||
const RhsScalar* blB = &blockB[j2*strideB+offsetB*8];
|
||||
prefetch(&blB[0]);
|
||||
|
||||
LhsPacket A0;
|
||||
for(Index k=0; k<peeled_kc; k+=pk)
|
||||
{
|
||||
RhsPacketx4 rhs_panel;
|
||||
RhsPacket T0;
|
||||
#define EIGEN_GEBGP_ONESTEP(K) \
|
||||
do { \
|
||||
EIGEN_ASM_COMMENT("begin step of gebp micro kernel 1pX8"); \
|
||||
traits.loadLhs(&blA[(0 + 1 * K) * LhsProgress], A0); \
|
||||
traits.loadRhs(&blB[(0 + 8 * K) * RhsProgress], rhs_panel); \
|
||||
traits.madd(A0, rhs_panel, C0, T0, fix<0>); \
|
||||
traits.updateRhs(&blB[(1 + 8 * K) * RhsProgress], rhs_panel); \
|
||||
traits.madd(A0, rhs_panel, C1, T0, fix<1>); \
|
||||
traits.updateRhs(&blB[(2 + 8 * K) * RhsProgress], rhs_panel); \
|
||||
traits.madd(A0, rhs_panel, C2, T0, fix<2>); \
|
||||
traits.updateRhs(&blB[(3 + 8 * K) * RhsProgress], rhs_panel); \
|
||||
traits.madd(A0, rhs_panel, C3, T0, fix<3>); \
|
||||
traits.loadRhs(&blB[(4 + 8 * K) * RhsProgress], rhs_panel); \
|
||||
traits.madd(A0, rhs_panel, C4, T0, fix<0>); \
|
||||
traits.updateRhs(&blB[(5 + 8 * K) * RhsProgress], rhs_panel); \
|
||||
traits.madd(A0, rhs_panel, C5, T0, fix<1>); \
|
||||
traits.updateRhs(&blB[(6 + 8 * K) * RhsProgress], rhs_panel); \
|
||||
traits.madd(A0, rhs_panel, C6, T0, fix<2>); \
|
||||
traits.updateRhs(&blB[(7 + 8 * K) * RhsProgress], rhs_panel); \
|
||||
traits.madd(A0, rhs_panel, C7, T0, fix<3>); \
|
||||
EIGEN_ASM_COMMENT("end step of gebp micro kernel 1pX8"); \
|
||||
} while (false)
|
||||
|
||||
EIGEN_ASM_COMMENT("begin gebp micro kernel 1pX8");
|
||||
|
||||
EIGEN_GEBGP_ONESTEP(0);
|
||||
EIGEN_GEBGP_ONESTEP(1);
|
||||
EIGEN_GEBGP_ONESTEP(2);
|
||||
EIGEN_GEBGP_ONESTEP(3);
|
||||
EIGEN_GEBGP_ONESTEP(4);
|
||||
EIGEN_GEBGP_ONESTEP(5);
|
||||
EIGEN_GEBGP_ONESTEP(6);
|
||||
EIGEN_GEBGP_ONESTEP(7);
|
||||
|
||||
blB += pk*8*RhsProgress;
|
||||
blA += pk*(1*LhsProgress);
|
||||
|
||||
EIGEN_ASM_COMMENT("end gebp micro kernel 1pX8");
|
||||
}
|
||||
// process remaining peeled loop
|
||||
for(Index k=peeled_kc; k<depth; k++)
|
||||
{
|
||||
RhsPacketx4 rhs_panel;
|
||||
RhsPacket T0;
|
||||
EIGEN_GEBGP_ONESTEP(0);
|
||||
blB += 8*RhsProgress;
|
||||
blA += 1*LhsProgress;
|
||||
}
|
||||
|
||||
#undef EIGEN_GEBGP_ONESTEP
|
||||
|
||||
ResPacket R0, R1;
|
||||
ResPacket alphav = pset1<ResPacket>(alpha);
|
||||
|
||||
R0 = r0.template loadPacket<ResPacket>(0);
|
||||
R1 = r1.template loadPacket<ResPacket>(0);
|
||||
traits.acc(C0, alphav, R0);
|
||||
traits.acc(C1, alphav, R1);
|
||||
r0.storePacket(0, R0);
|
||||
r1.storePacket(0, R1);
|
||||
|
||||
R0 = r2.template loadPacket<ResPacket>(0);
|
||||
R1 = r3.template loadPacket<ResPacket>(0);
|
||||
traits.acc(C2, alphav, R0);
|
||||
traits.acc(C3, alphav, R1);
|
||||
r2.storePacket(0, R0);
|
||||
r3.storePacket(0, R1);
|
||||
|
||||
R0 = r4.template loadPacket<ResPacket>(0);
|
||||
R1 = r5.template loadPacket<ResPacket>(0);
|
||||
traits.acc(C4, alphav, R0);
|
||||
traits.acc(C5, alphav, R1);
|
||||
r4.storePacket(0, R0);
|
||||
r5.storePacket(0, R1);
|
||||
|
||||
R0 = r6.template loadPacket<ResPacket>(0);
|
||||
R1 = r7.template loadPacket<ResPacket>(0);
|
||||
traits.acc(C6, alphav, R0);
|
||||
traits.acc(C7, alphav, R1);
|
||||
r6.storePacket(0, R0);
|
||||
r7.storePacket(0, R1);
|
||||
}
|
||||
// loops on each largest micro vertical panel of rhs (depth * nr)
|
||||
for(Index j2=0; j2<packet_cols4; j2+=nr)
|
||||
for(Index j2=packet_cols8; j2<packet_cols4; j2+=4)
|
||||
{
|
||||
// We select a LhsProgress x nr micro block of res
|
||||
// which is entirely stored into 1 x nr registers.
|
||||
@ -1257,7 +1380,7 @@ struct lhs_process_one_packet
|
||||
r3.prefetch(prefetch_res_offset);
|
||||
|
||||
// performs "inner" products
|
||||
const RhsScalar* blB = &blockB[j2*strideB+offsetB*nr];
|
||||
const RhsScalar* blB = &blockB[j2*strideB+offsetB*4];
|
||||
prefetch(&blB[0]);
|
||||
LhsPacket A0, A1;
|
||||
|
||||
@ -1415,6 +1538,7 @@ void gebp_kernel<LhsScalar,RhsScalar,Index,DataMapper,mr,nr,ConjugateLhs,Conjuga
|
||||
if(strideB==-1) strideB = depth;
|
||||
conj_helper<LhsScalar,RhsScalar,ConjugateLhs,ConjugateRhs> cj;
|
||||
Index packet_cols4 = nr>=4 ? (cols/4) * 4 : 0;
|
||||
Index packet_cols8 = nr>=8 ? (cols/8) * 8 : 0;
|
||||
const Index peeled_mc3 = mr>=3*Traits::LhsProgress ? (rows/(3*LhsProgress))*(3*LhsProgress) : 0;
|
||||
const Index peeled_mc2 = mr>=2*Traits::LhsProgress ? peeled_mc3+((rows-peeled_mc3)/(2*LhsProgress))*(2*LhsProgress) : 0;
|
||||
const Index peeled_mc1 = mr>=1*Traits::LhsProgress ? peeled_mc2+((rows-peeled_mc2)/(1*LhsProgress))*(1*LhsProgress) : 0;
|
||||
@ -1443,7 +1567,219 @@ void gebp_kernel<LhsScalar,RhsScalar,Index,DataMapper,mr,nr,ConjugateLhs,Conjuga
|
||||
for(Index i1=0; i1<peeled_mc3; i1+=actual_panel_rows)
|
||||
{
|
||||
const Index actual_panel_end = (std::min)(i1+actual_panel_rows, peeled_mc3);
|
||||
for(Index j2=0; j2<packet_cols4; j2+=nr)
|
||||
|
||||
// nr >= 8
|
||||
for(Index j2=0; j2<packet_cols8; j2+=8)
|
||||
{
|
||||
for(Index i=i1; i<actual_panel_end; i+=3*LhsProgress)
|
||||
{
|
||||
const LhsScalar* blA = &blockA[i*strideA+offsetA*(3*LhsProgress)];
|
||||
prefetch(&blA[0]);
|
||||
// gets res block as register
|
||||
AccPacket C0, C1, C2, C3, C4, C5, C6, C7,
|
||||
C8, C9, C10, C11, C12, C13, C14, C15,
|
||||
C16, C17, C18, C19, C20, C21, C22, C23;
|
||||
traits.initAcc(C0); traits.initAcc(C1); traits.initAcc(C2); traits.initAcc(C3);
|
||||
traits.initAcc(C4); traits.initAcc(C5); traits.initAcc(C6); traits.initAcc(C7);
|
||||
traits.initAcc(C8); traits.initAcc(C9); traits.initAcc(C10); traits.initAcc(C11);
|
||||
traits.initAcc(C12); traits.initAcc(C13); traits.initAcc(C14); traits.initAcc(C15);
|
||||
traits.initAcc(C16); traits.initAcc(C17); traits.initAcc(C18); traits.initAcc(C19);
|
||||
traits.initAcc(C20); traits.initAcc(C21); traits.initAcc(C22); traits.initAcc(C23);
|
||||
|
||||
LinearMapper r0 = res.getLinearMapper(i, j2 + 0);
|
||||
LinearMapper r1 = res.getLinearMapper(i, j2 + 1);
|
||||
LinearMapper r2 = res.getLinearMapper(i, j2 + 2);
|
||||
LinearMapper r3 = res.getLinearMapper(i, j2 + 3);
|
||||
LinearMapper r4 = res.getLinearMapper(i, j2 + 4);
|
||||
LinearMapper r5 = res.getLinearMapper(i, j2 + 5);
|
||||
LinearMapper r6 = res.getLinearMapper(i, j2 + 6);
|
||||
LinearMapper r7 = res.getLinearMapper(i, j2 + 7);
|
||||
|
||||
r0.prefetch(0);
|
||||
r1.prefetch(0);
|
||||
r2.prefetch(0);
|
||||
r3.prefetch(0);
|
||||
r4.prefetch(0);
|
||||
r5.prefetch(0);
|
||||
r6.prefetch(0);
|
||||
r7.prefetch(0);
|
||||
|
||||
// performs "inner" products
|
||||
const RhsScalar* blB = &blockB[j2*strideB+offsetB*8];
|
||||
prefetch(&blB[0]);
|
||||
LhsPacket A0, A1;
|
||||
for(Index k=0; k<peeled_kc; k+=pk)
|
||||
{
|
||||
EIGEN_ASM_COMMENT("begin gebp micro kernel 3pX8");
|
||||
// 27 registers are taken (24 for acc, 3 for lhs).
|
||||
RhsPanel27 rhs_panel;
|
||||
RhsPacket T0;
|
||||
LhsPacket A2;
|
||||
#if EIGEN_COMP_GNUC_STRICT && EIGEN_ARCH_ARM64 && defined(EIGEN_VECTORIZE_NEON) && !(EIGEN_GNUC_AT_LEAST(9,0))
|
||||
// see http://eigen.tuxfamily.org/bz/show_bug.cgi?id=1633
|
||||
// without this workaround A0, A1, and A2 are loaded in the same register,
|
||||
// which is not good for pipelining
|
||||
#define EIGEN_GEBP_3Px8_REGISTER_ALLOC_WORKAROUND __asm__ ("" : "+w,m" (A0), "+w,m" (A1), "+w,m" (A2));
|
||||
#else
|
||||
#define EIGEN_GEBP_3Px8_REGISTER_ALLOC_WORKAROUND
|
||||
#endif
|
||||
|
||||
#define EIGEN_GEBP_ONESTEP(K) \
|
||||
do { \
|
||||
EIGEN_ASM_COMMENT("begin step of gebp micro kernel 3pX8"); \
|
||||
traits.loadLhs(&blA[(0 + 3 * K) * LhsProgress], A0); \
|
||||
traits.loadLhs(&blA[(1 + 3 * K) * LhsProgress], A1); \
|
||||
traits.loadLhs(&blA[(2 + 3 * K) * LhsProgress], A2); \
|
||||
EIGEN_GEBP_3Px8_REGISTER_ALLOC_WORKAROUND \
|
||||
traits.loadRhs(blB + (0 + 8 * K) * Traits::RhsProgress, rhs_panel); \
|
||||
traits.madd(A0, rhs_panel, C0, T0, fix<0>); \
|
||||
traits.madd(A1, rhs_panel, C8, T0, fix<0>); \
|
||||
traits.madd(A2, rhs_panel, C16, T0, fix<0>); \
|
||||
traits.updateRhs(blB + (1 + 8 * K) * Traits::RhsProgress, rhs_panel); \
|
||||
traits.madd(A0, rhs_panel, C1, T0, fix<1>); \
|
||||
traits.madd(A1, rhs_panel, C9, T0, fix<1>); \
|
||||
traits.madd(A2, rhs_panel, C17, T0, fix<1>); \
|
||||
traits.updateRhs(blB + (2 + 8 * K) * Traits::RhsProgress, rhs_panel); \
|
||||
traits.madd(A0, rhs_panel, C2, T0, fix<2>); \
|
||||
traits.madd(A1, rhs_panel, C10, T0, fix<2>); \
|
||||
traits.madd(A2, rhs_panel, C18, T0, fix<2>); \
|
||||
traits.updateRhs(blB + (3 + 8 * K) * Traits::RhsProgress, rhs_panel); \
|
||||
traits.madd(A0, rhs_panel, C3, T0, fix<3>); \
|
||||
traits.madd(A1, rhs_panel, C11, T0, fix<3>); \
|
||||
traits.madd(A2, rhs_panel, C19, T0, fix<3>); \
|
||||
traits.loadRhs(blB + (4 + 8 * K) * Traits::RhsProgress, rhs_panel); \
|
||||
traits.madd(A0, rhs_panel, C4, T0, fix<0>); \
|
||||
traits.madd(A1, rhs_panel, C12, T0, fix<0>); \
|
||||
traits.madd(A2, rhs_panel, C20, T0, fix<0>); \
|
||||
traits.updateRhs(blB + (5 + 8 * K) * Traits::RhsProgress, rhs_panel); \
|
||||
traits.madd(A0, rhs_panel, C5, T0, fix<1>); \
|
||||
traits.madd(A1, rhs_panel, C13, T0, fix<1>); \
|
||||
traits.madd(A2, rhs_panel, C21, T0, fix<1>); \
|
||||
traits.updateRhs(blB + (6 + 8 * K) * Traits::RhsProgress, rhs_panel); \
|
||||
traits.madd(A0, rhs_panel, C6, T0, fix<2>); \
|
||||
traits.madd(A1, rhs_panel, C14, T0, fix<2>); \
|
||||
traits.madd(A2, rhs_panel, C22, T0, fix<2>); \
|
||||
traits.updateRhs(blB + (7 + 8 * K) * Traits::RhsProgress, rhs_panel); \
|
||||
traits.madd(A0, rhs_panel, C7, T0, fix<3>); \
|
||||
traits.madd(A1, rhs_panel, C15, T0, fix<3>); \
|
||||
traits.madd(A2, rhs_panel, C23, T0, fix<3>); \
|
||||
EIGEN_ASM_COMMENT("end step of gebp micro kernel 3pX8"); \
|
||||
} while (false)
|
||||
|
||||
EIGEN_GEBP_ONESTEP(0);
|
||||
EIGEN_GEBP_ONESTEP(1);
|
||||
EIGEN_GEBP_ONESTEP(2);
|
||||
EIGEN_GEBP_ONESTEP(3);
|
||||
EIGEN_GEBP_ONESTEP(4);
|
||||
EIGEN_GEBP_ONESTEP(5);
|
||||
EIGEN_GEBP_ONESTEP(6);
|
||||
EIGEN_GEBP_ONESTEP(7);
|
||||
|
||||
blB += pk * 8 * RhsProgress;
|
||||
blA += pk * 3 * Traits::LhsProgress;
|
||||
EIGEN_ASM_COMMENT("end gebp micro kernel 3pX8");
|
||||
}
|
||||
|
||||
// process remaining peeled loop
|
||||
for (Index k = peeled_kc; k < depth; k++)
|
||||
{
|
||||
|
||||
RhsPanel27 rhs_panel;
|
||||
RhsPacket T0;
|
||||
LhsPacket A2;
|
||||
EIGEN_GEBP_ONESTEP(0);
|
||||
blB += 8 * RhsProgress;
|
||||
blA += 3 * Traits::LhsProgress;
|
||||
}
|
||||
|
||||
#undef EIGEN_GEBP_ONESTEP
|
||||
|
||||
ResPacket R0, R1, R2;
|
||||
ResPacket alphav = pset1<ResPacket>(alpha);
|
||||
|
||||
R0 = r0.template loadPacket<ResPacket>(0 * Traits::ResPacketSize);
|
||||
R1 = r0.template loadPacket<ResPacket>(1 * Traits::ResPacketSize);
|
||||
R2 = r0.template loadPacket<ResPacket>(2 * Traits::ResPacketSize);
|
||||
traits.acc(C0, alphav, R0);
|
||||
traits.acc(C8, alphav, R1);
|
||||
traits.acc(C16, alphav, R2);
|
||||
r0.storePacket(0 * Traits::ResPacketSize, R0);
|
||||
r0.storePacket(1 * Traits::ResPacketSize, R1);
|
||||
r0.storePacket(2 * Traits::ResPacketSize, R2);
|
||||
|
||||
R0 = r1.template loadPacket<ResPacket>(0 * Traits::ResPacketSize);
|
||||
R1 = r1.template loadPacket<ResPacket>(1 * Traits::ResPacketSize);
|
||||
R2 = r1.template loadPacket<ResPacket>(2 * Traits::ResPacketSize);
|
||||
traits.acc(C1, alphav, R0);
|
||||
traits.acc(C9, alphav, R1);
|
||||
traits.acc(C17, alphav, R2);
|
||||
r1.storePacket(0 * Traits::ResPacketSize, R0);
|
||||
r1.storePacket(1 * Traits::ResPacketSize, R1);
|
||||
r1.storePacket(2 * Traits::ResPacketSize, R2);
|
||||
|
||||
R0 = r2.template loadPacket<ResPacket>(0 * Traits::ResPacketSize);
|
||||
R1 = r2.template loadPacket<ResPacket>(1 * Traits::ResPacketSize);
|
||||
R2 = r2.template loadPacket<ResPacket>(2 * Traits::ResPacketSize);
|
||||
traits.acc(C2, alphav, R0);
|
||||
traits.acc(C10, alphav, R1);
|
||||
traits.acc(C18, alphav, R2);
|
||||
r2.storePacket(0 * Traits::ResPacketSize, R0);
|
||||
r2.storePacket(1 * Traits::ResPacketSize, R1);
|
||||
r2.storePacket(2 * Traits::ResPacketSize, R2);
|
||||
|
||||
R0 = r3.template loadPacket<ResPacket>(0 * Traits::ResPacketSize);
|
||||
R1 = r3.template loadPacket<ResPacket>(1 * Traits::ResPacketSize);
|
||||
R2 = r3.template loadPacket<ResPacket>(2 * Traits::ResPacketSize);
|
||||
traits.acc(C3, alphav, R0);
|
||||
traits.acc(C11, alphav, R1);
|
||||
traits.acc(C19, alphav, R2);
|
||||
r3.storePacket(0 * Traits::ResPacketSize, R0);
|
||||
r3.storePacket(1 * Traits::ResPacketSize, R1);
|
||||
r3.storePacket(2 * Traits::ResPacketSize, R2);
|
||||
|
||||
R0 = r4.template loadPacket<ResPacket>(0 * Traits::ResPacketSize);
|
||||
R1 = r4.template loadPacket<ResPacket>(1 * Traits::ResPacketSize);
|
||||
R2 = r4.template loadPacket<ResPacket>(2 * Traits::ResPacketSize);
|
||||
traits.acc(C4, alphav, R0);
|
||||
traits.acc(C12, alphav, R1);
|
||||
traits.acc(C20, alphav, R2);
|
||||
r4.storePacket(0 * Traits::ResPacketSize, R0);
|
||||
r4.storePacket(1 * Traits::ResPacketSize, R1);
|
||||
r4.storePacket(2 * Traits::ResPacketSize, R2);
|
||||
|
||||
R0 = r5.template loadPacket<ResPacket>(0 * Traits::ResPacketSize);
|
||||
R1 = r5.template loadPacket<ResPacket>(1 * Traits::ResPacketSize);
|
||||
R2 = r5.template loadPacket<ResPacket>(2 * Traits::ResPacketSize);
|
||||
traits.acc(C5, alphav, R0);
|
||||
traits.acc(C13, alphav, R1);
|
||||
traits.acc(C21, alphav, R2);
|
||||
r5.storePacket(0 * Traits::ResPacketSize, R0);
|
||||
r5.storePacket(1 * Traits::ResPacketSize, R1);
|
||||
r5.storePacket(2 * Traits::ResPacketSize, R2);
|
||||
|
||||
R0 = r6.template loadPacket<ResPacket>(0 * Traits::ResPacketSize);
|
||||
R1 = r6.template loadPacket<ResPacket>(1 * Traits::ResPacketSize);
|
||||
R2 = r6.template loadPacket<ResPacket>(2 * Traits::ResPacketSize);
|
||||
traits.acc(C6, alphav, R0);
|
||||
traits.acc(C14, alphav, R1);
|
||||
traits.acc(C22, alphav, R2);
|
||||
r6.storePacket(0 * Traits::ResPacketSize, R0);
|
||||
r6.storePacket(1 * Traits::ResPacketSize, R1);
|
||||
r6.storePacket(2 * Traits::ResPacketSize, R2);
|
||||
|
||||
R0 = r7.template loadPacket<ResPacket>(0 * Traits::ResPacketSize);
|
||||
R1 = r7.template loadPacket<ResPacket>(1 * Traits::ResPacketSize);
|
||||
R2 = r7.template loadPacket<ResPacket>(2 * Traits::ResPacketSize);
|
||||
traits.acc(C7, alphav, R0);
|
||||
traits.acc(C15, alphav, R1);
|
||||
traits.acc(C23, alphav, R2);
|
||||
r7.storePacket(0 * Traits::ResPacketSize, R0);
|
||||
r7.storePacket(1 * Traits::ResPacketSize, R1);
|
||||
r7.storePacket(2 * Traits::ResPacketSize, R2);
|
||||
}
|
||||
}
|
||||
|
||||
for(Index j2=packet_cols8; j2<packet_cols4; j2+=4)
|
||||
{
|
||||
for(Index i=i1; i<actual_panel_end; i+=3*LhsProgress)
|
||||
{
|
||||
@ -1473,14 +1809,14 @@ void gebp_kernel<LhsScalar,RhsScalar,Index,DataMapper,mr,nr,ConjugateLhs,Conjuga
|
||||
r3.prefetch(0);
|
||||
|
||||
// performs "inner" products
|
||||
const RhsScalar* blB = &blockB[j2*strideB+offsetB*nr];
|
||||
const RhsScalar* blB = &blockB[j2*strideB+offsetB*4];
|
||||
prefetch(&blB[0]);
|
||||
LhsPacket A0, A1;
|
||||
|
||||
for(Index k=0; k<peeled_kc; k+=pk)
|
||||
{
|
||||
EIGEN_ASM_COMMENT("begin gebp micro kernel 3pX4");
|
||||
// 15 registers are taken (12 for acc, 2 for lhs).
|
||||
// 15 registers are taken (12 for acc, 3 for lhs).
|
||||
RhsPanel15 rhs_panel;
|
||||
RhsPacket T0;
|
||||
LhsPacket A2;
|
||||
@ -1689,7 +2025,170 @@ void gebp_kernel<LhsScalar,RhsScalar,Index,DataMapper,mr,nr,ConjugateLhs,Conjuga
|
||||
for(Index i1=peeled_mc3; i1<peeled_mc2; i1+=actual_panel_rows)
|
||||
{
|
||||
Index actual_panel_end = (std::min)(i1+actual_panel_rows, peeled_mc2);
|
||||
for(Index j2=0; j2<packet_cols4; j2+=nr)
|
||||
for(Index j2=0; j2<packet_cols8; j2+=8)
|
||||
{
|
||||
for(Index i=i1; i<actual_panel_end; i+=2*LhsProgress)
|
||||
{
|
||||
const LhsScalar* blA = &blockA[i*strideA+offsetA*(2*Traits::LhsProgress)];
|
||||
prefetch(&blA[0]);
|
||||
|
||||
AccPacket C0, C1, C2, C3, C4, C5, C6, C7,
|
||||
C8, C9, C10, C11, C12, C13, C14, C15;
|
||||
traits.initAcc(C0); traits.initAcc(C1); traits.initAcc(C2); traits.initAcc(C3);
|
||||
traits.initAcc(C4); traits.initAcc(C5); traits.initAcc(C6); traits.initAcc(C7);
|
||||
traits.initAcc(C8); traits.initAcc(C9); traits.initAcc(C10); traits.initAcc(C11);
|
||||
traits.initAcc(C12); traits.initAcc(C13); traits.initAcc(C14); traits.initAcc(C15);
|
||||
|
||||
LinearMapper r0 = res.getLinearMapper(i, j2 + 0);
|
||||
LinearMapper r1 = res.getLinearMapper(i, j2 + 1);
|
||||
LinearMapper r2 = res.getLinearMapper(i, j2 + 2);
|
||||
LinearMapper r3 = res.getLinearMapper(i, j2 + 3);
|
||||
LinearMapper r4 = res.getLinearMapper(i, j2 + 4);
|
||||
LinearMapper r5 = res.getLinearMapper(i, j2 + 5);
|
||||
LinearMapper r6 = res.getLinearMapper(i, j2 + 6);
|
||||
LinearMapper r7 = res.getLinearMapper(i, j2 + 7);
|
||||
r0.prefetch(prefetch_res_offset);
|
||||
r1.prefetch(prefetch_res_offset);
|
||||
r2.prefetch(prefetch_res_offset);
|
||||
r3.prefetch(prefetch_res_offset);
|
||||
r4.prefetch(prefetch_res_offset);
|
||||
r5.prefetch(prefetch_res_offset);
|
||||
r6.prefetch(prefetch_res_offset);
|
||||
r7.prefetch(prefetch_res_offset);
|
||||
|
||||
const RhsScalar* blB = &blockB[j2*strideB+offsetB*8];
|
||||
prefetch(&blB[0]);
|
||||
LhsPacket A0, A1;
|
||||
for(Index k=0; k<peeled_kc; k+=pk)
|
||||
{
|
||||
RhsPacketx4 rhs_panel;
|
||||
RhsPacket T0;
|
||||
// NOTE: the begin/end asm comments below work around bug 935!
|
||||
// but they are not enough for gcc>=6 without FMA (bug 1637)
|
||||
#if EIGEN_GNUC_AT_LEAST(6,0) && defined(EIGEN_VECTORIZE_SSE)
|
||||
#define EIGEN_GEBP_2Px8_SPILLING_WORKAROUND __asm__ ("" : [a0] "+x,m" (A0),[a1] "+x,m" (A1));
|
||||
#else
|
||||
#define EIGEN_GEBP_2Px8_SPILLING_WORKAROUND
|
||||
#endif
|
||||
#define EIGEN_GEBGP_ONESTEP(K) \
|
||||
do { \
|
||||
EIGEN_ASM_COMMENT("begin step of gebp micro kernel 2pX8"); \
|
||||
traits.loadLhs(&blA[(0 + 2 * K) * LhsProgress], A0); \
|
||||
traits.loadLhs(&blA[(1 + 2 * K) * LhsProgress], A1); \
|
||||
traits.loadRhs(&blB[(0 + 8 * K) * RhsProgress], rhs_panel); \
|
||||
traits.madd(A0, rhs_panel, C0, T0, fix<0>); \
|
||||
traits.madd(A1, rhs_panel, C8, T0, fix<0>); \
|
||||
traits.updateRhs(&blB[(1 + 8 * K) * RhsProgress], rhs_panel); \
|
||||
traits.madd(A0, rhs_panel, C1, T0, fix<1>); \
|
||||
traits.madd(A1, rhs_panel, C9, T0, fix<1>); \
|
||||
traits.updateRhs(&blB[(2 + 8 * K) * RhsProgress], rhs_panel); \
|
||||
traits.madd(A0, rhs_panel, C2, T0, fix<2>); \
|
||||
traits.madd(A1, rhs_panel, C10, T0, fix<2>); \
|
||||
traits.updateRhs(&blB[(3 + 8 * K) * RhsProgress], rhs_panel); \
|
||||
traits.madd(A0, rhs_panel, C3, T0, fix<3>); \
|
||||
traits.madd(A1, rhs_panel, C11, T0, fix<3>); \
|
||||
traits.loadRhs(&blB[(4 + 8 * K) * RhsProgress], rhs_panel); \
|
||||
traits.madd(A0, rhs_panel, C4, T0, fix<0>); \
|
||||
traits.madd(A1, rhs_panel, C12, T0, fix<0>); \
|
||||
traits.updateRhs(&blB[(5 + 8 * K) * RhsProgress], rhs_panel); \
|
||||
traits.madd(A0, rhs_panel, C5, T0, fix<1>); \
|
||||
traits.madd(A1, rhs_panel, C13, T0, fix<1>); \
|
||||
traits.updateRhs(&blB[(6 + 8 * K) * RhsProgress], rhs_panel); \
|
||||
traits.madd(A0, rhs_panel, C6, T0, fix<2>); \
|
||||
traits.madd(A1, rhs_panel, C14, T0, fix<2>); \
|
||||
traits.updateRhs(&blB[(7 + 8 * K) * RhsProgress], rhs_panel); \
|
||||
traits.madd(A0, rhs_panel, C7, T0, fix<3>); \
|
||||
traits.madd(A1, rhs_panel, C15, T0, fix<3>); \
|
||||
EIGEN_GEBP_2Px8_SPILLING_WORKAROUND \
|
||||
EIGEN_ASM_COMMENT("end step of gebp micro kernel 2pX8"); \
|
||||
} while (false)
|
||||
|
||||
EIGEN_ASM_COMMENT("begin gebp micro kernel 2pX8");
|
||||
|
||||
EIGEN_GEBGP_ONESTEP(0);
|
||||
EIGEN_GEBGP_ONESTEP(1);
|
||||
EIGEN_GEBGP_ONESTEP(2);
|
||||
EIGEN_GEBGP_ONESTEP(3);
|
||||
EIGEN_GEBGP_ONESTEP(4);
|
||||
EIGEN_GEBGP_ONESTEP(5);
|
||||
EIGEN_GEBGP_ONESTEP(6);
|
||||
EIGEN_GEBGP_ONESTEP(7);
|
||||
|
||||
blB += pk*8*RhsProgress;
|
||||
blA += pk*(2*Traits::LhsProgress);
|
||||
|
||||
EIGEN_ASM_COMMENT("end gebp micro kernel 2pX8");
|
||||
}
|
||||
// process remaining peeled loop
|
||||
for(Index k=peeled_kc; k<depth; k++)
|
||||
{
|
||||
RhsPacketx4 rhs_panel;
|
||||
RhsPacket T0;
|
||||
EIGEN_GEBGP_ONESTEP(0);
|
||||
blB += 8*RhsProgress;
|
||||
blA += 2*Traits::LhsProgress;
|
||||
}
|
||||
|
||||
#undef EIGEN_GEBGP_ONESTEP
|
||||
|
||||
ResPacket R0, R1, R2, R3;
|
||||
ResPacket alphav = pset1<ResPacket>(alpha);
|
||||
|
||||
R0 = r0.template loadPacket<ResPacket>(0 * Traits::ResPacketSize);
|
||||
R1 = r0.template loadPacket<ResPacket>(1 * Traits::ResPacketSize);
|
||||
R2 = r1.template loadPacket<ResPacket>(0 * Traits::ResPacketSize);
|
||||
R3 = r1.template loadPacket<ResPacket>(1 * Traits::ResPacketSize);
|
||||
traits.acc(C0, alphav, R0);
|
||||
traits.acc(C8, alphav, R1);
|
||||
traits.acc(C1, alphav, R2);
|
||||
traits.acc(C9, alphav, R3);
|
||||
r0.storePacket(0 * Traits::ResPacketSize, R0);
|
||||
r0.storePacket(1 * Traits::ResPacketSize, R1);
|
||||
r1.storePacket(0 * Traits::ResPacketSize, R2);
|
||||
r1.storePacket(1 * Traits::ResPacketSize, R3);
|
||||
|
||||
R0 = r2.template loadPacket<ResPacket>(0 * Traits::ResPacketSize);
|
||||
R1 = r2.template loadPacket<ResPacket>(1 * Traits::ResPacketSize);
|
||||
R2 = r3.template loadPacket<ResPacket>(0 * Traits::ResPacketSize);
|
||||
R3 = r3.template loadPacket<ResPacket>(1 * Traits::ResPacketSize);
|
||||
traits.acc(C2, alphav, R0);
|
||||
traits.acc(C10, alphav, R1);
|
||||
traits.acc(C3, alphav, R2);
|
||||
traits.acc(C11, alphav, R3);
|
||||
r2.storePacket(0 * Traits::ResPacketSize, R0);
|
||||
r2.storePacket(1 * Traits::ResPacketSize, R1);
|
||||
r3.storePacket(0 * Traits::ResPacketSize, R2);
|
||||
r3.storePacket(1 * Traits::ResPacketSize, R3);
|
||||
|
||||
R0 = r4.template loadPacket<ResPacket>(0 * Traits::ResPacketSize);
|
||||
R1 = r4.template loadPacket<ResPacket>(1 * Traits::ResPacketSize);
|
||||
R2 = r5.template loadPacket<ResPacket>(0 * Traits::ResPacketSize);
|
||||
R3 = r5.template loadPacket<ResPacket>(1 * Traits::ResPacketSize);
|
||||
traits.acc(C4, alphav, R0);
|
||||
traits.acc(C12, alphav, R1);
|
||||
traits.acc(C5, alphav, R2);
|
||||
traits.acc(C13, alphav, R3);
|
||||
r4.storePacket(0 * Traits::ResPacketSize, R0);
|
||||
r4.storePacket(1 * Traits::ResPacketSize, R1);
|
||||
r5.storePacket(0 * Traits::ResPacketSize, R2);
|
||||
r5.storePacket(1 * Traits::ResPacketSize, R3);
|
||||
|
||||
R0 = r6.template loadPacket<ResPacket>(0 * Traits::ResPacketSize);
|
||||
R1 = r6.template loadPacket<ResPacket>(1 * Traits::ResPacketSize);
|
||||
R2 = r7.template loadPacket<ResPacket>(0 * Traits::ResPacketSize);
|
||||
R3 = r7.template loadPacket<ResPacket>(1 * Traits::ResPacketSize);
|
||||
traits.acc(C6, alphav, R0);
|
||||
traits.acc(C14, alphav, R1);
|
||||
traits.acc(C7, alphav, R2);
|
||||
traits.acc(C15, alphav, R3);
|
||||
r6.storePacket(0 * Traits::ResPacketSize, R0);
|
||||
r6.storePacket(1 * Traits::ResPacketSize, R1);
|
||||
r7.storePacket(0 * Traits::ResPacketSize, R2);
|
||||
r7.storePacket(1 * Traits::ResPacketSize, R3);
|
||||
}
|
||||
}
|
||||
|
||||
for(Index j2=packet_cols8; j2<packet_cols4; j2+=4)
|
||||
{
|
||||
for(Index i=i1; i<actual_panel_end; i+=2*LhsProgress)
|
||||
{
|
||||
@ -1717,7 +2216,7 @@ void gebp_kernel<LhsScalar,RhsScalar,Index,DataMapper,mr,nr,ConjugateLhs,Conjuga
|
||||
r3.prefetch(prefetch_res_offset);
|
||||
|
||||
// performs "inner" products
|
||||
const RhsScalar* blB = &blockB[j2*strideB+offsetB*nr];
|
||||
const RhsScalar* blB = &blockB[j2*strideB+offsetB*4];
|
||||
prefetch(&blB[0]);
|
||||
LhsPacket A0, A1;
|
||||
|
||||
@ -1907,14 +2406,66 @@ void gebp_kernel<LhsScalar,RhsScalar,Index,DataMapper,mr,nr,ConjugateLhs,Conjuga
|
||||
if(peeled_mc_quarter<rows)
|
||||
{
|
||||
// loop on each panel of the rhs
|
||||
for(Index j2=0; j2<packet_cols4; j2+=nr)
|
||||
for(Index j2=0; j2<packet_cols8; j2+=8)
|
||||
{
|
||||
// loop on each row of the lhs (1*LhsProgress x depth)
|
||||
for(Index i=peeled_mc_quarter; i<rows; i+=1)
|
||||
{
|
||||
const LhsScalar* blA = &blockA[i*strideA+offsetA];
|
||||
prefetch(&blA[0]);
|
||||
const RhsScalar* blB = &blockB[j2*strideB+offsetB*nr];
|
||||
// gets a 1 x 1 res block as registers
|
||||
ResScalar C0(0),C1(0),C2(0),C3(0),C4(0),C5(0),C6(0),C7(0);
|
||||
const RhsScalar* blB = &blockB[j2*strideB+offsetB*8];
|
||||
for(Index k=0; k<depth; k++)
|
||||
{
|
||||
LhsScalar A0 = blA[k];
|
||||
RhsScalar B_0;
|
||||
|
||||
B_0 = blB[0];
|
||||
C0 = cj.pmadd(A0, B_0, C0);
|
||||
|
||||
B_0 = blB[1];
|
||||
C1 = cj.pmadd(A0, B_0, C1);
|
||||
|
||||
B_0 = blB[2];
|
||||
C2 = cj.pmadd(A0, B_0, C2);
|
||||
|
||||
B_0 = blB[3];
|
||||
C3 = cj.pmadd(A0, B_0, C3);
|
||||
|
||||
B_0 = blB[4];
|
||||
C4 = cj.pmadd(A0, B_0, C4);
|
||||
|
||||
B_0 = blB[5];
|
||||
C5 = cj.pmadd(A0, B_0, C5);
|
||||
|
||||
B_0 = blB[6];
|
||||
C6 = cj.pmadd(A0, B_0, C6);
|
||||
|
||||
B_0 = blB[7];
|
||||
C7 = cj.pmadd(A0, B_0, C7);
|
||||
|
||||
blB += 8;
|
||||
}
|
||||
res(i, j2 + 0) += alpha * C0;
|
||||
res(i, j2 + 1) += alpha * C1;
|
||||
res(i, j2 + 2) += alpha * C2;
|
||||
res(i, j2 + 3) += alpha * C3;
|
||||
res(i, j2 + 4) += alpha * C4;
|
||||
res(i, j2 + 5) += alpha * C5;
|
||||
res(i, j2 + 6) += alpha * C6;
|
||||
res(i, j2 + 7) += alpha * C7;
|
||||
}
|
||||
}
|
||||
|
||||
for(Index j2=packet_cols8; j2<packet_cols4; j2+=4)
|
||||
{
|
||||
// loop on each row of the lhs (1*LhsProgress x depth)
|
||||
for(Index i=peeled_mc_quarter; i<rows; i+=1)
|
||||
{
|
||||
const LhsScalar* blA = &blockA[i*strideA+offsetA];
|
||||
prefetch(&blA[0]);
|
||||
const RhsScalar* blB = &blockB[j2*strideB+offsetB*4];
|
||||
|
||||
// If LhsProgress is 8 or 16, it assumes that there is a
|
||||
// half or quarter packet, respectively, of the same size as
|
||||
@ -2397,51 +2948,121 @@ EIGEN_DONT_INLINE void gemm_pack_rhs<Scalar, Index, DataMapper, nr, ColMajor, Co
|
||||
Index packet_cols4 = nr>=4 ? (cols/4) * 4 : 0;
|
||||
Index count = 0;
|
||||
const Index peeled_k = (depth/PacketSize)*PacketSize;
|
||||
// if(nr>=8)
|
||||
// {
|
||||
// for(Index j2=0; j2<packet_cols8; j2+=8)
|
||||
// {
|
||||
// // skip what we have before
|
||||
// if(PanelMode) count += 8 * offset;
|
||||
// const Scalar* b0 = &rhs[(j2+0)*rhsStride];
|
||||
// const Scalar* b1 = &rhs[(j2+1)*rhsStride];
|
||||
// const Scalar* b2 = &rhs[(j2+2)*rhsStride];
|
||||
// const Scalar* b3 = &rhs[(j2+3)*rhsStride];
|
||||
// const Scalar* b4 = &rhs[(j2+4)*rhsStride];
|
||||
// const Scalar* b5 = &rhs[(j2+5)*rhsStride];
|
||||
// const Scalar* b6 = &rhs[(j2+6)*rhsStride];
|
||||
// const Scalar* b7 = &rhs[(j2+7)*rhsStride];
|
||||
// Index k=0;
|
||||
// if(PacketSize==8) // TODO enable vectorized transposition for PacketSize==4
|
||||
// {
|
||||
// for(; k<peeled_k; k+=PacketSize) {
|
||||
// PacketBlock<Packet> kernel;
|
||||
// for (int p = 0; p < PacketSize; ++p) {
|
||||
// kernel.packet[p] = ploadu<Packet>(&rhs[(j2+p)*rhsStride+k]);
|
||||
// }
|
||||
// ptranspose(kernel);
|
||||
// for (int p = 0; p < PacketSize; ++p) {
|
||||
// pstoreu(blockB+count, cj.pconj(kernel.packet[p]));
|
||||
// count+=PacketSize;
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
// for(; k<depth; k++)
|
||||
// {
|
||||
// blockB[count+0] = cj(b0[k]);
|
||||
// blockB[count+1] = cj(b1[k]);
|
||||
// blockB[count+2] = cj(b2[k]);
|
||||
// blockB[count+3] = cj(b3[k]);
|
||||
// blockB[count+4] = cj(b4[k]);
|
||||
// blockB[count+5] = cj(b5[k]);
|
||||
// blockB[count+6] = cj(b6[k]);
|
||||
// blockB[count+7] = cj(b7[k]);
|
||||
// count += 8;
|
||||
// }
|
||||
// // skip what we have after
|
||||
// if(PanelMode) count += 8 * (stride-offset-depth);
|
||||
// }
|
||||
// }
|
||||
|
||||
if(nr>=8)
|
||||
{
|
||||
for(Index j2=0; j2<packet_cols8; j2+=8)
|
||||
{
|
||||
// skip what we have before
|
||||
if(PanelMode) count += 8 * offset;
|
||||
const LinearMapper dm0 = rhs.getLinearMapper(0, j2 + 0);
|
||||
const LinearMapper dm1 = rhs.getLinearMapper(0, j2 + 1);
|
||||
const LinearMapper dm2 = rhs.getLinearMapper(0, j2 + 2);
|
||||
const LinearMapper dm3 = rhs.getLinearMapper(0, j2 + 3);
|
||||
const LinearMapper dm4 = rhs.getLinearMapper(0, j2 + 4);
|
||||
const LinearMapper dm5 = rhs.getLinearMapper(0, j2 + 5);
|
||||
const LinearMapper dm6 = rhs.getLinearMapper(0, j2 + 6);
|
||||
const LinearMapper dm7 = rhs.getLinearMapper(0, j2 + 7);
|
||||
Index k = 0;
|
||||
if (PacketSize % 2 == 0 && PacketSize <= 8) // 2 4 8
|
||||
{
|
||||
for (; k < peeled_k; k += PacketSize)
|
||||
{
|
||||
if (PacketSize == 2)
|
||||
{
|
||||
PacketBlock<Packet, PacketSize==2 ?2:PacketSize> kernel0, kernel1, kernel2, kernel3;
|
||||
kernel0.packet[0%PacketSize] = dm0.template loadPacket<Packet>(k);
|
||||
kernel0.packet[1%PacketSize] = dm1.template loadPacket<Packet>(k);
|
||||
kernel1.packet[0%PacketSize] = dm2.template loadPacket<Packet>(k);
|
||||
kernel1.packet[1%PacketSize] = dm3.template loadPacket<Packet>(k);
|
||||
kernel2.packet[0%PacketSize] = dm4.template loadPacket<Packet>(k);
|
||||
kernel2.packet[1%PacketSize] = dm5.template loadPacket<Packet>(k);
|
||||
kernel3.packet[0%PacketSize] = dm6.template loadPacket<Packet>(k);
|
||||
kernel3.packet[1%PacketSize] = dm7.template loadPacket<Packet>(k);
|
||||
ptranspose(kernel0);
|
||||
ptranspose(kernel1);
|
||||
ptranspose(kernel2);
|
||||
ptranspose(kernel3);
|
||||
|
||||
pstoreu(blockB + count + 0 * PacketSize, cj.pconj(kernel0.packet[0 % PacketSize]));
|
||||
pstoreu(blockB + count + 1 * PacketSize, cj.pconj(kernel1.packet[0 % PacketSize]));
|
||||
pstoreu(blockB + count + 2 * PacketSize, cj.pconj(kernel2.packet[0 % PacketSize]));
|
||||
pstoreu(blockB + count + 3 * PacketSize, cj.pconj(kernel3.packet[0 % PacketSize]));
|
||||
|
||||
pstoreu(blockB + count + 4 * PacketSize, cj.pconj(kernel0.packet[1 % PacketSize]));
|
||||
pstoreu(blockB + count + 5 * PacketSize, cj.pconj(kernel1.packet[1 % PacketSize]));
|
||||
pstoreu(blockB + count + 6 * PacketSize, cj.pconj(kernel2.packet[1 % PacketSize]));
|
||||
pstoreu(blockB + count + 7 * PacketSize, cj.pconj(kernel3.packet[1 % PacketSize]));
|
||||
count+=8*PacketSize;
|
||||
}
|
||||
else if (PacketSize == 4)
|
||||
{
|
||||
PacketBlock<Packet, PacketSize == 4?4:PacketSize> kernel0, kernel1;
|
||||
|
||||
kernel0.packet[0%PacketSize] = dm0.template loadPacket<Packet>(k);
|
||||
kernel0.packet[1%PacketSize] = dm1.template loadPacket<Packet>(k);
|
||||
kernel0.packet[2%PacketSize] = dm2.template loadPacket<Packet>(k);
|
||||
kernel0.packet[3%PacketSize] = dm3.template loadPacket<Packet>(k);
|
||||
kernel1.packet[0%PacketSize] = dm4.template loadPacket<Packet>(k);
|
||||
kernel1.packet[1%PacketSize] = dm5.template loadPacket<Packet>(k);
|
||||
kernel1.packet[2%PacketSize] = dm6.template loadPacket<Packet>(k);
|
||||
kernel1.packet[3%PacketSize] = dm7.template loadPacket<Packet>(k);
|
||||
ptranspose(kernel0);
|
||||
ptranspose(kernel1);
|
||||
|
||||
pstoreu(blockB+count+0*PacketSize, cj.pconj(kernel0.packet[0%PacketSize]));
|
||||
pstoreu(blockB+count+1*PacketSize, cj.pconj(kernel1.packet[0%PacketSize]));
|
||||
pstoreu(blockB+count+2*PacketSize, cj.pconj(kernel0.packet[1%PacketSize]));
|
||||
pstoreu(blockB+count+3*PacketSize, cj.pconj(kernel1.packet[1%PacketSize]));
|
||||
pstoreu(blockB+count+4*PacketSize, cj.pconj(kernel0.packet[2%PacketSize]));
|
||||
pstoreu(blockB+count+5*PacketSize, cj.pconj(kernel1.packet[2%PacketSize]));
|
||||
pstoreu(blockB+count+6*PacketSize, cj.pconj(kernel0.packet[3%PacketSize]));
|
||||
pstoreu(blockB+count+7*PacketSize, cj.pconj(kernel1.packet[3%PacketSize]));
|
||||
count+=8*PacketSize;
|
||||
}
|
||||
else if (PacketSize == 8)
|
||||
{
|
||||
PacketBlock<Packet, PacketSize==8?8:PacketSize> kernel0;
|
||||
|
||||
kernel0.packet[0%PacketSize] = dm0.template loadPacket<Packet>(k);
|
||||
kernel0.packet[1%PacketSize] = dm1.template loadPacket<Packet>(k);
|
||||
kernel0.packet[2%PacketSize] = dm2.template loadPacket<Packet>(k);
|
||||
kernel0.packet[3%PacketSize] = dm3.template loadPacket<Packet>(k);
|
||||
kernel0.packet[4%PacketSize] = dm4.template loadPacket<Packet>(k);
|
||||
kernel0.packet[5%PacketSize] = dm5.template loadPacket<Packet>(k);
|
||||
kernel0.packet[6%PacketSize] = dm6.template loadPacket<Packet>(k);
|
||||
kernel0.packet[7%PacketSize] = dm7.template loadPacket<Packet>(k);
|
||||
ptranspose(kernel0);
|
||||
|
||||
pstoreu(blockB+count+0*PacketSize, cj.pconj(kernel0.packet[0%PacketSize]));
|
||||
pstoreu(blockB+count+1*PacketSize, cj.pconj(kernel0.packet[1%PacketSize]));
|
||||
pstoreu(blockB+count+2*PacketSize, cj.pconj(kernel0.packet[2%PacketSize]));
|
||||
pstoreu(blockB+count+3*PacketSize, cj.pconj(kernel0.packet[3%PacketSize]));
|
||||
pstoreu(blockB+count+4*PacketSize, cj.pconj(kernel0.packet[4%PacketSize]));
|
||||
pstoreu(blockB+count+5*PacketSize, cj.pconj(kernel0.packet[5%PacketSize]));
|
||||
pstoreu(blockB+count+6*PacketSize, cj.pconj(kernel0.packet[6%PacketSize]));
|
||||
pstoreu(blockB+count+7*PacketSize, cj.pconj(kernel0.packet[7%PacketSize]));
|
||||
count+=8*PacketSize;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for(; k<depth; k++)
|
||||
{
|
||||
blockB[count+0] = cj(dm0(k));
|
||||
blockB[count+1] = cj(dm1(k));
|
||||
blockB[count+2] = cj(dm2(k));
|
||||
blockB[count+3] = cj(dm3(k));
|
||||
blockB[count+4] = cj(dm4(k));
|
||||
blockB[count+5] = cj(dm5(k));
|
||||
blockB[count+6] = cj(dm6(k));
|
||||
blockB[count+7] = cj(dm7(k));
|
||||
count += 8;
|
||||
}
|
||||
// skip what we have after
|
||||
if(PanelMode) count += 8 * (stride-offset-depth);
|
||||
}
|
||||
}
|
||||
|
||||
if(nr>=4)
|
||||
{
|
||||
@ -2522,39 +3143,41 @@ struct gemm_pack_rhs<Scalar, Index, DataMapper, nr, RowMajor, Conjugate, PanelMo
|
||||
Index packet_cols4 = nr>=4 ? (cols/4) * 4 : 0;
|
||||
Index count = 0;
|
||||
|
||||
// if(nr>=8)
|
||||
// {
|
||||
// for(Index j2=0; j2<packet_cols8; j2+=8)
|
||||
// {
|
||||
// // skip what we have before
|
||||
// if(PanelMode) count += 8 * offset;
|
||||
// for(Index k=0; k<depth; k++)
|
||||
// {
|
||||
// if (PacketSize==8) {
|
||||
// Packet A = ploadu<Packet>(&rhs[k*rhsStride + j2]);
|
||||
// pstoreu(blockB+count, cj.pconj(A));
|
||||
// } else if (PacketSize==4) {
|
||||
// Packet A = ploadu<Packet>(&rhs[k*rhsStride + j2]);
|
||||
// Packet B = ploadu<Packet>(&rhs[k*rhsStride + j2 + PacketSize]);
|
||||
// pstoreu(blockB+count, cj.pconj(A));
|
||||
// pstoreu(blockB+count+PacketSize, cj.pconj(B));
|
||||
// } else {
|
||||
// const Scalar* b0 = &rhs[k*rhsStride + j2];
|
||||
// blockB[count+0] = cj(b0[0]);
|
||||
// blockB[count+1] = cj(b0[1]);
|
||||
// blockB[count+2] = cj(b0[2]);
|
||||
// blockB[count+3] = cj(b0[3]);
|
||||
// blockB[count+4] = cj(b0[4]);
|
||||
// blockB[count+5] = cj(b0[5]);
|
||||
// blockB[count+6] = cj(b0[6]);
|
||||
// blockB[count+7] = cj(b0[7]);
|
||||
// }
|
||||
// count += 8;
|
||||
// }
|
||||
// // skip what we have after
|
||||
// if(PanelMode) count += 8 * (stride-offset-depth);
|
||||
// }
|
||||
// }
|
||||
if(nr>=8)
|
||||
{
|
||||
for(Index j2=0; j2<packet_cols8; j2+=8)
|
||||
{
|
||||
// skip what we have before
|
||||
if(PanelMode) count += 8 * offset;
|
||||
for(Index k=0; k<depth; k++)
|
||||
{
|
||||
if (PacketSize==8) {
|
||||
Packet A = rhs.template loadPacket<Packet>(k, j2);
|
||||
pstoreu(blockB+count, cj.pconj(A));
|
||||
count += PacketSize;
|
||||
} else if (PacketSize==4) {
|
||||
Packet A = rhs.template loadPacket<Packet>(k, j2);
|
||||
Packet B = rhs.template loadPacket<Packet>(k, j2 + 4);
|
||||
pstoreu(blockB+count, cj.pconj(A));
|
||||
pstoreu(blockB+count+PacketSize, cj.pconj(B));
|
||||
count += 2*PacketSize;
|
||||
} else {
|
||||
const LinearMapper dm0 = rhs.getLinearMapper(k, j2);
|
||||
blockB[count+0] = cj(dm0(0));
|
||||
blockB[count+1] = cj(dm0(1));
|
||||
blockB[count+2] = cj(dm0(2));
|
||||
blockB[count+3] = cj(dm0(3));
|
||||
blockB[count+4] = cj(dm0(4));
|
||||
blockB[count+5] = cj(dm0(5));
|
||||
blockB[count+6] = cj(dm0(6));
|
||||
blockB[count+7] = cj(dm0(7));
|
||||
count += 8;
|
||||
}
|
||||
}
|
||||
// skip what we have after
|
||||
if(PanelMode) count += 8 * (stride-offset-depth);
|
||||
}
|
||||
}
|
||||
if(nr>=4)
|
||||
{
|
||||
for(Index j2=packet_cols8; j2<packet_cols4; j2+=4)
|
||||
|
Loading…
x
Reference in New Issue
Block a user