Fix and optimize mixed products

This commit is contained in:
Gael Guennebaud 2014-04-17 16:04:30 +02:00
parent 0fa8290366
commit 11fbdcbc38

View File

@ -180,14 +180,15 @@ public:
NumberOfRegisters = EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS, NumberOfRegisters = EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS,
// register block size along the N direction (must be either 2 or 4) // register block size along the N direction must be 1 or 4
nr = 4,//NumberOfRegisters/4, nr = 4,
// register block size along the M direction (currently, this one cannot be modified) // register block size along the M direction (currently, this one cannot be modified)
#ifdef __FMA__ #ifdef __FMA__
// we assume 16 registers
mr = 3*LhsPacketSize, mr = 3*LhsPacketSize,
#else #else
mr = 2*LhsPacketSize, mr = (EIGEN_PLAIN_ENUM_MIN(16,NumberOfRegisters)/2/nr)*LhsPacketSize,
#endif #endif
LhsProgress = LhsPacketSize, LhsProgress = LhsPacketSize,
@ -209,15 +210,15 @@ public:
p = pset1<ResPacket>(ResScalar(0)); p = pset1<ResPacket>(ResScalar(0));
} }
EIGEN_STRONG_INLINE void broadcastRhs(const RhsScalar* b, RhsPacket& b0, RhsPacket& b1, RhsPacket& b2, RhsPacket& b3) // EIGEN_STRONG_INLINE void broadcastRhs(const RhsScalar* b, RhsPacket& b0, RhsPacket& b1, RhsPacket& b2, RhsPacket& b3)
{ // {
pbroadcast4(b, b0, b1, b2, b3); // pbroadcast4(b, b0, b1, b2, b3);
} // }
//
EIGEN_STRONG_INLINE void broadcastRhs(const RhsScalar* b, RhsPacket& b0, RhsPacket& b1) // EIGEN_STRONG_INLINE void broadcastRhs(const RhsScalar* b, RhsPacket& b0, RhsPacket& b1)
{ // {
pbroadcast2(b, b0, b1); // pbroadcast2(b, b0, b1);
} // }
template<typename RhsPacketType> template<typename RhsPacketType>
EIGEN_STRONG_INLINE void loadRhs(const RhsScalar* b, RhsPacketType& dest) const EIGEN_STRONG_INLINE void loadRhs(const RhsScalar* b, RhsPacketType& dest) const
@ -290,8 +291,13 @@ public:
ResPacketSize = Vectorizable ? packet_traits<ResScalar>::size : 1, ResPacketSize = Vectorizable ? packet_traits<ResScalar>::size : 1,
NumberOfRegisters = EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS, NumberOfRegisters = EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS,
nr = NumberOfRegisters/2, nr = 4,
mr = LhsPacketSize, #ifdef __FMA__
// we assume 16 registers
mr = 3*LhsPacketSize,
#else
mr = (EIGEN_PLAIN_ENUM_MIN(16,NumberOfRegisters)/2/nr)*LhsPacketSize,
#endif
LhsProgress = LhsPacketSize, LhsProgress = LhsPacketSize,
RhsProgress = 1 RhsProgress = 1
@ -332,15 +338,15 @@ public:
dest = ploadu<LhsPacket>(a); dest = ploadu<LhsPacket>(a);
} }
EIGEN_STRONG_INLINE void broadcastRhs(const RhsScalar* b, RhsPacket& b0, RhsPacket& b1, RhsPacket& b2, RhsPacket& b3) // EIGEN_STRONG_INLINE void broadcastRhs(const RhsScalar* b, RhsPacket& b0, RhsPacket& b1, RhsPacket& b2, RhsPacket& b3)
{ // {
pbroadcast4(b, b0, b1, b2, b3); // pbroadcast4(b, b0, b1, b2, b3);
} // }
//
EIGEN_STRONG_INLINE void broadcastRhs(const RhsScalar* b, RhsPacket& b0, RhsPacket& b1) // EIGEN_STRONG_INLINE void broadcastRhs(const RhsScalar* b, RhsPacket& b0, RhsPacket& b1)
{ // {
pbroadcast2(b, b0, b1); // pbroadcast2(b, b0, b1);
} // }
EIGEN_STRONG_INLINE void madd(const LhsPacket& a, const RhsPacket& b, AccPacket& c, RhsPacket& tmp) const EIGEN_STRONG_INLINE void madd(const LhsPacket& a, const RhsPacket& b, AccPacket& c, RhsPacket& tmp) const
{ {
@ -566,7 +572,7 @@ public:
NumberOfRegisters = EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS, NumberOfRegisters = EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS,
// FIXME: should depend on NumberOfRegisters // FIXME: should depend on NumberOfRegisters
nr = 4, nr = 4,
mr = ResPacketSize, mr = (EIGEN_PLAIN_ENUM_MIN(16,NumberOfRegisters)/2/nr)*ResPacketSize,
LhsProgress = ResPacketSize, LhsProgress = ResPacketSize,
RhsProgress = 1 RhsProgress = 1
@ -593,20 +599,26 @@ public:
} }
// linking error if instantiated without being optimized out: // linking error if instantiated without being optimized out:
void broadcastRhs(const RhsScalar* b, RhsPacket& b0, RhsPacket& b1, RhsPacket& b2, RhsPacket& b3); // void broadcastRhs(const RhsScalar* b, RhsPacket& b0, RhsPacket& b1, RhsPacket& b2, RhsPacket& b3);
//
EIGEN_STRONG_INLINE void broadcastRhs(const RhsScalar* b, RhsPacket& b0, RhsPacket& b1) // EIGEN_STRONG_INLINE void broadcastRhs(const RhsScalar* b, RhsPacket& b0, RhsPacket& b1)
{ // {
// FIXME not sure that's the best way to implement it! // // FIXME not sure that's the best way to implement it!
b0 = pload1<RhsPacket>(b+0); // b0 = pload1<RhsPacket>(b+0);
b1 = pload1<RhsPacket>(b+1); // b1 = pload1<RhsPacket>(b+1);
} // }
EIGEN_STRONG_INLINE void loadLhs(const LhsScalar* a, LhsPacket& dest) const EIGEN_STRONG_INLINE void loadLhs(const LhsScalar* a, LhsPacket& dest) const
{ {
dest = ploaddup<LhsPacket>(a); dest = ploaddup<LhsPacket>(a);
} }
EIGEN_STRONG_INLINE void loadRhsQuad(const RhsScalar* b, RhsPacket& dest) const
{
eigen_internal_assert(unpacket_traits<RhsPacket>::size<=4);
loadRhs(b,dest);
}
EIGEN_STRONG_INLINE void loadLhsUnaligned(const LhsScalar* a, LhsPacket& dest) const EIGEN_STRONG_INLINE void loadLhsUnaligned(const LhsScalar* a, LhsPacket& dest) const
{ {
dest = ploaddup<LhsPacket>(a); dest = ploaddup<LhsPacket>(a);
@ -619,7 +631,13 @@ public:
EIGEN_STRONG_INLINE void madd_impl(const LhsPacket& a, const RhsPacket& b, AccPacket& c, RhsPacket& tmp, const true_type&) const EIGEN_STRONG_INLINE void madd_impl(const LhsPacket& a, const RhsPacket& b, AccPacket& c, RhsPacket& tmp, const true_type&) const
{ {
#ifdef EIGEN_VECTORIZE_FMA
EIGEN_UNUSED_VARIABLE(tmp);
c.v = pmadd(a,b.v,c.v);
#else
tmp = b; tmp.v = pmul(a,tmp.v); c = padd(c,tmp); tmp = b; tmp.v = pmul(a,tmp.v); c = padd(c,tmp);
#endif
} }
EIGEN_STRONG_INLINE void madd_impl(const LhsScalar& a, const RhsScalar& b, ResScalar& c, RhsScalar& /*tmp*/, const false_type&) const EIGEN_STRONG_INLINE void madd_impl(const LhsScalar& a, const RhsScalar& b, ResScalar& c, RhsScalar& /*tmp*/, const false_type&) const
@ -956,7 +974,7 @@ void gebp_kernel<LhsScalar,RhsScalar,Index,mr,nr,ConjugateLhs,ConjugateRhs>
for(Index k=0; k<peeled_kc; k+=pk) for(Index k=0; k<peeled_kc; k+=pk)
{ {
IACA_START IACA_START
EIGEN_ASM_COMMENT("begin gegp micro kernel 2p x 4"); EIGEN_ASM_COMMENT("begin gegp micro kernel 2pX4");
RhsPacket B_0, B1; RhsPacket B_0, B1;
#define EIGEN_GEBGP_ONESTEP(K) \ #define EIGEN_GEBGP_ONESTEP(K) \
@ -1134,7 +1152,7 @@ void gebp_kernel<LhsScalar,RhsScalar,Index,mr,nr,ConjugateLhs,ConjugateRhs>
for(Index k=0; k<peeled_kc; k+=pk) for(Index k=0; k<peeled_kc; k+=pk)
{ {
IACA_START IACA_START
EIGEN_ASM_COMMENT("begin gegp micro kernel 2p x 4"); EIGEN_ASM_COMMENT("begin gegp micro kernel 1pX4");
RhsPacket B_0, B1; RhsPacket B_0, B1;
#define EIGEN_GEBGP_ONESTEP(K) \ #define EIGEN_GEBGP_ONESTEP(K) \
@ -1160,7 +1178,7 @@ void gebp_kernel<LhsScalar,RhsScalar,Index,mr,nr,ConjugateLhs,ConjugateRhs>
EIGEN_GEBGP_ONESTEP(7); EIGEN_GEBGP_ONESTEP(7);
blB += pk*4*RhsProgress; blB += pk*4*RhsProgress;
blA += pk*(1*Traits::LhsProgress); blA += pk*1*LhsProgress;
IACA_END IACA_END
} }
// process remaining peeled loop // process remaining peeled loop
@ -1169,7 +1187,7 @@ void gebp_kernel<LhsScalar,RhsScalar,Index,mr,nr,ConjugateLhs,ConjugateRhs>
RhsPacket B_0, B1; RhsPacket B_0, B1;
EIGEN_GEBGP_ONESTEP(0); EIGEN_GEBGP_ONESTEP(0);
blB += 4*RhsProgress; blB += 4*RhsProgress;
blA += 1*Traits::LhsProgress; blA += 1*LhsProgress;
} }
#undef EIGEN_GEBGP_ONESTEP #undef EIGEN_GEBGP_ONESTEP
@ -1439,6 +1457,8 @@ EIGEN_DONT_INLINE void gemm_pack_lhs<Scalar, Index, Pack1, Pack2, ColMajor, Conj
const Index peeled_mc3 = Pack1>=3*PacketSize ? (rows/(3*PacketSize))*(3*PacketSize) : 0; const Index peeled_mc3 = Pack1>=3*PacketSize ? (rows/(3*PacketSize))*(3*PacketSize) : 0;
const Index peeled_mc2 = Pack1>=2*PacketSize ? peeled_mc3+((rows-peeled_mc3)/(2*PacketSize))*(2*PacketSize) : 0; const Index peeled_mc2 = Pack1>=2*PacketSize ? peeled_mc3+((rows-peeled_mc3)/(2*PacketSize))*(2*PacketSize) : 0;
const Index peeled_mc1 = Pack1>=1*PacketSize ? (rows/(1*PacketSize))*(1*PacketSize) : 0; const Index peeled_mc1 = Pack1>=1*PacketSize ? (rows/(1*PacketSize))*(1*PacketSize) : 0;
const Index peeled_mc0 = Pack2>=1*PacketSize ? peeled_mc1
: Pack2>1 ? (rows/Pack2)*Pack2 : 0;
// Pack 3 packets // Pack 3 packets
if(Pack1>=3*PacketSize) if(Pack1>=3*PacketSize)
@ -1496,16 +1516,20 @@ EIGEN_DONT_INLINE void gemm_pack_lhs<Scalar, Index, Pack1, Pack2, ColMajor, Conj
} }
} }
// Pack scalars // Pack scalars
// if(rows-peeled_mc>=Pack2) if(Pack2<PacketSize && Pack2>1)
// { {
// if(PanelMode) count += Pack2*offset; for(Index i=peeled_mc1; i<peeled_mc0; i+=Pack2)
// for(Index k=0; k<depth; k++) {
// for(Index w=0; w<Pack2; w++) if(PanelMode) count += Pack2 * offset;
// blockA[count++] = cj(lhs(peeled_mc+w, k));
// if(PanelMode) count += Pack2 * (stride-offset-depth); for(Index k=0; k<depth; k++)
// peeled_mc += Pack2; for(Index w=0; w<Pack2; w++)
// } blockA[count++] = cj(lhs(i+w, k));
for(Index i=peeled_mc1; i<rows; i++)
if(PanelMode) count += Pack2 * (stride-offset-depth);
}
}
for(Index i=peeled_mc0; i<rows; i++)
{ {
if(PanelMode) count += offset; if(PanelMode) count += offset;
for(Index k=0; k<depth; k++) for(Index k=0; k<depth; k++)
@ -1539,35 +1563,36 @@ EIGEN_DONT_INLINE void gemm_pack_lhs<Scalar, Index, Pack1, Pack2, RowMajor, Conj
// const Index peeled_mc2 = Pack1>=2*PacketSize ? peeled_mc3+((rows-peeled_mc3)/(2*PacketSize))*(2*PacketSize) : 0; // const Index peeled_mc2 = Pack1>=2*PacketSize ? peeled_mc3+((rows-peeled_mc3)/(2*PacketSize))*(2*PacketSize) : 0;
// const Index peeled_mc1 = Pack1>=1*PacketSize ? (rows/(1*PacketSize))*(1*PacketSize) : 0; // const Index peeled_mc1 = Pack1>=1*PacketSize ? (rows/(1*PacketSize))*(1*PacketSize) : 0;
int pack_packets = Pack1/PacketSize; int pack = Pack1;
Index i = 0; Index i = 0;
while(pack_packets>0) while(pack>0)
{ {
Index remaining_rows = rows-i; Index remaining_rows = rows-i;
Index peeled_mc = i+(remaining_rows/(pack_packets*PacketSize))*(pack_packets*PacketSize); Index peeled_mc = i+(remaining_rows/pack)*pack;
// std::cout << "pack_packets = " << pack_packets << " from " << i << " to " << peeled_mc << "\n"; for(; i<peeled_mc; i+=pack)
for(; i<peeled_mc; i+=pack_packets*PacketSize)
{ {
if(PanelMode) count += (pack_packets*PacketSize) * offset; if(PanelMode) count += pack * offset;
const Index peeled_k = (depth/PacketSize)*PacketSize; const Index peeled_k = (depth/PacketSize)*PacketSize;
Index k=0; Index k=0;
if(pack>=PacketSize)
{
for(; k<peeled_k; k+=PacketSize) for(; k<peeled_k; k+=PacketSize)
{ {
for (Index m = 0; m < (pack_packets*PacketSize); m += PacketSize) for (Index m = 0; m < pack; m += PacketSize)
{ {
Kernel<Packet> kernel; Kernel<Packet> kernel;
for (int p = 0; p < PacketSize; ++p) kernel.packet[p] = ploadu<Packet>(&lhs(i+p+m, k)); for (int p = 0; p < PacketSize; ++p) kernel.packet[p] = ploadu<Packet>(&lhs(i+p+m, k));
ptranspose(kernel); ptranspose(kernel);
for (int p = 0; p < PacketSize; ++p) pstore(blockA+count+m+(pack_packets*PacketSize)*p, cj.pconj(kernel.packet[p])); for (int p = 0; p < PacketSize; ++p) pstore(blockA+count+m+(pack)*p, cj.pconj(kernel.packet[p]));
}
count += PacketSize*pack;
} }
count += PacketSize*(pack_packets*PacketSize);
} }
for(; k<depth; k++) for(; k<depth; k++)
{ {
Index w=0; Index w=0;
for(; w<(pack_packets*PacketSize)-3; w+=4) for(; w<pack-3; w+=4)
{ {
Scalar a(cj(lhs(i+w+0, k))), Scalar a(cj(lhs(i+w+0, k))),
b(cj(lhs(i+w+1, k))), b(cj(lhs(i+w+1, k))),
@ -1578,26 +1603,19 @@ EIGEN_DONT_INLINE void gemm_pack_lhs<Scalar, Index, Pack1, Pack2, RowMajor, Conj
blockA[count++] = c; blockA[count++] = c;
blockA[count++] = d; blockA[count++] = d;
} }
if(PacketSize%4) if(pack%4)
for(;w<pack_packets*PacketSize;++w) for(;w<pack;++w)
blockA[count++] = cj(lhs(i+w, k)); blockA[count++] = cj(lhs(i+w, k));
} }
if(PanelMode) count += (pack_packets*PacketSize) * (stride-offset-depth); if(PanelMode) count += pack * (stride-offset-depth);
} }
pack_packets--; pack -= PacketSize;
if(pack<Pack2 && (pack+PacketSize)!=Pack2)
pack = Pack2;
} }
// if(rows-peeled_mc>=Pack2)
// {
// if(PanelMode) count += Pack2*offset;
// for(Index k=0; k<depth; k++)
// for(Index w=0; w<Pack2; w++)
// blockA[count++] = cj(lhs(peeled_mc+w, k));
// if(PanelMode) count += Pack2 * (stride-offset-depth);
// peeled_mc += Pack2;
// }
for(; i<rows; i++) for(; i<rows; i++)
{ {
if(PanelMode) count += offset; if(PanelMode) count += offset;