Vectorized the loop peeling of the inner loop of the block-panel matrix multiplication code. This speeds up the multiplication of matrices which size is not a multiple of the packet size.

This commit is contained in:
Benoit Steiner 2014-03-28 12:11:23 -07:00
parent 39bfbd43f0
commit ad59ade116

View File

@ -206,6 +206,11 @@ public:
dest = pload<LhsPacket>(a); dest = pload<LhsPacket>(a);
} }
EIGEN_STRONG_INLINE void loadLhsUnaligned(const LhsScalar* a, LhsPacket& dest) const
{
dest = ploadu<LhsPacket>(a);
}
EIGEN_STRONG_INLINE void madd(const LhsPacket& a, const RhsPacket& b, AccPacket& c, AccPacket& tmp) const EIGEN_STRONG_INLINE void madd(const LhsPacket& a, const RhsPacket& b, AccPacket& c, AccPacket& tmp) const
{ {
// It would be a lot cleaner to call pmadd all the time. Unfortunately if we // It would be a lot cleaner to call pmadd all the time. Unfortunately if we
@ -279,6 +284,11 @@ public:
dest = pload<LhsPacket>(a); dest = pload<LhsPacket>(a);
} }
EIGEN_STRONG_INLINE void loadLhsUnaligned(const LhsScalar* a, LhsPacket& dest) const
{
dest = ploadu<LhsPacket>(a);
}
EIGEN_STRONG_INLINE void broadcastRhs(const RhsScalar* b, RhsPacket& b0, RhsPacket& b1, RhsPacket& b2, RhsPacket& b3) EIGEN_STRONG_INLINE void broadcastRhs(const RhsScalar* b, RhsPacket& b0, RhsPacket& b1, RhsPacket& b2, RhsPacket& b3)
{ {
pbroadcast4(b, b0, b1, b2, b3); pbroadcast4(b, b0, b1, b2, b3);
@ -334,6 +344,8 @@ public:
&& packet_traits<Scalar>::Vectorizable, && packet_traits<Scalar>::Vectorizable,
RealPacketSize = Vectorizable ? packet_traits<RealScalar>::size : 1, RealPacketSize = Vectorizable ? packet_traits<RealScalar>::size : 1,
ResPacketSize = Vectorizable ? packet_traits<ResScalar>::size : 1, ResPacketSize = Vectorizable ? packet_traits<ResScalar>::size : 1,
LhsPacketSize = Vectorizable ? packet_traits<LhsScalar>::size : 1,
RhsPacketSize = Vectorizable ? packet_traits<RhsScalar>::size : 1,
// FIXME: should depend on NumberOfRegisters // FIXME: should depend on NumberOfRegisters
nr = 4, nr = 4,
@ -402,6 +414,11 @@ public:
dest = pload<LhsPacket>((const typename unpacket_traits<LhsPacket>::type*)(a)); dest = pload<LhsPacket>((const typename unpacket_traits<LhsPacket>::type*)(a));
} }
EIGEN_STRONG_INLINE void loadLhsUnaligned(const LhsScalar* a, LhsPacket& dest) const
{
dest = ploadu<LhsPacket>((const typename unpacket_traits<LhsPacket>::type*)(a));
}
EIGEN_STRONG_INLINE void madd(const LhsPacket& a, const RhsPacket& b, DoublePacket& c, RhsPacket& /*tmp*/) const EIGEN_STRONG_INLINE void madd(const LhsPacket& a, const RhsPacket& b, DoublePacket& c, RhsPacket& /*tmp*/) const
{ {
c.first = padd(pmul(a,b.first), c.first); c.first = padd(pmul(a,b.first), c.first);
@ -509,6 +526,11 @@ public:
dest = ploaddup<LhsPacket>(a); dest = ploaddup<LhsPacket>(a);
} }
EIGEN_STRONG_INLINE void loadLhsUnaligned(const LhsScalar* a, LhsPacket& dest) const
{
dest = ploaddup<LhsPacket>(a);
}
EIGEN_STRONG_INLINE void madd(const LhsPacket& a, const RhsPacket& b, AccPacket& c, RhsPacket& tmp) const EIGEN_STRONG_INLINE void madd(const LhsPacket& a, const RhsPacket& b, AccPacket& c, RhsPacket& tmp) const
{ {
madd_impl(a, b, c, tmp, typename conditional<Vectorizable,true_type,false_type>::type()); madd_impl(a, b, c, tmp, typename conditional<Vectorizable,true_type,false_type>::type());
@ -706,11 +728,45 @@ void gebp_kernel<LhsScalar,RhsScalar,Index,mr,nr,ConjugateLhs,ConjugateRhs>
const LhsScalar* blA = &blockA[i*strideA+offsetA]; const LhsScalar* blA = &blockA[i*strideA+offsetA];
prefetch(&blA[0]); prefetch(&blA[0]);
// gets a 1 x 8 res block as registers
ResScalar C0(0), C1(0), C2(0), C3(0), C4(0), C5(0), C6(0), C7(0);
// FIXME directly use blockB ??? // FIXME directly use blockB ???
const RhsScalar* blB = &blockB[j2*strideB+offsetB*8]; const RhsScalar* blB = &blockB[j2*strideB+offsetB*8];
// TODO peel this loop
if(nr == Traits::RhsPacketSize)
{
EIGEN_ASM_COMMENT("begin_vectorized_multiplication_of_last_rows");
typedef gebp_traits<RhsScalar,LhsScalar,ConjugateRhs,ConjugateLhs> SwappedTraits;
typedef typename SwappedTraits::ResScalar SResScalar;
typedef typename SwappedTraits::LhsPacket SLhsPacket;
typedef typename SwappedTraits::RhsPacket SRhsPacket;
typedef typename SwappedTraits::ResPacket SResPacket;
typedef typename SwappedTraits::AccPacket SAccPacket;
SwappedTraits straits;
SAccPacket C0;
straits.initAcc(C0);
for(Index k=0; k<depth; k++)
{
SLhsPacket A0;
straits.loadLhsUnaligned(blB, A0);
SRhsPacket B_0;
straits.loadRhs(&blA[k], B_0);
SRhsPacket T0;
straits.madd(A0,B_0,C0,T0);
blB += nr;
}
SResPacket R = pgather<SResScalar, SResPacket>(&res[j2*resStride + i], resStride);
SResPacket alphav = pset1<SResPacket>(alpha);
straits.acc(C0, alphav, R);
pscatter(&res[j2*resStride + i], R, resStride);
EIGEN_ASM_COMMENT("end_vectorized_multiplication_of_last_rows");
}
else
{
// gets a 1 x 8 res block as registers
ResScalar C0(0), C1(0), C2(0), C3(0), C4(0), C5(0), C6(0), C7(0);
for(Index k=0; k<depth; k++) for(Index k=0; k<depth; k++)
{ {
LhsScalar A0; LhsScalar A0;
@ -751,6 +807,7 @@ void gebp_kernel<LhsScalar,RhsScalar,Index,mr,nr,ConjugateLhs,ConjugateRhs>
} }
} }
} }
}
// Second pass using depth x 4 panels // Second pass using depth x 4 panels
// If nr==8, then we have at most one such panel // If nr==8, then we have at most one such panel
@ -839,11 +896,43 @@ void gebp_kernel<LhsScalar,RhsScalar,Index,mr,nr,ConjugateLhs,ConjugateRhs>
const LhsScalar* blA = &blockA[i*strideA+offsetA]; const LhsScalar* blA = &blockA[i*strideA+offsetA];
prefetch(&blA[0]); prefetch(&blA[0]);
// gets a 1 x 4 res block as registers
ResScalar C0(0), C1(0), C2(0), C3(0);
// FIXME directly use blockB ??? // FIXME directly use blockB ???
const RhsScalar* blB = &blockB[j2*strideB+offsetB*4]; const RhsScalar* blB = &blockB[j2*strideB+offsetB*4];
// TODO peel this loop
if(nr == Traits::RhsPacketSize)
{
EIGEN_ASM_COMMENT("begin_vectorized_multiplication_of_last_rows");
typedef gebp_traits<RhsScalar,LhsScalar,ConjugateRhs,ConjugateLhs> SwappedTraits;
typedef typename SwappedTraits::ResScalar SResScalar;
typedef typename SwappedTraits::LhsPacket SLhsPacket;
typedef typename SwappedTraits::RhsPacket SRhsPacket;
typedef typename SwappedTraits::ResPacket SResPacket;
typedef typename SwappedTraits::AccPacket SAccPacket;
SwappedTraits straits;
SAccPacket C0;
straits.initAcc(C0);
for(Index k=0; k<depth; k++)
{
SLhsPacket A0;
straits.loadLhsUnaligned(blB, A0);
SRhsPacket B_0;
straits.loadRhs(&blA[k], B_0);
SRhsPacket T0;
straits.madd(A0,B_0,C0,T0);
blB += nr;
}
SResPacket R = pgather<SResScalar, SResPacket>(&res[j2*resStride + i], resStride);
SResPacket alphav = pset1<SResPacket>(alpha);
straits.acc(C0, alphav, R);
pscatter(&res[j2*resStride + i], R, resStride);
EIGEN_ASM_COMMENT("end_vectorized_multiplication_of_last_rows");
} else {
// gets a 1 x 4 res block as registers
ResScalar C0(0), C1(0), C2(0), C3(0);
for(Index k=0; k<depth; k++) for(Index k=0; k<depth; k++)
{ {
LhsScalar A0; LhsScalar A0;
@ -870,6 +959,7 @@ void gebp_kernel<LhsScalar,RhsScalar,Index,mr,nr,ConjugateLhs,ConjugateRhs>
} }
} }
} }
}
// process remaining rhs/res columns one at a time // process remaining rhs/res columns one at a time
// => do the same but with nr==1 // => do the same but with nr==1