improve WIP new matrix product

This commit is contained in:
Gael Guennebaud 2009-02-27 17:18:52 +00:00
parent 40774c625e
commit 8ed186b9ab

View File

@ -342,56 +342,107 @@ static void ei_cache_friendly_product(
#else #else
// loops on each L2 cache friendly blocks of the result // loops on each L2 cache friendly blocks of the result
for(int l2i=0; l2i<rows; l2i+=l2BlockRows) for(int l2j=0; l2j<cols; l2j+=l2BlockCols)
{ {
for(int l2j=0; l2j<cols; l2j+=l2BlockCols) for(int l2i=0; l2i<rows; l2i+=l2BlockRows)
{ {
// We have selected a block of lhs // We have selected a block of lhs
// Packs this block into 'block' // Packs this block into 'block'
for(int j=0; j<l2BlockCols; ++j) int count = 0;
for(int j=0; j<l2BlockCols; j+=MaxBlockRows)
{ {
int count = 0; for(int i=0; i<l2BlockRows; i+=2*PacketSize)
for(int i=0; i<l2BlockRows; ++i) for (int w=0; w<MaxBlockRows; ++w)
{ for (int y=0; y<2*PacketSize; ++y)
block[ (j*l2BlockCols) + i] = lhs[(j+l2j)*rows+l2i+count++]; block[count++] = lhs[(j+l2j+w)*rows + l2i+i+ y];
}
} }
// loops on each L2 cache firendly block of the result/rhs // loops on each L2 cache firendly block of the result/rhs
for(int l2k=0; l2k<cols; l2k+=l2BlockCols) for(int l2k=0; l2k<cols; l2k+=l2BlockCols)
{ {
for(int j=0; j<l2BlockCols; ++j) for(int i=0; i<l2BlockRows; i+=MaxBlockRows)
{ {
for(int i=0; i<l2BlockRows; i+=PacketSize) for(int j=0; j<l2BlockCols; ++j)
{ {
PacketType A0, A1, A2, A3, A4, A5; PacketType A0, A1, A2, A3, A4, A5, A6, A7;
// Load the packets from rhs and reorder them // Load the packets from rhs and reorder them
// Here we need some vector reordering // Here we need some vector reordering
// Right now its hardcoded to packets of 4 elements // Right now its hardcoded to packets of 4 elements
A0 = ei_pset1(rhs[(j+l2k)*rows+(i+l2j)]); const Scalar* lrhs = &rhs[(j+l2k)*rows+(i+l2j)];
A1 = ei_pset1(rhs[(j+l2k)*rows+(i+l2j)+1]); A0 = ei_pset1(lrhs[0]);
A2 = ei_pset1(rhs[(j+l2k)*rows+(i+l2j)+2]); A1 = ei_pset1(lrhs[1]);
A3 = ei_pset1(rhs[(j+l2k)*rows+(i+l2j)+3]); A2 = ei_pset1(lrhs[2]);
A3 = ei_pset1(lrhs[3]);
for(int k=0; k<l2BlockRows; k+=PacketSize) if (MaxBlockRows==8)
{ {
PacketType L0, L1, L2, L3; A4 = ei_pset1(lrhs[4]);
A5 = ei_pset1(lrhs[5]);
A6 = ei_pset1(lrhs[6]);
A7 = ei_pset1(lrhs[7]);
}
Scalar * lb = &block[l2BlockRows * i];
for(int k=0; k<l2BlockRows; k+=2*PacketSize)
{
PacketType R0, R1, L0, L1, T0, T1;
asm("#begin sgemm");
// We perform "cross products" of vectors to avoid // We perform "cross products" of vectors to avoid
// reductions (horizontal ops) afterwards // reductions (horizontal ops) afterwards
A4 = ei_pload(&res[(j+l2k)*rows+l2i+k]); T0 = ei_pload(&res[(j+l2k)*rows+l2i+k]);
L0 = ei_pload(&block[ k + (i + 0)*l2BlockRows ]); T1 = ei_pload(&res[(j+l2k)*rows+l2i+k+PacketSize]);
L1 = ei_pload(&block[ k + (i + 1)*l2BlockRows ]); // uncomment to remove res cache miss
A4 = ei_pmadd(L0, A0, A4); // T0 = ei_pload(&res[k]);
L2 = ei_pload(&block[ k + (i + 2)*l2BlockRows ]); // T1 = ei_pload(&res[k+PacketSize]);
A4 = ei_pmadd(L1, A1, A4);
L3 = ei_pload(&block[ k + (i + 3)*l2BlockRows ]);
A4 = ei_pmadd(L2, A2, A4);
A4 = ei_pmadd(L3, A3, A4);
ei_pstore(&res[(j+l2k)*rows+l2i+k], A4); R0 = ei_pload(&lb[0*PacketSize]);
L0 = ei_pload(&lb[1*PacketSize]);
R1 = ei_pload(&lb[2*PacketSize]);
L1 = ei_pload(&lb[3*PacketSize]);
T0 = ei_pmadd(R0, A0, T0);
T1 = ei_pmadd(L0, A0, T1);
R0 = ei_pload(&lb[4*PacketSize]);
L0 = ei_pload(&lb[5*PacketSize]);
T0 = ei_pmadd(R1, A1, T0);
T1 = ei_pmadd(L1, A1, T1);
R1 = ei_pload(&lb[6*PacketSize]);
L1 = ei_pload(&lb[7*PacketSize]);
T0 = ei_pmadd(R0, A2, T0);
T1 = ei_pmadd(L0, A2, T1);
if(MaxBlockRows==8)
{
R0 = ei_pload(&lb[8*PacketSize]);
L0 = ei_pload(&lb[9*PacketSize]);
}
T0 = ei_pmadd(R1, A3, T0);
T1 = ei_pmadd(L1, A3, T1);
if(MaxBlockRows==8)
{
R1 = ei_pload(&lb[10*PacketSize]);
L1 = ei_pload(&lb[11*PacketSize]);
T0 = ei_pmadd(R0, A4, T0);
T1 = ei_pmadd(L0, A4, T1);
R0 = ei_pload(&lb[12*PacketSize]);
L0 = ei_pload(&lb[13*PacketSize]);
T0 = ei_pmadd(R1, A5, T0);
T1 = ei_pmadd(L1, A5, T1);
R1 = ei_pload(&lb[14*PacketSize]);
L1 = ei_pload(&lb[15*PacketSize]);
T0 = ei_pmadd(R0, A6, T0);
T1 = ei_pmadd(L0, A6, T1);
T0 = ei_pmadd(R1, A7, T0);
T1 = ei_pmadd(L1, A7, T1);
}
lb += MaxBlockRows*2*PacketSize;
ei_pstore(&res[(j+l2k)*rows+l2i+k], T0);
ei_pstore(&res[(j+l2k)*rows+l2i+k+PacketSize], T1);
// uncomment to remove res cache miss
// ei_pstore(&res[k], T0);
// ei_pstore(&res[k+PacketSize], T1);
asm("#end sgemm");
} }
} }
} }