mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-08-12 11:49:02 +08:00
Vectorized the loop peeling of the inner loop of the block-panel matrix multiplication code. This speeds up the multiplication of matrices which size is not a multiple of the packet size.
This commit is contained in:
parent
39bfbd43f0
commit
ad59ade116
@ -206,6 +206,11 @@ public:
|
|||||||
dest = pload<LhsPacket>(a);
|
dest = pload<LhsPacket>(a);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
EIGEN_STRONG_INLINE void loadLhsUnaligned(const LhsScalar* a, LhsPacket& dest) const
|
||||||
|
{
|
||||||
|
dest = ploadu<LhsPacket>(a);
|
||||||
|
}
|
||||||
|
|
||||||
EIGEN_STRONG_INLINE void madd(const LhsPacket& a, const RhsPacket& b, AccPacket& c, AccPacket& tmp) const
|
EIGEN_STRONG_INLINE void madd(const LhsPacket& a, const RhsPacket& b, AccPacket& c, AccPacket& tmp) const
|
||||||
{
|
{
|
||||||
// It would be a lot cleaner to call pmadd all the time. Unfortunately if we
|
// It would be a lot cleaner to call pmadd all the time. Unfortunately if we
|
||||||
@ -279,6 +284,11 @@ public:
|
|||||||
dest = pload<LhsPacket>(a);
|
dest = pload<LhsPacket>(a);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
EIGEN_STRONG_INLINE void loadLhsUnaligned(const LhsScalar* a, LhsPacket& dest) const
|
||||||
|
{
|
||||||
|
dest = ploadu<LhsPacket>(a);
|
||||||
|
}
|
||||||
|
|
||||||
EIGEN_STRONG_INLINE void broadcastRhs(const RhsScalar* b, RhsPacket& b0, RhsPacket& b1, RhsPacket& b2, RhsPacket& b3)
|
EIGEN_STRONG_INLINE void broadcastRhs(const RhsScalar* b, RhsPacket& b0, RhsPacket& b1, RhsPacket& b2, RhsPacket& b3)
|
||||||
{
|
{
|
||||||
pbroadcast4(b, b0, b1, b2, b3);
|
pbroadcast4(b, b0, b1, b2, b3);
|
||||||
@ -334,6 +344,8 @@ public:
|
|||||||
&& packet_traits<Scalar>::Vectorizable,
|
&& packet_traits<Scalar>::Vectorizable,
|
||||||
RealPacketSize = Vectorizable ? packet_traits<RealScalar>::size : 1,
|
RealPacketSize = Vectorizable ? packet_traits<RealScalar>::size : 1,
|
||||||
ResPacketSize = Vectorizable ? packet_traits<ResScalar>::size : 1,
|
ResPacketSize = Vectorizable ? packet_traits<ResScalar>::size : 1,
|
||||||
|
LhsPacketSize = Vectorizable ? packet_traits<LhsScalar>::size : 1,
|
||||||
|
RhsPacketSize = Vectorizable ? packet_traits<RhsScalar>::size : 1,
|
||||||
|
|
||||||
// FIXME: should depend on NumberOfRegisters
|
// FIXME: should depend on NumberOfRegisters
|
||||||
nr = 4,
|
nr = 4,
|
||||||
@ -402,6 +414,11 @@ public:
|
|||||||
dest = pload<LhsPacket>((const typename unpacket_traits<LhsPacket>::type*)(a));
|
dest = pload<LhsPacket>((const typename unpacket_traits<LhsPacket>::type*)(a));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
EIGEN_STRONG_INLINE void loadLhsUnaligned(const LhsScalar* a, LhsPacket& dest) const
|
||||||
|
{
|
||||||
|
dest = ploadu<LhsPacket>((const typename unpacket_traits<LhsPacket>::type*)(a));
|
||||||
|
}
|
||||||
|
|
||||||
EIGEN_STRONG_INLINE void madd(const LhsPacket& a, const RhsPacket& b, DoublePacket& c, RhsPacket& /*tmp*/) const
|
EIGEN_STRONG_INLINE void madd(const LhsPacket& a, const RhsPacket& b, DoublePacket& c, RhsPacket& /*tmp*/) const
|
||||||
{
|
{
|
||||||
c.first = padd(pmul(a,b.first), c.first);
|
c.first = padd(pmul(a,b.first), c.first);
|
||||||
@ -509,6 +526,11 @@ public:
|
|||||||
dest = ploaddup<LhsPacket>(a);
|
dest = ploaddup<LhsPacket>(a);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
EIGEN_STRONG_INLINE void loadLhsUnaligned(const LhsScalar* a, LhsPacket& dest) const
|
||||||
|
{
|
||||||
|
dest = ploaddup<LhsPacket>(a);
|
||||||
|
}
|
||||||
|
|
||||||
EIGEN_STRONG_INLINE void madd(const LhsPacket& a, const RhsPacket& b, AccPacket& c, RhsPacket& tmp) const
|
EIGEN_STRONG_INLINE void madd(const LhsPacket& a, const RhsPacket& b, AccPacket& c, RhsPacket& tmp) const
|
||||||
{
|
{
|
||||||
madd_impl(a, b, c, tmp, typename conditional<Vectorizable,true_type,false_type>::type());
|
madd_impl(a, b, c, tmp, typename conditional<Vectorizable,true_type,false_type>::type());
|
||||||
@ -706,49 +728,84 @@ void gebp_kernel<LhsScalar,RhsScalar,Index,mr,nr,ConjugateLhs,ConjugateRhs>
|
|||||||
const LhsScalar* blA = &blockA[i*strideA+offsetA];
|
const LhsScalar* blA = &blockA[i*strideA+offsetA];
|
||||||
prefetch(&blA[0]);
|
prefetch(&blA[0]);
|
||||||
|
|
||||||
// gets a 1 x 8 res block as registers
|
|
||||||
ResScalar C0(0), C1(0), C2(0), C3(0), C4(0), C5(0), C6(0), C7(0);
|
|
||||||
// FIXME directly use blockB ???
|
// FIXME directly use blockB ???
|
||||||
const RhsScalar* blB = &blockB[j2*strideB+offsetB*8];
|
const RhsScalar* blB = &blockB[j2*strideB+offsetB*8];
|
||||||
// TODO peel this loop
|
|
||||||
for(Index k=0; k<depth; k++)
|
|
||||||
{
|
|
||||||
LhsScalar A0;
|
|
||||||
RhsScalar B_0, B_1;
|
|
||||||
|
|
||||||
A0 = blA[k];
|
if(nr == Traits::RhsPacketSize)
|
||||||
|
{
|
||||||
|
EIGEN_ASM_COMMENT("begin_vectorized_multiplication_of_last_rows");
|
||||||
|
|
||||||
B_0 = blB[0];
|
typedef gebp_traits<RhsScalar,LhsScalar,ConjugateRhs,ConjugateLhs> SwappedTraits;
|
||||||
B_1 = blB[1];
|
typedef typename SwappedTraits::ResScalar SResScalar;
|
||||||
MADD(cj,A0,B_0,C0, B_0);
|
typedef typename SwappedTraits::LhsPacket SLhsPacket;
|
||||||
MADD(cj,A0,B_1,C1, B_1);
|
typedef typename SwappedTraits::RhsPacket SRhsPacket;
|
||||||
|
typedef typename SwappedTraits::ResPacket SResPacket;
|
||||||
|
typedef typename SwappedTraits::AccPacket SAccPacket;
|
||||||
|
SwappedTraits straits;
|
||||||
|
|
||||||
B_0 = blB[2];
|
SAccPacket C0;
|
||||||
B_1 = blB[3];
|
straits.initAcc(C0);
|
||||||
MADD(cj,A0,B_0,C2, B_0);
|
for(Index k=0; k<depth; k++)
|
||||||
MADD(cj,A0,B_1,C3, B_1);
|
{
|
||||||
|
SLhsPacket A0;
|
||||||
|
straits.loadLhsUnaligned(blB, A0);
|
||||||
|
SRhsPacket B_0;
|
||||||
|
straits.loadRhs(&blA[k], B_0);
|
||||||
|
SRhsPacket T0;
|
||||||
|
straits.madd(A0,B_0,C0,T0);
|
||||||
|
blB += nr;
|
||||||
|
}
|
||||||
|
SResPacket R = pgather<SResScalar, SResPacket>(&res[j2*resStride + i], resStride);
|
||||||
|
SResPacket alphav = pset1<SResPacket>(alpha);
|
||||||
|
straits.acc(C0, alphav, R);
|
||||||
|
pscatter(&res[j2*resStride + i], R, resStride);
|
||||||
|
|
||||||
B_0 = blB[4];
|
EIGEN_ASM_COMMENT("end_vectorized_multiplication_of_last_rows");
|
||||||
B_1 = blB[5];
|
}
|
||||||
MADD(cj,A0,B_0,C4, B_0);
|
else
|
||||||
MADD(cj,A0,B_1,C5, B_1);
|
{
|
||||||
|
// gets a 1 x 8 res block as registers
|
||||||
|
ResScalar C0(0), C1(0), C2(0), C3(0), C4(0), C5(0), C6(0), C7(0);
|
||||||
|
|
||||||
B_0 = blB[6];
|
for(Index k=0; k<depth; k++)
|
||||||
B_1 = blB[7];
|
{
|
||||||
MADD(cj,A0,B_0,C6, B_0);
|
LhsScalar A0;
|
||||||
MADD(cj,A0,B_1,C7, B_1);
|
RhsScalar B_0, B_1;
|
||||||
|
|
||||||
blB += 8;
|
A0 = blA[k];
|
||||||
}
|
|
||||||
res[(j2+0)*resStride + i] += alpha*C0;
|
B_0 = blB[0];
|
||||||
res[(j2+1)*resStride + i] += alpha*C1;
|
B_1 = blB[1];
|
||||||
res[(j2+2)*resStride + i] += alpha*C2;
|
MADD(cj,A0,B_0,C0, B_0);
|
||||||
res[(j2+3)*resStride + i] += alpha*C3;
|
MADD(cj,A0,B_1,C1, B_1);
|
||||||
res[(j2+4)*resStride + i] += alpha*C4;
|
|
||||||
res[(j2+5)*resStride + i] += alpha*C5;
|
B_0 = blB[2];
|
||||||
res[(j2+6)*resStride + i] += alpha*C6;
|
B_1 = blB[3];
|
||||||
res[(j2+7)*resStride + i] += alpha*C7;
|
MADD(cj,A0,B_0,C2, B_0);
|
||||||
}
|
MADD(cj,A0,B_1,C3, B_1);
|
||||||
|
|
||||||
|
B_0 = blB[4];
|
||||||
|
B_1 = blB[5];
|
||||||
|
MADD(cj,A0,B_0,C4, B_0);
|
||||||
|
MADD(cj,A0,B_1,C5, B_1);
|
||||||
|
|
||||||
|
B_0 = blB[6];
|
||||||
|
B_1 = blB[7];
|
||||||
|
MADD(cj,A0,B_0,C6, B_0);
|
||||||
|
MADD(cj,A0,B_1,C7, B_1);
|
||||||
|
|
||||||
|
blB += 8;
|
||||||
|
}
|
||||||
|
res[(j2+0)*resStride + i] += alpha*C0;
|
||||||
|
res[(j2+1)*resStride + i] += alpha*C1;
|
||||||
|
res[(j2+2)*resStride + i] += alpha*C2;
|
||||||
|
res[(j2+3)*resStride + i] += alpha*C3;
|
||||||
|
res[(j2+4)*resStride + i] += alpha*C4;
|
||||||
|
res[(j2+5)*resStride + i] += alpha*C5;
|
||||||
|
res[(j2+6)*resStride + i] += alpha*C6;
|
||||||
|
res[(j2+7)*resStride + i] += alpha*C7;
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -839,35 +896,68 @@ void gebp_kernel<LhsScalar,RhsScalar,Index,mr,nr,ConjugateLhs,ConjugateRhs>
|
|||||||
const LhsScalar* blA = &blockA[i*strideA+offsetA];
|
const LhsScalar* blA = &blockA[i*strideA+offsetA];
|
||||||
prefetch(&blA[0]);
|
prefetch(&blA[0]);
|
||||||
|
|
||||||
// gets a 1 x 4 res block as registers
|
|
||||||
ResScalar C0(0), C1(0), C2(0), C3(0);
|
|
||||||
// FIXME directly use blockB ???
|
// FIXME directly use blockB ???
|
||||||
const RhsScalar* blB = &blockB[j2*strideB+offsetB*4];
|
const RhsScalar* blB = &blockB[j2*strideB+offsetB*4];
|
||||||
// TODO peel this loop
|
|
||||||
for(Index k=0; k<depth; k++)
|
|
||||||
{
|
|
||||||
LhsScalar A0;
|
|
||||||
RhsScalar B_0, B_1;
|
|
||||||
|
|
||||||
A0 = blA[k];
|
if(nr == Traits::RhsPacketSize)
|
||||||
|
{
|
||||||
|
EIGEN_ASM_COMMENT("begin_vectorized_multiplication_of_last_rows");
|
||||||
|
|
||||||
B_0 = blB[0];
|
typedef gebp_traits<RhsScalar,LhsScalar,ConjugateRhs,ConjugateLhs> SwappedTraits;
|
||||||
B_1 = blB[1];
|
typedef typename SwappedTraits::ResScalar SResScalar;
|
||||||
MADD(cj,A0,B_0,C0, B_0);
|
typedef typename SwappedTraits::LhsPacket SLhsPacket;
|
||||||
MADD(cj,A0,B_1,C1, B_1);
|
typedef typename SwappedTraits::RhsPacket SRhsPacket;
|
||||||
|
typedef typename SwappedTraits::ResPacket SResPacket;
|
||||||
|
typedef typename SwappedTraits::AccPacket SAccPacket;
|
||||||
|
SwappedTraits straits;
|
||||||
|
|
||||||
B_0 = blB[2];
|
SAccPacket C0;
|
||||||
B_1 = blB[3];
|
straits.initAcc(C0);
|
||||||
MADD(cj,A0,B_0,C2, B_0);
|
for(Index k=0; k<depth; k++)
|
||||||
MADD(cj,A0,B_1,C3, B_1);
|
{
|
||||||
|
SLhsPacket A0;
|
||||||
|
straits.loadLhsUnaligned(blB, A0);
|
||||||
|
SRhsPacket B_0;
|
||||||
|
straits.loadRhs(&blA[k], B_0);
|
||||||
|
SRhsPacket T0;
|
||||||
|
straits.madd(A0,B_0,C0,T0);
|
||||||
|
blB += nr;
|
||||||
|
}
|
||||||
|
SResPacket R = pgather<SResScalar, SResPacket>(&res[j2*resStride + i], resStride);
|
||||||
|
SResPacket alphav = pset1<SResPacket>(alpha);
|
||||||
|
straits.acc(C0, alphav, R);
|
||||||
|
pscatter(&res[j2*resStride + i], R, resStride);
|
||||||
|
|
||||||
blB += 4;
|
EIGEN_ASM_COMMENT("end_vectorized_multiplication_of_last_rows");
|
||||||
}
|
} else {
|
||||||
res[(j2+0)*resStride + i] += alpha*C0;
|
// gets a 1 x 4 res block as registers
|
||||||
res[(j2+1)*resStride + i] += alpha*C1;
|
ResScalar C0(0), C1(0), C2(0), C3(0);
|
||||||
res[(j2+2)*resStride + i] += alpha*C2;
|
|
||||||
res[(j2+3)*resStride + i] += alpha*C3;
|
for(Index k=0; k<depth; k++)
|
||||||
}
|
{
|
||||||
|
LhsScalar A0;
|
||||||
|
RhsScalar B_0, B_1;
|
||||||
|
|
||||||
|
A0 = blA[k];
|
||||||
|
|
||||||
|
B_0 = blB[0];
|
||||||
|
B_1 = blB[1];
|
||||||
|
MADD(cj,A0,B_0,C0, B_0);
|
||||||
|
MADD(cj,A0,B_1,C1, B_1);
|
||||||
|
|
||||||
|
B_0 = blB[2];
|
||||||
|
B_1 = blB[3];
|
||||||
|
MADD(cj,A0,B_0,C2, B_0);
|
||||||
|
MADD(cj,A0,B_1,C3, B_1);
|
||||||
|
|
||||||
|
blB += 4;
|
||||||
|
}
|
||||||
|
res[(j2+0)*resStride + i] += alpha*C0;
|
||||||
|
res[(j2+1)*resStride + i] += alpha*C1;
|
||||||
|
res[(j2+2)*resStride + i] += alpha*C2;
|
||||||
|
res[(j2+3)*resStride + i] += alpha*C3;
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user