Optimize gebp kernel:

1 - increase peeling level along the depth dimention (+5% for large matrices, i.e., >1000)
2 - improve pipelining when dealing with latest rows of the lhs
This commit is contained in:
Gael Guennebaud 2014-03-30 21:57:05 +02:00
parent ad59ade116
commit e497a27ddc

View File

@ -95,6 +95,9 @@ void computeProductBlockingSizes(SizeType& k, SizeType& m, SizeType& n)
k = std::min<SizeType>(k, l1/kdiv); k = std::min<SizeType>(k, l1/kdiv);
SizeType _m = k>0 ? l2/(4 * sizeof(LhsScalar) * k) : 0; SizeType _m = k>0 ? l2/(4 * sizeof(LhsScalar) * k) : 0;
if(_m<m) m = _m & mr_mask; if(_m<m) m = _m & mr_mask;
m = 1024;
k = 256;
} }
template<typename LhsScalar, typename RhsScalar, typename SizeType> template<typename LhsScalar, typename RhsScalar, typename SizeType>
@ -328,6 +331,22 @@ protected:
conj_helper<ResPacket,ResPacket,ConjLhs,false> cj; conj_helper<ResPacket,ResPacket,ConjLhs,false> cj;
}; };
template<typename Packet>
struct DoublePacket
{
Packet first;
Packet second;
};
template<typename Packet>
DoublePacket<Packet> padd(const DoublePacket<Packet> &a, const DoublePacket<Packet> &b)
{
DoublePacket<Packet> res;
res.first = padd(a.first, b.first);
res.second = padd(a.second,b.second);
return res;
}
template<typename RealScalar, bool _ConjLhs, bool _ConjRhs> template<typename RealScalar, bool _ConjLhs, bool _ConjRhs>
class gebp_traits<std::complex<RealScalar>, std::complex<RealScalar>, _ConjLhs, _ConjRhs > class gebp_traits<std::complex<RealScalar>, std::complex<RealScalar>, _ConjLhs, _ConjRhs >
{ {
@ -357,20 +376,16 @@ public:
typedef typename packet_traits<RealScalar>::type RealPacket; typedef typename packet_traits<RealScalar>::type RealPacket;
typedef typename packet_traits<Scalar>::type ScalarPacket; typedef typename packet_traits<Scalar>::type ScalarPacket;
struct DoublePacket typedef DoublePacket<RealPacket> DoublePacketType;
{
RealPacket first;
RealPacket second;
};
typedef typename conditional<Vectorizable,RealPacket, Scalar>::type LhsPacket; typedef typename conditional<Vectorizable,RealPacket, Scalar>::type LhsPacket;
typedef typename conditional<Vectorizable,DoublePacket,Scalar>::type RhsPacket; typedef typename conditional<Vectorizable,DoublePacketType,Scalar>::type RhsPacket;
typedef typename conditional<Vectorizable,ScalarPacket,Scalar>::type ResPacket; typedef typename conditional<Vectorizable,ScalarPacket,Scalar>::type ResPacket;
typedef typename conditional<Vectorizable,DoublePacket,Scalar>::type AccPacket; typedef typename conditional<Vectorizable,DoublePacketType,Scalar>::type AccPacket;
EIGEN_STRONG_INLINE void initAcc(Scalar& p) { p = Scalar(0); } EIGEN_STRONG_INLINE void initAcc(Scalar& p) { p = Scalar(0); }
EIGEN_STRONG_INLINE void initAcc(DoublePacket& p) EIGEN_STRONG_INLINE void initAcc(DoublePacketType& p)
{ {
p.first = pset1<RealPacket>(RealScalar(0)); p.first = pset1<RealPacket>(RealScalar(0));
p.second = pset1<RealPacket>(RealScalar(0)); p.second = pset1<RealPacket>(RealScalar(0));
@ -383,7 +398,7 @@ public:
} }
// Vectorized path // Vectorized path
EIGEN_STRONG_INLINE void loadRhs(const RhsScalar* b, DoublePacket& dest) const EIGEN_STRONG_INLINE void loadRhs(const RhsScalar* b, DoublePacketType& dest) const
{ {
dest.first = pset1<RealPacket>(real(*b)); dest.first = pset1<RealPacket>(real(*b));
dest.second = pset1<RealPacket>(imag(*b)); dest.second = pset1<RealPacket>(imag(*b));
@ -393,7 +408,7 @@ public:
void broadcastRhs(const RhsScalar* b, RhsPacket& b0, RhsPacket& b1, RhsPacket& b2, RhsPacket& b3); void broadcastRhs(const RhsScalar* b, RhsPacket& b0, RhsPacket& b1, RhsPacket& b2, RhsPacket& b3);
// Vectorized path // Vectorized path
EIGEN_STRONG_INLINE void broadcastRhs(const RhsScalar* b, DoublePacket& b0, DoublePacket& b1) EIGEN_STRONG_INLINE void broadcastRhs(const RhsScalar* b, DoublePacketType& b0, DoublePacketType& b1)
{ {
// FIXME not sure that's the best way to implement it! // FIXME not sure that's the best way to implement it!
loadRhs(b+0, b0); loadRhs(b+0, b0);
@ -419,7 +434,7 @@ public:
dest = ploadu<LhsPacket>((const typename unpacket_traits<LhsPacket>::type*)(a)); dest = ploadu<LhsPacket>((const typename unpacket_traits<LhsPacket>::type*)(a));
} }
EIGEN_STRONG_INLINE void madd(const LhsPacket& a, const RhsPacket& b, DoublePacket& c, RhsPacket& /*tmp*/) const EIGEN_STRONG_INLINE void madd(const LhsPacket& a, const RhsPacket& b, DoublePacketType& c, RhsPacket& /*tmp*/) const
{ {
c.first = padd(pmul(a,b.first), c.first); c.first = padd(pmul(a,b.first), c.first);
c.second = padd(pmul(a,b.second),c.second); c.second = padd(pmul(a,b.second),c.second);
@ -432,7 +447,7 @@ public:
EIGEN_STRONG_INLINE void acc(const Scalar& c, const Scalar& alpha, Scalar& r) const { r += alpha * c; } EIGEN_STRONG_INLINE void acc(const Scalar& c, const Scalar& alpha, Scalar& r) const { r += alpha * c; }
EIGEN_STRONG_INLINE void acc(const DoublePacket& c, const ResPacket& alpha, ResPacket& r) const EIGEN_STRONG_INLINE void acc(const DoublePacketType& c, const ResPacket& alpha, ResPacket& r) const
{ {
// assemble c // assemble c
ResPacket tmp; ResPacket tmp;
@ -572,6 +587,14 @@ struct gebp_kernel
typedef typename Traits::ResPacket ResPacket; typedef typename Traits::ResPacket ResPacket;
typedef typename Traits::AccPacket AccPacket; typedef typename Traits::AccPacket AccPacket;
typedef gebp_traits<RhsScalar,LhsScalar,ConjugateRhs,ConjugateLhs> SwappedTraits;
typedef typename SwappedTraits::ResScalar SResScalar;
typedef typename SwappedTraits::LhsPacket SLhsPacket;
typedef typename SwappedTraits::RhsPacket SRhsPacket;
typedef typename SwappedTraits::ResPacket SResPacket;
typedef typename SwappedTraits::AccPacket SAccPacket;
enum { enum {
Vectorizable = Traits::Vectorizable, Vectorizable = Traits::Vectorizable,
LhsProgress = Traits::LhsProgress, LhsProgress = Traits::LhsProgress,
@ -591,6 +614,7 @@ void gebp_kernel<LhsScalar,RhsScalar,Index,mr,nr,ConjugateLhs,ConjugateRhs>
Index strideA, Index strideB, Index offsetA, Index offsetB) Index strideA, Index strideB, Index offsetA, Index offsetB)
{ {
Traits traits; Traits traits;
SwappedTraits straits;
if(strideA==-1) strideA = depth; if(strideA==-1) strideA = depth;
if(strideB==-1) strideB = depth; if(strideB==-1) strideB = depth;
@ -599,7 +623,9 @@ void gebp_kernel<LhsScalar,RhsScalar,Index,mr,nr,ConjugateLhs,ConjugateRhs>
Index packet_cols4 = nr>=4 ? (cols/4) * 4 : 0; Index packet_cols4 = nr>=4 ? (cols/4) * 4 : 0;
// Here we assume that mr==LhsProgress // Here we assume that mr==LhsProgress
const Index peeled_mc = (rows/mr)*mr; const Index peeled_mc = (rows/mr)*mr;
const Index peeled_kc = (depth/4)*4; enum { pk = 8 }; // NOTE Such a large peeling factor is important for large matrices (~ +5% when >1000 on Haswell)
const Index peeled_kc = depth & ~(pk-1);
const Index depth2 = depth & ~1;
// loops on each micro vertical panel of rhs (depth x nr) // loops on each micro vertical panel of rhs (depth x nr)
// First pass using depth x 8 panels // First pass using depth x 8 panels
@ -634,14 +660,14 @@ void gebp_kernel<LhsScalar,RhsScalar,Index,mr,nr,ConjugateLhs,ConjugateRhs>
// uncomment for register prefetching // uncomment for register prefetching
// LhsPacket A1; // LhsPacket A1;
// traits.loadLhs(blA, A0); // traits.loadLhs(blA, A0);
for(Index k=0; k<peeled_kc; k+=4) for(Index k=0; k<peeled_kc; k+=pk)
{ {
EIGEN_ASM_COMMENT("begin gegp micro kernel 1p x 8"); EIGEN_ASM_COMMENT("begin gegp micro kernel 1p x 8");
RhsPacket B_0, B1, B2, B3; RhsPacket B_0, B1, B2, B3;
// The following version is faster on some architures // NOTE The following version is faster on some architures
// but sometimes leads to segfaults because it might read one packet outside the bounds // but sometimes leads to segfaults because it might read one packet outside the bounds
// To test it, you also need to uncomment the initialization of A0 above and the copy of A1 to A0 below. // To test it, you also need to uncomment the initialization of A0 above and the copy of A1 to A0 below.
#if 0 #if 0
#define EIGEN_GEBGP_ONESTEP8(K,L,M) \ #define EIGEN_GEBGP_ONESTEP8(K,L,M) \
traits.loadLhs(&blA[(K+1)*LhsProgress], L); \ traits.loadLhs(&blA[(K+1)*LhsProgress], L); \
@ -674,9 +700,13 @@ void gebp_kernel<LhsScalar,RhsScalar,Index,mr,nr,ConjugateLhs,ConjugateRhs>
EIGEN_GEBGP_ONESTEP8(1,A0,A1); EIGEN_GEBGP_ONESTEP8(1,A0,A1);
EIGEN_GEBGP_ONESTEP8(2,A1,A0); EIGEN_GEBGP_ONESTEP8(2,A1,A0);
EIGEN_GEBGP_ONESTEP8(3,A0,A1); EIGEN_GEBGP_ONESTEP8(3,A0,A1);
EIGEN_GEBGP_ONESTEP8(4,A1,A0);
EIGEN_GEBGP_ONESTEP8(5,A0,A1);
EIGEN_GEBGP_ONESTEP8(6,A1,A0);
EIGEN_GEBGP_ONESTEP8(7,A0,A1);
blB += 4*8*RhsProgress; blB += pk*8*RhsProgress;
blA += 4*mr; blA += pk*mr;
} }
// process remaining peeled loop // process remaining peeled loop
for(Index k=peeled_kc; k<depth; k++) for(Index k=peeled_kc; k<depth; k++)
@ -720,97 +750,170 @@ void gebp_kernel<LhsScalar,RhsScalar,Index,mr,nr,ConjugateLhs,ConjugateRhs>
pstoreu(r0+5*resStride, R5); pstoreu(r0+5*resStride, R5);
pstoreu(r0+6*resStride, R6); pstoreu(r0+6*resStride, R6);
pstoreu(r0+7*resStride, R0); pstoreu(r0+7*resStride, R0);
} }
for(Index i=peeled_mc; i<rows; i++) // Deal with remaining rows of the lhs
// TODO we should vectorize if <= 8, and not strictly ==
if(SwappedTraits::LhsProgress == 8)
{ {
const LhsScalar* blA = &blockA[i*strideA+offsetA]; // Apply the same logic but with reversed operands
prefetch(&blA[0]); // To improve pipelining, we process 2 rows at once and accumulate even and odd products along the k dimension
// into two different packets.
typedef gebp_traits<RhsScalar,LhsScalar,ConjugateRhs,ConjugateLhs> SwappedTraits;
typedef typename SwappedTraits::ResScalar SResScalar;
typedef typename SwappedTraits::LhsPacket SLhsPacket;
typedef typename SwappedTraits::RhsPacket SRhsPacket;
typedef typename SwappedTraits::ResPacket SResPacket;
typedef typename SwappedTraits::AccPacket SAccPacket;
SwappedTraits straits;
// FIXME directly use blockB ??? Index rows2 = (rows & ~1);
const RhsScalar* blB = &blockB[j2*strideB+offsetB*8]; for(Index i=peeled_mc; i<rows2; i+=2)
{
const LhsScalar* blA = &blockA[i*strideA+offsetA];
const RhsScalar* blB = &blockB[j2*strideB+offsetB*8];
if(nr == Traits::RhsPacketSize) EIGEN_ASM_COMMENT("begin_vectorized_multiplication_of_last_rows 2x8");
{
EIGEN_ASM_COMMENT("begin_vectorized_multiplication_of_last_rows");
typedef gebp_traits<RhsScalar,LhsScalar,ConjugateRhs,ConjugateLhs> SwappedTraits; SAccPacket C0,C1, C2,C3;
typedef typename SwappedTraits::ResScalar SResScalar; straits.initAcc(C0); // even
typedef typename SwappedTraits::LhsPacket SLhsPacket; straits.initAcc(C1); // odd
typedef typename SwappedTraits::RhsPacket SRhsPacket; straits.initAcc(C2); // even
typedef typename SwappedTraits::ResPacket SResPacket; straits.initAcc(C3); // odd
typedef typename SwappedTraits::AccPacket SAccPacket; for(Index k=0; k<depth2; k+=2)
SwappedTraits straits; {
SLhsPacket A0, A1;
straits.loadLhsUnaligned(blB+0, A0);
straits.loadLhsUnaligned(blB+8, A1);
SRhsPacket B_0, B_1, B_2, B_3;
straits.loadRhs(blA+k+0, B_0);
straits.loadRhs(blA+k+1, B_1);
straits.loadRhs(blA+strideA+k+0, B_2);
straits.loadRhs(blA+strideA+k+1, B_3);
straits.madd(A0,B_0,C0,B_0);
straits.madd(A1,B_1,C1,B_1);
straits.madd(A0,B_2,C2,B_2);
straits.madd(A1,B_3,C3,B_3);
blB += 2*nr;
}
if(depth2<depth)
{
Index k = depth-1;
SLhsPacket A0;
straits.loadLhsUnaligned(blB+0, A0);
SRhsPacket B_0, B_2;
straits.loadRhs(blA+k+0, B_0);
straits.loadRhs(blA+strideA+k+0, B_2);
straits.madd(A0,B_0,C0,B_0);
straits.madd(A0,B_2,C2,B_2);
}
SResPacket R = pgather<SResScalar, SResPacket>(&res[j2*resStride + i], resStride);
SResPacket alphav = pset1<SResPacket>(alpha);
straits.acc(padd(C0,C1), alphav, R);
pscatter(&res[j2*resStride + i], R, resStride);
SAccPacket C0; R = pgather<SResScalar, SResPacket>(&res[j2*resStride + i + 1], resStride);
straits.initAcc(C0); straits.acc(padd(C2,C3), alphav, R);
for(Index k=0; k<depth; k++) pscatter(&res[j2*resStride + i + 1], R, resStride);
{
SLhsPacket A0;
straits.loadLhsUnaligned(blB, A0);
SRhsPacket B_0;
straits.loadRhs(&blA[k], B_0);
SRhsPacket T0;
straits.madd(A0,B_0,C0,T0);
blB += nr;
}
SResPacket R = pgather<SResScalar, SResPacket>(&res[j2*resStride + i], resStride);
SResPacket alphav = pset1<SResPacket>(alpha);
straits.acc(C0, alphav, R);
pscatter(&res[j2*resStride + i], R, resStride);
EIGEN_ASM_COMMENT("end_vectorized_multiplication_of_last_rows"); EIGEN_ASM_COMMENT("end_vectorized_multiplication_of_last_rows 8");
} }
else if(rows2!=rows)
{ {
// gets a 1 x 8 res block as registers Index i = rows-1;
ResScalar C0(0), C1(0), C2(0), C3(0), C4(0), C5(0), C6(0), C7(0); const LhsScalar* blA = &blockA[i*strideA+offsetA];
const RhsScalar* blB = &blockB[j2*strideB+offsetB*8];
for(Index k=0; k<depth; k++) EIGEN_ASM_COMMENT("begin_vectorized_multiplication_of_last_rows 8");
{
LhsScalar A0;
RhsScalar B_0, B_1;
A0 = blA[k]; SAccPacket C0,C1;
straits.initAcc(C0); // even
straits.initAcc(C1); // odd
B_0 = blB[0]; for(Index k=0; k<depth2; k+=2)
B_1 = blB[1]; {
MADD(cj,A0,B_0,C0, B_0); SLhsPacket A0, A1;
MADD(cj,A0,B_1,C1, B_1); straits.loadLhsUnaligned(blB+0, A0);
straits.loadLhsUnaligned(blB+8, A1);
SRhsPacket B_0, B_1;
straits.loadRhs(blA+k+0, B_0);
straits.loadRhs(blA+k+1, B_1);
straits.madd(A0,B_0,C0,B_0);
straits.madd(A1,B_1,C1,B_1);
blB += 2*8;
}
if(depth!=depth2)
{
Index k = depth-1;
SLhsPacket A0;
straits.loadLhsUnaligned(blB+0, A0);
SRhsPacket B_0;
straits.loadRhs(blA+k+0, B_0);
straits.madd(A0,B_0,C0,B_0);
}
SResPacket R = pgather<SResScalar, SResPacket>(&res[j2*resStride + i], resStride);
SResPacket alphav = pset1<SResPacket>(alpha);
straits.acc(padd(C0,C1), alphav, R);
pscatter(&res[j2*resStride + i], R, resStride);
}
}
else
{
// Pure scalar path
for(Index i=peeled_mc; i<rows; i++)
{
const LhsScalar* blA = &blockA[i*strideA+offsetA];
const RhsScalar* blB = &blockB[j2*strideB+offsetB*8];
B_0 = blB[2]; // gets a 1 x 8 res block as registers
B_1 = blB[3]; ResScalar C0(0), C1(0), C2(0), C3(0), C4(0), C5(0), C6(0), C7(0);
MADD(cj,A0,B_0,C2, B_0);
MADD(cj,A0,B_1,C3, B_1);
B_0 = blB[4]; for(Index k=0; k<depth; k++)
B_1 = blB[5]; {
MADD(cj,A0,B_0,C4, B_0); LhsScalar A0;
MADD(cj,A0,B_1,C5, B_1); RhsScalar B_0, B_1;
B_0 = blB[6]; A0 = blA[k];
B_1 = blB[7];
MADD(cj,A0,B_0,C6, B_0);
MADD(cj,A0,B_1,C7, B_1);
blB += 8; B_0 = blB[0];
} B_1 = blB[1];
res[(j2+0)*resStride + i] += alpha*C0; MADD(cj,A0,B_0,C0, B_0);
res[(j2+1)*resStride + i] += alpha*C1; MADD(cj,A0,B_1,C1, B_1);
res[(j2+2)*resStride + i] += alpha*C2;
res[(j2+3)*resStride + i] += alpha*C3; B_0 = blB[2];
res[(j2+4)*resStride + i] += alpha*C4; B_1 = blB[3];
res[(j2+5)*resStride + i] += alpha*C5; MADD(cj,A0,B_0,C2, B_0);
res[(j2+6)*resStride + i] += alpha*C6; MADD(cj,A0,B_1,C3, B_1);
res[(j2+7)*resStride + i] += alpha*C7;
} B_0 = blB[4];
} B_1 = blB[5];
MADD(cj,A0,B_0,C4, B_0);
MADD(cj,A0,B_1,C5, B_1);
B_0 = blB[6];
B_1 = blB[7];
MADD(cj,A0,B_0,C6, B_0);
MADD(cj,A0,B_1,C7, B_1);
blB += 8;
}
res[(j2+0)*resStride + i] += alpha*C0;
res[(j2+1)*resStride + i] += alpha*C1;
res[(j2+2)*resStride + i] += alpha*C2;
res[(j2+3)*resStride + i] += alpha*C3;
res[(j2+4)*resStride + i] += alpha*C4;
res[(j2+5)*resStride + i] += alpha*C5;
res[(j2+6)*resStride + i] += alpha*C6;
res[(j2+7)*resStride + i] += alpha*C7;
}
}
} }
} }
// Second pass using depth x 4 panels // Second pass using depth x 4 panels
// If nr==8, then we have at most one such panel // If nr==8, then we have at most one such panel
// TODO: with 16 registers, we coud optimize this part to leverage more pipelinining,
// for instance, by using a 2 packet * 4 kernel. Useful when the rhs is thin
if(nr>=4) if(nr>=4)
{ {
for(Index j2=packet_cols8; j2<packet_cols4; j2+=4) for(Index j2=packet_cols8; j2<packet_cols4; j2+=4)
@ -821,7 +924,6 @@ void gebp_kernel<LhsScalar,RhsScalar,Index,mr,nr,ConjugateLhs,ConjugateRhs>
for(Index i=0; i<peeled_mc; i+=mr) for(Index i=0; i<peeled_mc; i+=mr)
{ {
const LhsScalar* blA = &blockA[i*strideA+offsetA*mr]; const LhsScalar* blA = &blockA[i*strideA+offsetA*mr];
// prefetch(&blA[0]);
// gets res block as register // gets res block as register
AccPacket C0, C1, C2, C3; AccPacket C0, C1, C2, C3;
@ -835,15 +937,11 @@ void gebp_kernel<LhsScalar,RhsScalar,Index,mr,nr,ConjugateLhs,ConjugateRhs>
// performs "inner" products // performs "inner" products
const RhsScalar* blB = &blockB[j2*strideB+offsetB*4]; const RhsScalar* blB = &blockB[j2*strideB+offsetB*4];
LhsPacket A0; LhsPacket A0;
// uncomment for register prefetching for(Index k=0; k<peeled_kc; k+=pk)
// LhsPacket A1;
// traits.loadLhs(blA, A0);
for(Index k=0; k<peeled_kc; k+=4)
{ {
EIGEN_ASM_COMMENT("begin gegp micro kernel 1p x 4"); EIGEN_ASM_COMMENT("begin gegp micro kernel 1p x 4");
RhsPacket B_0, B1; RhsPacket B_0, B1;
#define EIGEN_GEBGP_ONESTEP4(K) \ #define EIGEN_GEBGP_ONESTEP4(K) \
traits.loadLhs(&blA[K*LhsProgress], A0); \ traits.loadLhs(&blA[K*LhsProgress], A0); \
traits.broadcastRhs(&blB[0+4*K*RhsProgress], B_0, B1); \ traits.broadcastRhs(&blB[0+4*K*RhsProgress], B_0, B1); \
@ -857,11 +955,15 @@ void gebp_kernel<LhsScalar,RhsScalar,Index,mr,nr,ConjugateLhs,ConjugateRhs>
EIGEN_GEBGP_ONESTEP4(1); EIGEN_GEBGP_ONESTEP4(1);
EIGEN_GEBGP_ONESTEP4(2); EIGEN_GEBGP_ONESTEP4(2);
EIGEN_GEBGP_ONESTEP4(3); EIGEN_GEBGP_ONESTEP4(3);
EIGEN_GEBGP_ONESTEP4(4);
EIGEN_GEBGP_ONESTEP4(5);
EIGEN_GEBGP_ONESTEP4(6);
EIGEN_GEBGP_ONESTEP4(7);
blB += 4*4*RhsProgress; blB += pk*4*RhsProgress;
blA += 4*mr; blA += pk*mr;
} }
// process remaining peeled loop // process remaining of peeled loop
for(Index k=peeled_kc; k<depth; k++) for(Index k=peeled_kc; k<depth; k++)
{ {
RhsPacket B_0, B1; RhsPacket B_0, B1;
@ -894,98 +996,86 @@ void gebp_kernel<LhsScalar,RhsScalar,Index,mr,nr,ConjugateLhs,ConjugateRhs>
for(Index i=peeled_mc; i<rows; i++) for(Index i=peeled_mc; i<rows; i++)
{ {
const LhsScalar* blA = &blockA[i*strideA+offsetA]; const LhsScalar* blA = &blockA[i*strideA+offsetA];
prefetch(&blA[0]);
// FIXME directly use blockB ???
const RhsScalar* blB = &blockB[j2*strideB+offsetB*4]; const RhsScalar* blB = &blockB[j2*strideB+offsetB*4];
if(nr == Traits::RhsPacketSize) // TODO vectorize in more cases
{ if(SwappedTraits::LhsProgress==4)
EIGEN_ASM_COMMENT("begin_vectorized_multiplication_of_last_rows"); {
EIGEN_ASM_COMMENT("begin_vectorized_multiplication_of_last_rows 1x4");
typedef gebp_traits<RhsScalar,LhsScalar,ConjugateRhs,ConjugateLhs> SwappedTraits; SAccPacket C0;
typedef typename SwappedTraits::ResScalar SResScalar; straits.initAcc(C0);
typedef typename SwappedTraits::LhsPacket SLhsPacket; for(Index k=0; k<depth; k++)
typedef typename SwappedTraits::RhsPacket SRhsPacket;
typedef typename SwappedTraits::ResPacket SResPacket;
typedef typename SwappedTraits::AccPacket SAccPacket;
SwappedTraits straits;
SAccPacket C0;
straits.initAcc(C0);
for(Index k=0; k<depth; k++)
{ {
SLhsPacket A0; SLhsPacket A0;
straits.loadLhsUnaligned(blB, A0); straits.loadLhsUnaligned(blB, A0);
SRhsPacket B_0; SRhsPacket B_0;
straits.loadRhs(&blA[k], B_0); straits.loadRhs(&blA[k], B_0);
SRhsPacket T0; SRhsPacket T0;
straits.madd(A0,B_0,C0,T0); straits.madd(A0,B_0,C0,T0);
blB += nr; blB += 4;
} }
SResPacket R = pgather<SResScalar, SResPacket>(&res[j2*resStride + i], resStride); SResPacket R = pgather<SResScalar, SResPacket>(&res[j2*resStride + i], resStride);
SResPacket alphav = pset1<SResPacket>(alpha); SResPacket alphav = pset1<SResPacket>(alpha);
straits.acc(C0, alphav, R); straits.acc(C0, alphav, R);
pscatter(&res[j2*resStride + i], R, resStride); pscatter(&res[j2*resStride + i], R, resStride);
EIGEN_ASM_COMMENT("end_vectorized_multiplication_of_last_rows"); EIGEN_ASM_COMMENT("end_vectorized_multiplication_of_last_rows 1x4");
} else { }
// gets a 1 x 4 res block as registers else
ResScalar C0(0), C1(0), C2(0), C3(0); {
// Pure scalar path
// gets a 1 x 4 res block as registers
ResScalar C0(0), C1(0), C2(0), C3(0);
for(Index k=0; k<depth; k++) for(Index k=0; k<depth; k++)
{ {
LhsScalar A0; LhsScalar A0;
RhsScalar B_0, B_1; RhsScalar B_0, B_1;
A0 = blA[k]; A0 = blA[k];
B_0 = blB[0]; B_0 = blB[0];
B_1 = blB[1]; B_1 = blB[1];
MADD(cj,A0,B_0,C0, B_0); MADD(cj,A0,B_0,C0, B_0);
MADD(cj,A0,B_1,C1, B_1); MADD(cj,A0,B_1,C1, B_1);
B_0 = blB[2]; B_0 = blB[2];
B_1 = blB[3]; B_1 = blB[3];
MADD(cj,A0,B_0,C2, B_0); MADD(cj,A0,B_0,C2, B_0);
MADD(cj,A0,B_1,C3, B_1); MADD(cj,A0,B_1,C3, B_1);
blB += 4; blB += 4;
} }
res[(j2+0)*resStride + i] += alpha*C0; res[(j2+0)*resStride + i] += alpha*C0;
res[(j2+1)*resStride + i] += alpha*C1; res[(j2+1)*resStride + i] += alpha*C1;
res[(j2+2)*resStride + i] += alpha*C2; res[(j2+2)*resStride + i] += alpha*C2;
res[(j2+3)*resStride + i] += alpha*C3; res[(j2+3)*resStride + i] += alpha*C3;
} }
} }
} }
} }
// process remaining rhs/res columns one at a time // process remaining rhs/res columns one at a time
// => do the same but with nr==1
for(Index j2=packet_cols4; j2<cols; j2++) for(Index j2=packet_cols4; j2<cols; j2++)
{ {
// vectorized path
for(Index i=0; i<peeled_mc; i+=mr) for(Index i=0; i<peeled_mc; i+=mr)
{ {
const LhsScalar* blA = &blockA[i*strideA+offsetA*mr];
prefetch(&blA[0]);
// TODO move the res loads to the stores
// get res block as registers // get res block as registers
AccPacket C0; AccPacket C0;
traits.initAcc(C0); traits.initAcc(C0);
const LhsScalar* blA = &blockA[i*strideA+offsetA*mr];
const RhsScalar* blB = &blockB[j2*strideB+offsetB]; const RhsScalar* blB = &blockB[j2*strideB+offsetB];
for(Index k=0; k<depth; k++) for(Index k=0; k<depth; k++)
{ {
LhsPacket A0; LhsPacket A0;
RhsPacket B_0; RhsPacket B_0;
RhsPacket T0;
traits.loadLhs(&blA[0*LhsProgress], A0); traits.loadLhs(blA, A0);
traits.loadRhs(&blB[0*RhsProgress], B_0); traits.loadRhs(blB, B_0);
traits.madd(A0,B_0,C0,T0); traits.madd(A0,B_0,C0,B_0);
blB += RhsProgress; blB += RhsProgress;
blA += LhsProgress; blA += LhsProgress;
@ -997,14 +1087,12 @@ void gebp_kernel<LhsScalar,RhsScalar,Index,mr,nr,ConjugateLhs,ConjugateRhs>
traits.acc(C0, alphav, R0); traits.acc(C0, alphav, R0);
pstoreu(r0, R0); pstoreu(r0, R0);
} }
// pure scalar path
for(Index i=peeled_mc; i<rows; i++) for(Index i=peeled_mc; i<rows; i++)
{ {
const LhsScalar* blA = &blockA[i*strideA+offsetA]; const LhsScalar* blA = &blockA[i*strideA+offsetA];
prefetch(&blA[0]);
// gets a 1 x 1 res block as registers // gets a 1 x 1 res block as registers
ResScalar C0(0); ResScalar C0(0);
// FIXME directly use blockB ??
const RhsScalar* blB = &blockB[j2*strideB+offsetB]; const RhsScalar* blB = &blockB[j2*strideB+offsetB];
for(Index k=0; k<depth; k++) for(Index k=0; k<depth; k++)
{ {