mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-04-19 08:09:36 +08:00
optimize the packing of lhs blocks for matrix-matrix products => significant speedup for small products
This commit is contained in:
parent
0e1e0a2a58
commit
238999045c
@ -1103,7 +1103,7 @@ EIGEN_ASM_COMMENT("mybegin4");
|
||||
#undef CJMADD
|
||||
|
||||
// pack a block of the lhs
|
||||
// The travesal is as follow (mr==4):
|
||||
// The traversal is as follow (mr==4):
|
||||
// 0 4 8 12 ...
|
||||
// 1 5 9 13 ...
|
||||
// 2 6 10 14 ...
|
||||
@ -1119,11 +1119,15 @@ EIGEN_ASM_COMMENT("mybegin4");
|
||||
template<typename Scalar, typename Index, int Pack1, int Pack2, int StorageOrder, bool Conjugate, bool PanelMode>
|
||||
struct gemm_pack_lhs
|
||||
{
|
||||
void operator()(Scalar* blockA, const Scalar* EIGEN_RESTRICT _lhs, Index lhsStride, Index depth, Index rows,
|
||||
EIGEN_DONT_INLINE void operator()(Scalar* blockA, const Scalar* EIGEN_RESTRICT _lhs, Index lhsStride, Index depth, Index rows,
|
||||
Index stride=0, Index offset=0)
|
||||
{
|
||||
// enum { PacketSize = packet_traits<Scalar>::size };
|
||||
typedef typename packet_traits<Scalar>::type Packet;
|
||||
enum { PacketSize = packet_traits<Scalar>::size };
|
||||
|
||||
EIGEN_ASM_COMMENT("EIGEN PRODUCT PACK LHS");
|
||||
eigen_assert(((!PanelMode) && stride==0 && offset==0) || (PanelMode && stride>=depth && offset<=stride));
|
||||
eigen_assert( (StorageOrder==RowMajor) || (Pack1%PacketSize)==0 && Pack1<=4*PacketSize);
|
||||
conj_if<NumTraits<Scalar>::IsComplex && Conjugate> cj;
|
||||
const_blas_data_mapper<Scalar, Index, StorageOrder> lhs(_lhs,lhsStride);
|
||||
Index count = 0;
|
||||
@ -1131,9 +1135,44 @@ struct gemm_pack_lhs
|
||||
for(Index i=0; i<peeled_mc; i+=Pack1)
|
||||
{
|
||||
if(PanelMode) count += Pack1 * offset;
|
||||
for(Index k=0; k<depth; k++)
|
||||
for(Index w=0; w<Pack1; w++)
|
||||
blockA[count++] = cj(lhs(i+w, k));
|
||||
|
||||
if(StorageOrder==ColMajor)
|
||||
{
|
||||
for(Index k=0; k<depth; k++)
|
||||
{
|
||||
Packet A, B, C, D;
|
||||
if(Pack1>=1*PacketSize) A = ploadu<Packet>(&lhs(i+0*PacketSize, k));
|
||||
if(Pack1>=2*PacketSize) B = ploadu<Packet>(&lhs(i+1*PacketSize, k));
|
||||
if(Pack1>=3*PacketSize) C = ploadu<Packet>(&lhs(i+2*PacketSize, k));
|
||||
if(Pack1>=4*PacketSize) D = ploadu<Packet>(&lhs(i+3*PacketSize, k));
|
||||
if(Pack1>=1*PacketSize) { pstore(blockA+count, cj(A)); count+=PacketSize; }
|
||||
if(Pack1>=2*PacketSize) { pstore(blockA+count, cj(B)); count+=PacketSize; }
|
||||
if(Pack1>=3*PacketSize) { pstore(blockA+count, cj(C)); count+=PacketSize; }
|
||||
if(Pack1>=4*PacketSize) { pstore(blockA+count, cj(D)); count+=PacketSize; }
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
for(Index k=0; k<depth; k++)
|
||||
{
|
||||
// TODO add a vectorized transpose here
|
||||
Index w=0;
|
||||
for(; w<Pack1-3; w+=4)
|
||||
{
|
||||
Scalar a(cj(lhs(i+w+0, k))),
|
||||
b(cj(lhs(i+w+1, k))),
|
||||
c(cj(lhs(i+w+2, k))),
|
||||
d(cj(lhs(i+w+3, k)));
|
||||
blockA[count++] = a;
|
||||
blockA[count++] = b;
|
||||
blockA[count++] = c;
|
||||
blockA[count++] = d;
|
||||
}
|
||||
if(Pack1%4)
|
||||
for(;w<Pack1;++w)
|
||||
blockA[count++] = cj(lhs(i+w, k));
|
||||
}
|
||||
}
|
||||
if(PanelMode) count += Pack1 * (stride-offset-depth);
|
||||
}
|
||||
if(rows-peeled_mc>=Pack2)
|
||||
@ -1167,9 +1206,10 @@ struct gemm_pack_rhs<Scalar, Index, nr, ColMajor, Conjugate, PanelMode>
|
||||
{
|
||||
typedef typename packet_traits<Scalar>::type Packet;
|
||||
enum { PacketSize = packet_traits<Scalar>::size };
|
||||
void operator()(Scalar* blockB, const Scalar* rhs, Index rhsStride, Index depth, Index cols,
|
||||
EIGEN_DONT_INLINE void operator()(Scalar* blockB, const Scalar* rhs, Index rhsStride, Index depth, Index cols,
|
||||
Index stride=0, Index offset=0)
|
||||
{
|
||||
EIGEN_ASM_COMMENT("EIGEN PRODUCT PACK RHS COLMAJOR");
|
||||
eigen_assert(((!PanelMode) && stride==0 && offset==0) || (PanelMode && stride>=depth && offset<=stride));
|
||||
conj_if<NumTraits<Scalar>::IsComplex && Conjugate> cj;
|
||||
Index packet_cols = (cols/nr) * nr;
|
||||
@ -1214,9 +1254,10 @@ template<typename Scalar, typename Index, int nr, bool Conjugate, bool PanelMode
|
||||
struct gemm_pack_rhs<Scalar, Index, nr, RowMajor, Conjugate, PanelMode>
|
||||
{
|
||||
enum { PacketSize = packet_traits<Scalar>::size };
|
||||
void operator()(Scalar* blockB, const Scalar* rhs, Index rhsStride, Index depth, Index cols,
|
||||
EIGEN_DONT_INLINE void operator()(Scalar* blockB, const Scalar* rhs, Index rhsStride, Index depth, Index cols,
|
||||
Index stride=0, Index offset=0)
|
||||
{
|
||||
EIGEN_ASM_COMMENT("EIGEN PRODUCT PACK RHS ROWMAJOR");
|
||||
eigen_assert(((!PanelMode) && stride==0 && offset==0) || (PanelMode && stride>=depth && offset<=stride));
|
||||
conj_if<NumTraits<Scalar>::IsComplex && Conjugate> cj;
|
||||
Index packet_cols = (cols/nr) * nr;
|
||||
|
Loading…
x
Reference in New Issue
Block a user