mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-09-12 09:23:12 +08:00
Fix and optimize mixed products
This commit is contained in:
parent
0fa8290366
commit
11fbdcbc38
@ -180,14 +180,15 @@ public:
|
|||||||
|
|
||||||
NumberOfRegisters = EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS,
|
NumberOfRegisters = EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS,
|
||||||
|
|
||||||
// register block size along the N direction (must be either 2 or 4)
|
// register block size along the N direction must be 1 or 4
|
||||||
nr = 4,//NumberOfRegisters/4,
|
nr = 4,
|
||||||
|
|
||||||
// register block size along the M direction (currently, this one cannot be modified)
|
// register block size along the M direction (currently, this one cannot be modified)
|
||||||
#ifdef __FMA__
|
#ifdef __FMA__
|
||||||
|
// we assume 16 registers
|
||||||
mr = 3*LhsPacketSize,
|
mr = 3*LhsPacketSize,
|
||||||
#else
|
#else
|
||||||
mr = 2*LhsPacketSize,
|
mr = (EIGEN_PLAIN_ENUM_MIN(16,NumberOfRegisters)/2/nr)*LhsPacketSize,
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
LhsProgress = LhsPacketSize,
|
LhsProgress = LhsPacketSize,
|
||||||
@ -209,15 +210,15 @@ public:
|
|||||||
p = pset1<ResPacket>(ResScalar(0));
|
p = pset1<ResPacket>(ResScalar(0));
|
||||||
}
|
}
|
||||||
|
|
||||||
EIGEN_STRONG_INLINE void broadcastRhs(const RhsScalar* b, RhsPacket& b0, RhsPacket& b1, RhsPacket& b2, RhsPacket& b3)
|
// EIGEN_STRONG_INLINE void broadcastRhs(const RhsScalar* b, RhsPacket& b0, RhsPacket& b1, RhsPacket& b2, RhsPacket& b3)
|
||||||
{
|
// {
|
||||||
pbroadcast4(b, b0, b1, b2, b3);
|
// pbroadcast4(b, b0, b1, b2, b3);
|
||||||
}
|
// }
|
||||||
|
//
|
||||||
EIGEN_STRONG_INLINE void broadcastRhs(const RhsScalar* b, RhsPacket& b0, RhsPacket& b1)
|
// EIGEN_STRONG_INLINE void broadcastRhs(const RhsScalar* b, RhsPacket& b0, RhsPacket& b1)
|
||||||
{
|
// {
|
||||||
pbroadcast2(b, b0, b1);
|
// pbroadcast2(b, b0, b1);
|
||||||
}
|
// }
|
||||||
|
|
||||||
template<typename RhsPacketType>
|
template<typename RhsPacketType>
|
||||||
EIGEN_STRONG_INLINE void loadRhs(const RhsScalar* b, RhsPacketType& dest) const
|
EIGEN_STRONG_INLINE void loadRhs(const RhsScalar* b, RhsPacketType& dest) const
|
||||||
@ -290,8 +291,13 @@ public:
|
|||||||
ResPacketSize = Vectorizable ? packet_traits<ResScalar>::size : 1,
|
ResPacketSize = Vectorizable ? packet_traits<ResScalar>::size : 1,
|
||||||
|
|
||||||
NumberOfRegisters = EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS,
|
NumberOfRegisters = EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS,
|
||||||
nr = NumberOfRegisters/2,
|
nr = 4,
|
||||||
mr = LhsPacketSize,
|
#ifdef __FMA__
|
||||||
|
// we assume 16 registers
|
||||||
|
mr = 3*LhsPacketSize,
|
||||||
|
#else
|
||||||
|
mr = (EIGEN_PLAIN_ENUM_MIN(16,NumberOfRegisters)/2/nr)*LhsPacketSize,
|
||||||
|
#endif
|
||||||
|
|
||||||
LhsProgress = LhsPacketSize,
|
LhsProgress = LhsPacketSize,
|
||||||
RhsProgress = 1
|
RhsProgress = 1
|
||||||
@ -332,15 +338,15 @@ public:
|
|||||||
dest = ploadu<LhsPacket>(a);
|
dest = ploadu<LhsPacket>(a);
|
||||||
}
|
}
|
||||||
|
|
||||||
EIGEN_STRONG_INLINE void broadcastRhs(const RhsScalar* b, RhsPacket& b0, RhsPacket& b1, RhsPacket& b2, RhsPacket& b3)
|
// EIGEN_STRONG_INLINE void broadcastRhs(const RhsScalar* b, RhsPacket& b0, RhsPacket& b1, RhsPacket& b2, RhsPacket& b3)
|
||||||
{
|
// {
|
||||||
pbroadcast4(b, b0, b1, b2, b3);
|
// pbroadcast4(b, b0, b1, b2, b3);
|
||||||
}
|
// }
|
||||||
|
//
|
||||||
EIGEN_STRONG_INLINE void broadcastRhs(const RhsScalar* b, RhsPacket& b0, RhsPacket& b1)
|
// EIGEN_STRONG_INLINE void broadcastRhs(const RhsScalar* b, RhsPacket& b0, RhsPacket& b1)
|
||||||
{
|
// {
|
||||||
pbroadcast2(b, b0, b1);
|
// pbroadcast2(b, b0, b1);
|
||||||
}
|
// }
|
||||||
|
|
||||||
EIGEN_STRONG_INLINE void madd(const LhsPacket& a, const RhsPacket& b, AccPacket& c, RhsPacket& tmp) const
|
EIGEN_STRONG_INLINE void madd(const LhsPacket& a, const RhsPacket& b, AccPacket& c, RhsPacket& tmp) const
|
||||||
{
|
{
|
||||||
@ -566,7 +572,7 @@ public:
|
|||||||
NumberOfRegisters = EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS,
|
NumberOfRegisters = EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS,
|
||||||
// FIXME: should depend on NumberOfRegisters
|
// FIXME: should depend on NumberOfRegisters
|
||||||
nr = 4,
|
nr = 4,
|
||||||
mr = ResPacketSize,
|
mr = (EIGEN_PLAIN_ENUM_MIN(16,NumberOfRegisters)/2/nr)*ResPacketSize,
|
||||||
|
|
||||||
LhsProgress = ResPacketSize,
|
LhsProgress = ResPacketSize,
|
||||||
RhsProgress = 1
|
RhsProgress = 1
|
||||||
@ -593,20 +599,26 @@ public:
|
|||||||
}
|
}
|
||||||
|
|
||||||
// linking error if instantiated without being optimized out:
|
// linking error if instantiated without being optimized out:
|
||||||
void broadcastRhs(const RhsScalar* b, RhsPacket& b0, RhsPacket& b1, RhsPacket& b2, RhsPacket& b3);
|
// void broadcastRhs(const RhsScalar* b, RhsPacket& b0, RhsPacket& b1, RhsPacket& b2, RhsPacket& b3);
|
||||||
|
//
|
||||||
EIGEN_STRONG_INLINE void broadcastRhs(const RhsScalar* b, RhsPacket& b0, RhsPacket& b1)
|
// EIGEN_STRONG_INLINE void broadcastRhs(const RhsScalar* b, RhsPacket& b0, RhsPacket& b1)
|
||||||
{
|
// {
|
||||||
// FIXME not sure that's the best way to implement it!
|
// // FIXME not sure that's the best way to implement it!
|
||||||
b0 = pload1<RhsPacket>(b+0);
|
// b0 = pload1<RhsPacket>(b+0);
|
||||||
b1 = pload1<RhsPacket>(b+1);
|
// b1 = pload1<RhsPacket>(b+1);
|
||||||
}
|
// }
|
||||||
|
|
||||||
EIGEN_STRONG_INLINE void loadLhs(const LhsScalar* a, LhsPacket& dest) const
|
EIGEN_STRONG_INLINE void loadLhs(const LhsScalar* a, LhsPacket& dest) const
|
||||||
{
|
{
|
||||||
dest = ploaddup<LhsPacket>(a);
|
dest = ploaddup<LhsPacket>(a);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
EIGEN_STRONG_INLINE void loadRhsQuad(const RhsScalar* b, RhsPacket& dest) const
|
||||||
|
{
|
||||||
|
eigen_internal_assert(unpacket_traits<RhsPacket>::size<=4);
|
||||||
|
loadRhs(b,dest);
|
||||||
|
}
|
||||||
|
|
||||||
EIGEN_STRONG_INLINE void loadLhsUnaligned(const LhsScalar* a, LhsPacket& dest) const
|
EIGEN_STRONG_INLINE void loadLhsUnaligned(const LhsScalar* a, LhsPacket& dest) const
|
||||||
{
|
{
|
||||||
dest = ploaddup<LhsPacket>(a);
|
dest = ploaddup<LhsPacket>(a);
|
||||||
@ -619,7 +631,13 @@ public:
|
|||||||
|
|
||||||
EIGEN_STRONG_INLINE void madd_impl(const LhsPacket& a, const RhsPacket& b, AccPacket& c, RhsPacket& tmp, const true_type&) const
|
EIGEN_STRONG_INLINE void madd_impl(const LhsPacket& a, const RhsPacket& b, AccPacket& c, RhsPacket& tmp, const true_type&) const
|
||||||
{
|
{
|
||||||
|
#ifdef EIGEN_VECTORIZE_FMA
|
||||||
|
EIGEN_UNUSED_VARIABLE(tmp);
|
||||||
|
c.v = pmadd(a,b.v,c.v);
|
||||||
|
#else
|
||||||
tmp = b; tmp.v = pmul(a,tmp.v); c = padd(c,tmp);
|
tmp = b; tmp.v = pmul(a,tmp.v); c = padd(c,tmp);
|
||||||
|
#endif
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
EIGEN_STRONG_INLINE void madd_impl(const LhsScalar& a, const RhsScalar& b, ResScalar& c, RhsScalar& /*tmp*/, const false_type&) const
|
EIGEN_STRONG_INLINE void madd_impl(const LhsScalar& a, const RhsScalar& b, ResScalar& c, RhsScalar& /*tmp*/, const false_type&) const
|
||||||
@ -956,7 +974,7 @@ void gebp_kernel<LhsScalar,RhsScalar,Index,mr,nr,ConjugateLhs,ConjugateRhs>
|
|||||||
for(Index k=0; k<peeled_kc; k+=pk)
|
for(Index k=0; k<peeled_kc; k+=pk)
|
||||||
{
|
{
|
||||||
IACA_START
|
IACA_START
|
||||||
EIGEN_ASM_COMMENT("begin gegp micro kernel 2p x 4");
|
EIGEN_ASM_COMMENT("begin gegp micro kernel 2pX4");
|
||||||
RhsPacket B_0, B1;
|
RhsPacket B_0, B1;
|
||||||
|
|
||||||
#define EIGEN_GEBGP_ONESTEP(K) \
|
#define EIGEN_GEBGP_ONESTEP(K) \
|
||||||
@ -1134,7 +1152,7 @@ void gebp_kernel<LhsScalar,RhsScalar,Index,mr,nr,ConjugateLhs,ConjugateRhs>
|
|||||||
for(Index k=0; k<peeled_kc; k+=pk)
|
for(Index k=0; k<peeled_kc; k+=pk)
|
||||||
{
|
{
|
||||||
IACA_START
|
IACA_START
|
||||||
EIGEN_ASM_COMMENT("begin gegp micro kernel 2p x 4");
|
EIGEN_ASM_COMMENT("begin gegp micro kernel 1pX4");
|
||||||
RhsPacket B_0, B1;
|
RhsPacket B_0, B1;
|
||||||
|
|
||||||
#define EIGEN_GEBGP_ONESTEP(K) \
|
#define EIGEN_GEBGP_ONESTEP(K) \
|
||||||
@ -1160,7 +1178,7 @@ void gebp_kernel<LhsScalar,RhsScalar,Index,mr,nr,ConjugateLhs,ConjugateRhs>
|
|||||||
EIGEN_GEBGP_ONESTEP(7);
|
EIGEN_GEBGP_ONESTEP(7);
|
||||||
|
|
||||||
blB += pk*4*RhsProgress;
|
blB += pk*4*RhsProgress;
|
||||||
blA += pk*(1*Traits::LhsProgress);
|
blA += pk*1*LhsProgress;
|
||||||
IACA_END
|
IACA_END
|
||||||
}
|
}
|
||||||
// process remaining peeled loop
|
// process remaining peeled loop
|
||||||
@ -1169,7 +1187,7 @@ void gebp_kernel<LhsScalar,RhsScalar,Index,mr,nr,ConjugateLhs,ConjugateRhs>
|
|||||||
RhsPacket B_0, B1;
|
RhsPacket B_0, B1;
|
||||||
EIGEN_GEBGP_ONESTEP(0);
|
EIGEN_GEBGP_ONESTEP(0);
|
||||||
blB += 4*RhsProgress;
|
blB += 4*RhsProgress;
|
||||||
blA += 1*Traits::LhsProgress;
|
blA += 1*LhsProgress;
|
||||||
}
|
}
|
||||||
#undef EIGEN_GEBGP_ONESTEP
|
#undef EIGEN_GEBGP_ONESTEP
|
||||||
|
|
||||||
@ -1439,6 +1457,8 @@ EIGEN_DONT_INLINE void gemm_pack_lhs<Scalar, Index, Pack1, Pack2, ColMajor, Conj
|
|||||||
const Index peeled_mc3 = Pack1>=3*PacketSize ? (rows/(3*PacketSize))*(3*PacketSize) : 0;
|
const Index peeled_mc3 = Pack1>=3*PacketSize ? (rows/(3*PacketSize))*(3*PacketSize) : 0;
|
||||||
const Index peeled_mc2 = Pack1>=2*PacketSize ? peeled_mc3+((rows-peeled_mc3)/(2*PacketSize))*(2*PacketSize) : 0;
|
const Index peeled_mc2 = Pack1>=2*PacketSize ? peeled_mc3+((rows-peeled_mc3)/(2*PacketSize))*(2*PacketSize) : 0;
|
||||||
const Index peeled_mc1 = Pack1>=1*PacketSize ? (rows/(1*PacketSize))*(1*PacketSize) : 0;
|
const Index peeled_mc1 = Pack1>=1*PacketSize ? (rows/(1*PacketSize))*(1*PacketSize) : 0;
|
||||||
|
const Index peeled_mc0 = Pack2>=1*PacketSize ? peeled_mc1
|
||||||
|
: Pack2>1 ? (rows/Pack2)*Pack2 : 0;
|
||||||
|
|
||||||
// Pack 3 packets
|
// Pack 3 packets
|
||||||
if(Pack1>=3*PacketSize)
|
if(Pack1>=3*PacketSize)
|
||||||
@ -1496,16 +1516,20 @@ EIGEN_DONT_INLINE void gemm_pack_lhs<Scalar, Index, Pack1, Pack2, ColMajor, Conj
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
// Pack scalars
|
// Pack scalars
|
||||||
// if(rows-peeled_mc>=Pack2)
|
if(Pack2<PacketSize && Pack2>1)
|
||||||
// {
|
{
|
||||||
// if(PanelMode) count += Pack2*offset;
|
for(Index i=peeled_mc1; i<peeled_mc0; i+=Pack2)
|
||||||
// for(Index k=0; k<depth; k++)
|
{
|
||||||
// for(Index w=0; w<Pack2; w++)
|
if(PanelMode) count += Pack2 * offset;
|
||||||
// blockA[count++] = cj(lhs(peeled_mc+w, k));
|
|
||||||
// if(PanelMode) count += Pack2 * (stride-offset-depth);
|
for(Index k=0; k<depth; k++)
|
||||||
// peeled_mc += Pack2;
|
for(Index w=0; w<Pack2; w++)
|
||||||
// }
|
blockA[count++] = cj(lhs(i+w, k));
|
||||||
for(Index i=peeled_mc1; i<rows; i++)
|
|
||||||
|
if(PanelMode) count += Pack2 * (stride-offset-depth);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for(Index i=peeled_mc0; i<rows; i++)
|
||||||
{
|
{
|
||||||
if(PanelMode) count += offset;
|
if(PanelMode) count += offset;
|
||||||
for(Index k=0; k<depth; k++)
|
for(Index k=0; k<depth; k++)
|
||||||
@ -1539,35 +1563,36 @@ EIGEN_DONT_INLINE void gemm_pack_lhs<Scalar, Index, Pack1, Pack2, RowMajor, Conj
|
|||||||
// const Index peeled_mc2 = Pack1>=2*PacketSize ? peeled_mc3+((rows-peeled_mc3)/(2*PacketSize))*(2*PacketSize) : 0;
|
// const Index peeled_mc2 = Pack1>=2*PacketSize ? peeled_mc3+((rows-peeled_mc3)/(2*PacketSize))*(2*PacketSize) : 0;
|
||||||
// const Index peeled_mc1 = Pack1>=1*PacketSize ? (rows/(1*PacketSize))*(1*PacketSize) : 0;
|
// const Index peeled_mc1 = Pack1>=1*PacketSize ? (rows/(1*PacketSize))*(1*PacketSize) : 0;
|
||||||
|
|
||||||
int pack_packets = Pack1/PacketSize;
|
int pack = Pack1;
|
||||||
Index i = 0;
|
Index i = 0;
|
||||||
while(pack_packets>0)
|
while(pack>0)
|
||||||
{
|
{
|
||||||
|
|
||||||
Index remaining_rows = rows-i;
|
Index remaining_rows = rows-i;
|
||||||
Index peeled_mc = i+(remaining_rows/(pack_packets*PacketSize))*(pack_packets*PacketSize);
|
Index peeled_mc = i+(remaining_rows/pack)*pack;
|
||||||
// std::cout << "pack_packets = " << pack_packets << " from " << i << " to " << peeled_mc << "\n";
|
for(; i<peeled_mc; i+=pack)
|
||||||
for(; i<peeled_mc; i+=pack_packets*PacketSize)
|
|
||||||
{
|
{
|
||||||
if(PanelMode) count += (pack_packets*PacketSize) * offset;
|
if(PanelMode) count += pack * offset;
|
||||||
|
|
||||||
const Index peeled_k = (depth/PacketSize)*PacketSize;
|
const Index peeled_k = (depth/PacketSize)*PacketSize;
|
||||||
Index k=0;
|
Index k=0;
|
||||||
for(; k<peeled_k; k+=PacketSize)
|
if(pack>=PacketSize)
|
||||||
{
|
{
|
||||||
for (Index m = 0; m < (pack_packets*PacketSize); m += PacketSize)
|
for(; k<peeled_k; k+=PacketSize)
|
||||||
{
|
{
|
||||||
Kernel<Packet> kernel;
|
for (Index m = 0; m < pack; m += PacketSize)
|
||||||
for (int p = 0; p < PacketSize; ++p) kernel.packet[p] = ploadu<Packet>(&lhs(i+p+m, k));
|
{
|
||||||
ptranspose(kernel);
|
Kernel<Packet> kernel;
|
||||||
for (int p = 0; p < PacketSize; ++p) pstore(blockA+count+m+(pack_packets*PacketSize)*p, cj.pconj(kernel.packet[p]));
|
for (int p = 0; p < PacketSize; ++p) kernel.packet[p] = ploadu<Packet>(&lhs(i+p+m, k));
|
||||||
|
ptranspose(kernel);
|
||||||
|
for (int p = 0; p < PacketSize; ++p) pstore(blockA+count+m+(pack)*p, cj.pconj(kernel.packet[p]));
|
||||||
|
}
|
||||||
|
count += PacketSize*pack;
|
||||||
}
|
}
|
||||||
count += PacketSize*(pack_packets*PacketSize);
|
|
||||||
}
|
}
|
||||||
for(; k<depth; k++)
|
for(; k<depth; k++)
|
||||||
{
|
{
|
||||||
Index w=0;
|
Index w=0;
|
||||||
for(; w<(pack_packets*PacketSize)-3; w+=4)
|
for(; w<pack-3; w+=4)
|
||||||
{
|
{
|
||||||
Scalar a(cj(lhs(i+w+0, k))),
|
Scalar a(cj(lhs(i+w+0, k))),
|
||||||
b(cj(lhs(i+w+1, k))),
|
b(cj(lhs(i+w+1, k))),
|
||||||
@ -1578,26 +1603,19 @@ EIGEN_DONT_INLINE void gemm_pack_lhs<Scalar, Index, Pack1, Pack2, RowMajor, Conj
|
|||||||
blockA[count++] = c;
|
blockA[count++] = c;
|
||||||
blockA[count++] = d;
|
blockA[count++] = d;
|
||||||
}
|
}
|
||||||
if(PacketSize%4)
|
if(pack%4)
|
||||||
for(;w<pack_packets*PacketSize;++w)
|
for(;w<pack;++w)
|
||||||
blockA[count++] = cj(lhs(i+w, k));
|
blockA[count++] = cj(lhs(i+w, k));
|
||||||
}
|
}
|
||||||
|
|
||||||
if(PanelMode) count += (pack_packets*PacketSize) * (stride-offset-depth);
|
if(PanelMode) count += pack * (stride-offset-depth);
|
||||||
}
|
}
|
||||||
|
|
||||||
pack_packets--;
|
pack -= PacketSize;
|
||||||
|
if(pack<Pack2 && (pack+PacketSize)!=Pack2)
|
||||||
|
pack = Pack2;
|
||||||
}
|
}
|
||||||
|
|
||||||
// if(rows-peeled_mc>=Pack2)
|
|
||||||
// {
|
|
||||||
// if(PanelMode) count += Pack2*offset;
|
|
||||||
// for(Index k=0; k<depth; k++)
|
|
||||||
// for(Index w=0; w<Pack2; w++)
|
|
||||||
// blockA[count++] = cj(lhs(peeled_mc+w, k));
|
|
||||||
// if(PanelMode) count += Pack2 * (stride-offset-depth);
|
|
||||||
// peeled_mc += Pack2;
|
|
||||||
// }
|
|
||||||
for(; i<rows; i++)
|
for(; i<rows; i++)
|
||||||
{
|
{
|
||||||
if(PanelMode) count += offset;
|
if(PanelMode) count += offset;
|
||||||
|
Loading…
x
Reference in New Issue
Block a user