This commit is contained in:
Everton Constantino 2021-04-23 15:39:04 +00:00
parent c62ed9b214
commit 646d92c7f1
2 changed files with 34 additions and 562 deletions

View File

@ -17,9 +17,16 @@ namespace internal {
template<int CPU, typename LhsScalar, typename RhsScalar>
constexpr int SHAPES_COUNT<0, CPU, LhsScalar, RhsScalar> = 7;
// lhs_progress x depth_progress x rhs_progress (depth_progress > 1 matrix ops) x pointer to next rhs_progress on the shapes map
template<int CPU, typename LhsScalar, typename RhsScalar>
constexpr int SHAPES<0, CPU, LhsScalar, RhsScalar>[SHAPES_COUNT<0, CPU, LhsScalar,RhsScalar>][SHAPES_DIMENSION] =
{{1,1,1,SHAPES_POINTER_END},{4,1,1,0},{1,1,4,1},{4,1,4,1},{4,4,4,1},{8,1,4,1},{8,4,4,1}};
{ {1,1,1,SHAPES_POINTER_END, SHAPES_POINTER_END, SHAPES_POINTER_END},
{4,1,1, 0, 0, SHAPES_POINTER_END},
{1,1,4, 1, SHAPES_POINTER_END, SHAPES_POINTER_END},
{4,1,4, 1, 2, SHAPES_POINTER_END},
{4,4,4, 1, 2, 3},
{8,1,4, 1, 4, SHAPES_POINTER_END},
{8,4,4, 1, 4, 5}};
template<int CPU, typename Scalar, typename ResScalar, typename DataMapper>
struct Accumulator<0, CPU, Scalar, ResScalar, DataMapper, 4, 1>

View File

@ -18,555 +18,16 @@ namespace Eigen {
namespace internal {
#ifdef __OLD__
template<typename Scalar, typename Packet, typename Index, bool IsLhs = true>
class PackMap
{
const int packetSize = packet_traits<Scalar>::size;
const Scalar *packed_block;
const Scalar *residue_block;
Index packed_stride;
Index residue_size;
Index rows, cols;
Index offset, stride;
Scalar *cur;
public:
PackMap(const Scalar *packed_block, const Scalar *residue_block, Index rows, Index cols, Index offset, Index stride) : packed_block(packed_block), residue_block(residue_block), rows(rows), cols(cols), offset(offset), stride(stride)
{
if(IsLhs)
{
packed_stride = (rows / packetSize) * packetSize;
residue_size = rows % packetSize;
}
else {
packed_stride = (cols / packetSize) * packetSize;
residue_size = cols % packetSize;
}
};
PackMap(const Scalar *packed_block, Index rows, Index cols, Index offset, Index stride) : packed_block(packed_block), rows(rows), cols(cols)
{
if(IsLhs)
{
packed_stride = (rows / packetSize) * packetSize;
residue_block = packed_block + packed_stride*cols;
residue_size = rows % packetSize;
}
else {
packed_stride = (cols / packetSize) * packetSize;
residue_block = packed_block + packed_stride*rows;
residue_size = cols % packetSize;
}
};
EIGEN_STRONG_INLINE Index get_packed_size()
{
return packed_stride;
};
EIGEN_STRONG_INLINE Index get_residue_size()
{
return residue_size;
};
EIGEN_STRONG_INLINE const Scalar* get_packed_at(Index at)
{
return IsLhs ? packed_block + at : packed_block + at*packetSize*rows;
};
EIGEN_STRONG_INLINE const Scalar* get_residue_at(Index at)
{
return residue_block + stride*at;
};
};
template<typename ResScalar, typename AccScalar, typename LhsScalar, typename RhsScalar, typename Index, typename DataMapper>
EIGEN_STRONG_INLINE void gemm(const DataMapper& res, const LhsScalar* blockA, const RhsScalar* blockB,
Index rows, Index depth, Index cols, ResScalar alpha, Index strideA, Index strideB, Index offsetA, Index offsetB)
{
using AccPacket = typename packet_traits<AccScalar>::type;
using LhsPacket = typename packet_traits<LhsScalar>::type;
using RhsPacket = typename packet_traits<RhsScalar>::type;
using ResPacket = typename packet_traits<ResScalar>::type;
using LinearMapper = typename DataMapper::LinearMapper;
if( strideA == -1 ) strideA = depth;
if( strideB == -1 ) strideB = depth;
ResPacket pAlpha = pset1<ResPacket>(alpha);
#ifdef __DEBUG__
std::cout << "blockA" << std::endl;
for(auto i = 0; i < rows*depth; i++)
{
if(i % strideA == 0 && i > 0)
std::cout << std::endl;
std::cout << blockA[i] << " ";
}
std::cout << std::endl;
std::cout << "blockB" << std::endl;
for(auto i = 0; i < depth*cols; i++)
{
if(i % strideB == 0 && i > 0)
std::cout << std::endl;
std::cout << blockB[i] << " ";
}
std::cout << std::endl;
#endif
int accLhsProgress = 4;
int accRhsProgress = 4;
PackMap<LhsScalar, LhsPacket, Index> lhsMap(blockA, rows, depth, offsetA, strideA);
PackMap<RhsScalar, RhsPacket, Index, false> rhsMap(blockB, depth, cols, offsetB, strideB);
auto col = 0;
for(; col + accRhsProgress <= rhsMap.get_packed_size(); col+=accRhsProgress)
{
auto row = 0;
for(; row + 3*accLhsProgress <= lhsMap.get_packed_size(); row+=3*accLhsProgress)
{
const LhsScalar *lhs_ptr1 = lhsMap.get_packed_at(row + 0*accLhsProgress);
const LhsScalar *lhs_ptr2 = lhsMap.get_packed_at(row + 1*accLhsProgress);
const LhsScalar *lhs_ptr3 = lhsMap.get_packed_at(row + 2*accLhsProgress);
const RhsScalar *rhs_ptr = rhsMap.get_packed_at(col/accRhsProgress);
PacketBlock<AccPacket, 4> acc1;
acc1.packet[0] = pset1<AccPacket>(0);
acc1.packet[1] = pset1<AccPacket>(0);
acc1.packet[2] = pset1<AccPacket>(0);
acc1.packet[3] = pset1<AccPacket>(0);
PacketBlock<AccPacket, 4> acc2;
acc2.packet[0] = pset1<AccPacket>(0);
acc2.packet[1] = pset1<AccPacket>(0);
acc2.packet[2] = pset1<AccPacket>(0);
acc2.packet[3] = pset1<AccPacket>(0);
PacketBlock<AccPacket, 4> acc3;
acc3.packet[0] = pset1<AccPacket>(0);
acc3.packet[1] = pset1<AccPacket>(0);
acc3.packet[2] = pset1<AccPacket>(0);
acc3.packet[3] = pset1<AccPacket>(0);
LinearMapper r00 = res.getLinearMapper(row + 0*accLhsProgress, col + 0);
LinearMapper r01 = res.getLinearMapper(row + 0*accLhsProgress, col + 1);
LinearMapper r02 = res.getLinearMapper(row + 0*accLhsProgress, col + 2);
LinearMapper r03 = res.getLinearMapper(row + 0*accLhsProgress, col + 3);
LinearMapper r10 = res.getLinearMapper(row + 1*accLhsProgress, col + 0);
LinearMapper r11 = res.getLinearMapper(row + 1*accLhsProgress, col + 1);
LinearMapper r12 = res.getLinearMapper(row + 1*accLhsProgress, col + 2);
LinearMapper r13 = res.getLinearMapper(row + 1*accLhsProgress, col + 3);
LinearMapper r20 = res.getLinearMapper(row + 2*accLhsProgress, col + 0);
LinearMapper r21 = res.getLinearMapper(row + 2*accLhsProgress, col + 1);
LinearMapper r22 = res.getLinearMapper(row + 2*accLhsProgress, col + 2);
LinearMapper r23 = res.getLinearMapper(row + 2*accLhsProgress, col + 3);
auto k = 0;
for(; k < depth; k++)
{
RhsPacket prhs = pload<RhsPacket>(rhs_ptr);
PacketBlock<RhsPacket, 4> pbrhs;
pbrhs.packet[0] = pset1<RhsPacket>(prhs[0]);
pbrhs.packet[1] = pset1<RhsPacket>(prhs[1]);
pbrhs.packet[2] = pset1<RhsPacket>(prhs[2]);
pbrhs.packet[3] = pset1<RhsPacket>(prhs[3]);
LhsPacket plhs1 = pload<LhsPacket>(lhs_ptr1);
LhsPacket plhs2 = pload<LhsPacket>(lhs_ptr2);
LhsPacket plhs3 = pload<LhsPacket>(lhs_ptr3);
acc1.packet[0] += plhs1*pbrhs.packet[0];
acc1.packet[1] += plhs1*pbrhs.packet[1];
acc1.packet[2] += plhs1*pbrhs.packet[2];
acc1.packet[3] += plhs1*pbrhs.packet[3];
acc2.packet[0] += plhs2*pbrhs.packet[0];
acc2.packet[1] += plhs2*pbrhs.packet[1];
acc2.packet[2] += plhs2*pbrhs.packet[2];
acc2.packet[3] += plhs2*pbrhs.packet[3];
acc3.packet[0] += plhs3*pbrhs.packet[0];
acc3.packet[1] += plhs3*pbrhs.packet[1];
acc3.packet[2] += plhs3*pbrhs.packet[2];
acc3.packet[3] += plhs3*pbrhs.packet[3];
lhs_ptr1 += (rows/accLhsProgress)*accLhsProgress;
lhs_ptr2 += (rows/accLhsProgress)*accLhsProgress;
lhs_ptr3 += (rows/accLhsProgress)*accLhsProgress;
rhs_ptr += accRhsProgress;
}
r00.storePacket(0,r00.template loadPacket<ResPacket>(0) + acc1.packet[0]);
r01.storePacket(0,r01.template loadPacket<ResPacket>(0) + acc1.packet[1]);
r02.storePacket(0,r02.template loadPacket<ResPacket>(0) + acc1.packet[2]);
r03.storePacket(0,r03.template loadPacket<ResPacket>(0) + acc1.packet[3]);
r10.storePacket(0,r10.template loadPacket<ResPacket>(0) + acc2.packet[0]);
r11.storePacket(0,r11.template loadPacket<ResPacket>(0) + acc2.packet[1]);
r12.storePacket(0,r12.template loadPacket<ResPacket>(0) + acc2.packet[2]);
r13.storePacket(0,r13.template loadPacket<ResPacket>(0) + acc2.packet[3]);
r20.storePacket(0,r20.template loadPacket<ResPacket>(0) + acc3.packet[0]);
r21.storePacket(0,r21.template loadPacket<ResPacket>(0) + acc3.packet[1]);
r22.storePacket(0,r22.template loadPacket<ResPacket>(0) + acc3.packet[2]);
r23.storePacket(0,r23.template loadPacket<ResPacket>(0) + acc3.packet[3]);
}
for(; row + 2*accLhsProgress <= lhsMap.get_packed_size(); row+=2*accLhsProgress)
{
const LhsScalar *lhs_ptr1 = lhsMap.get_packed_at(row + 0*accLhsProgress);
const LhsScalar *lhs_ptr2 = lhsMap.get_packed_at(row + 1*accLhsProgress);
const RhsScalar *rhs_ptr = rhsMap.get_packed_at(col/accRhsProgress);
PacketBlock<AccPacket, 4> acc1;
acc1.packet[0] = pset1<AccPacket>(0);
acc1.packet[1] = pset1<AccPacket>(0);
acc1.packet[2] = pset1<AccPacket>(0);
acc1.packet[3] = pset1<AccPacket>(0);
PacketBlock<AccPacket, 4> acc2;
acc2.packet[0] = pset1<AccPacket>(0);
acc2.packet[1] = pset1<AccPacket>(0);
acc2.packet[2] = pset1<AccPacket>(0);
acc2.packet[3] = pset1<AccPacket>(0);
LinearMapper r00 = res.getLinearMapper(row + 0*accLhsProgress, col + 0);
LinearMapper r01 = res.getLinearMapper(row + 0*accLhsProgress, col + 1);
LinearMapper r02 = res.getLinearMapper(row + 0*accLhsProgress, col + 2);
LinearMapper r03 = res.getLinearMapper(row + 0*accLhsProgress, col + 3);
LinearMapper r10 = res.getLinearMapper(row + 1*accLhsProgress, col + 0);
LinearMapper r11 = res.getLinearMapper(row + 1*accLhsProgress, col + 1);
LinearMapper r12 = res.getLinearMapper(row + 1*accLhsProgress, col + 2);
LinearMapper r13 = res.getLinearMapper(row + 1*accLhsProgress, col + 3);
auto k = 0;
for(; k < depth; k++)
{
RhsPacket prhs = pload<RhsPacket>(rhs_ptr);
PacketBlock<RhsPacket, 4> pbrhs;
pbrhs.packet[0] = pset1<RhsPacket>(prhs[0]);
pbrhs.packet[1] = pset1<RhsPacket>(prhs[1]);
pbrhs.packet[2] = pset1<RhsPacket>(prhs[2]);
pbrhs.packet[3] = pset1<RhsPacket>(prhs[3]);
LhsPacket plhs1 = pload<LhsPacket>(lhs_ptr1);
LhsPacket plhs2 = pload<LhsPacket>(lhs_ptr2);
acc1.packet[0] += plhs1*pbrhs.packet[0];
acc1.packet[1] += plhs1*pbrhs.packet[1];
acc1.packet[2] += plhs1*pbrhs.packet[2];
acc1.packet[3] += plhs1*pbrhs.packet[3];
acc2.packet[0] += plhs2*pbrhs.packet[0];
acc2.packet[1] += plhs2*pbrhs.packet[1];
acc2.packet[2] += plhs2*pbrhs.packet[2];
acc2.packet[3] += plhs2*pbrhs.packet[3];
lhs_ptr1 += (rows/accLhsProgress)*accLhsProgress;
lhs_ptr2 += (rows/accLhsProgress)*accLhsProgress;
rhs_ptr += accRhsProgress;
}
r00.storePacket(0,r00.template loadPacket<ResPacket>(0) + acc1.packet[0]);
r01.storePacket(0,r01.template loadPacket<ResPacket>(0) + acc1.packet[1]);
r02.storePacket(0,r02.template loadPacket<ResPacket>(0) + acc1.packet[2]);
r03.storePacket(0,r03.template loadPacket<ResPacket>(0) + acc1.packet[3]);
r10.storePacket(0,r10.template loadPacket<ResPacket>(0) + acc2.packet[0]);
r11.storePacket(0,r11.template loadPacket<ResPacket>(0) + acc2.packet[1]);
r12.storePacket(0,r12.template loadPacket<ResPacket>(0) + acc2.packet[2]);
r13.storePacket(0,r13.template loadPacket<ResPacket>(0) + acc2.packet[3]);
}
for(; row + accLhsProgress <= lhsMap.get_packed_size(); row+=accLhsProgress)
{
const LhsScalar *lhs_ptr = lhsMap.get_packed_at(row);
const RhsScalar *rhs_ptr = rhsMap.get_packed_at(col/accRhsProgress);
PacketBlock<AccPacket, 4> acc;
acc.packet[0] = pset1<AccPacket>(0);
acc.packet[1] = pset1<AccPacket>(0);
acc.packet[2] = pset1<AccPacket>(0);
acc.packet[3] = pset1<AccPacket>(0);
LinearMapper r0 = res.getLinearMapper(row, col + 0);
LinearMapper r1 = res.getLinearMapper(row, col + 1);
LinearMapper r2 = res.getLinearMapper(row, col + 2);
LinearMapper r3 = res.getLinearMapper(row, col + 3);
auto k = 0;
for(; k < depth; k++)
{
RhsPacket prhs = pload<RhsPacket>(rhs_ptr);
PacketBlock<RhsPacket, 4> pbrhs;
pbrhs.packet[0] = pset1<RhsPacket>(prhs[0]);
pbrhs.packet[1] = pset1<RhsPacket>(prhs[1]);
pbrhs.packet[2] = pset1<RhsPacket>(prhs[2]);
pbrhs.packet[3] = pset1<RhsPacket>(prhs[3]);
LhsPacket plhs = pload<LhsPacket>(lhs_ptr);
#ifdef __NDEBUG__
std::cout << "(" << row << "," << k << "," << col << ")" << std::endl;
std::cout << "lhs " << plhs[0] << " " << plhs[1] << " " << plhs[2] << " " << plhs[3] << std::endl;
std::cout << "rhs " << prhs[0] << " " << prhs[1] << " " << prhs[2] << " " << prhs[3] << std::endl;
#endif
acc.packet[0] += plhs*pbrhs.packet[0];
acc.packet[1] += plhs*pbrhs.packet[1];
acc.packet[2] += plhs*pbrhs.packet[2];
acc.packet[3] += plhs*pbrhs.packet[3];
lhs_ptr += (rows/accLhsProgress)*accLhsProgress;
rhs_ptr += accRhsProgress;
}
r0.storePacket(0,r0.template loadPacket<ResPacket>(0) + acc.packet[0]);
r1.storePacket(0,r1.template loadPacket<ResPacket>(0) + acc.packet[1]);
r2.storePacket(0,r2.template loadPacket<ResPacket>(0) + acc.packet[2]);
r3.storePacket(0,r3.template loadPacket<ResPacket>(0) + acc.packet[3]);
}
auto row_residue = 0;
for(;row < rows; row++)
{
const LhsScalar *lhs_ptr = lhsMap.get_residue_at(row_residue);
const RhsScalar *rhs_ptr = rhsMap.get_packed_at(col/accRhsProgress);
PacketBlock<AccPacket, 1> acc;
acc.packet[0] = pset1<AccPacket>(0);
auto k = 0;
for(; k < depth; k++)
{
RhsPacket prhs = pload<RhsPacket>(rhs_ptr);
LhsPacket plhs = pset1<LhsPacket>(*lhs_ptr);
#ifdef __NDEBUG__
std::cout << "(" << row << "," << k << "," << col << ")" << std::endl;
std::cout << "lhs " << plhs[0] << " " << plhs[1] << " " << plhs[2] << " " << plhs[3] << std::endl;
std::cout << "rhs " << prhs[0] << " " << prhs[1] << " " << prhs[2] << " " << prhs[3] << std::endl;
#endif
acc.packet[0] += (*lhs_ptr)*prhs;
lhs_ptr++;
rhs_ptr += accRhsProgress;
}
res(row, col + 0) += acc.packet[0][0];
res(row, col + 1) += acc.packet[0][1];
res(row, col + 2) += acc.packet[0][2];
res(row, col + 3) += acc.packet[0][3];
row_residue++;
}
}
auto col_residue = 0;
for(; col < cols; col++)
{
auto row = 0;
for(; row + accLhsProgress <= lhsMap.get_packed_size(); row+=accLhsProgress)
{
const LhsScalar *lhs_ptr = lhsMap.get_packed_at(row);
const RhsScalar *rhs_ptr = rhsMap.get_residue_at(col_residue);
PacketBlock<AccPacket, 1> acc;
acc.packet[0] = pset1<AccPacket>(0);
LinearMapper r0 = res.getLinearMapper(row, col + 0);
auto k = 0;
for(; k < depth; k++)
{
RhsPacket prhs = pset1<RhsPacket>(*rhs_ptr);
LhsPacket plhs = pload<LhsPacket>(lhs_ptr);
#ifdef __NDEBUG__
std::cout << "(" << row << "," << k << "," << col << ")" << std::endl;
std::cout << "lhs " << plhs[0] << " " << plhs[1] << " " << plhs[2] << " " << plhs[3] << std::endl;
std::cout << "rhs " << prhs[0] << " " << prhs[1] << " " << prhs[2] << " " << prhs[3] << std::endl;
#endif
acc.packet[0] += plhs*prhs;
lhs_ptr += (rows/accLhsProgress)*accLhsProgress;
rhs_ptr++;
}
r0.storePacket(0,r0.template loadPacket<ResPacket>(0) + acc.packet[0]);
}
auto row_residue = 0;
for(;row < rows; row++)
{
const LhsScalar *lhs_ptr = lhsMap.get_residue_at(row_residue);
const RhsScalar *rhs_ptr = rhsMap.get_residue_at(col_residue);
AccScalar acc = 0;
auto k = 0;
for(; k < depth; k++)
{
#ifdef __NDEBUG__
std::cout << "(" << row << "," << k << "," << col << ")" << std::endl;
std::cout << "lhs " << plhs[0] << " " << plhs[1] << " " << plhs[2] << " " << plhs[3] << std::endl;
std::cout << "rhs " << prhs[0] << " " << prhs[1] << " " << prhs[2] << " " << prhs[3] << std::endl;
#endif
acc += (*lhs_ptr)*(*rhs_ptr);
lhs_ptr++;
rhs_ptr++;
}
//r0.storePacket(0,r0.template loadPacket<ResPacket>(0) + acc.packet[0]);
res(row, col) += acc;
row_residue++;
}
col_residue++;
}
}
template<typename ResScalar, typename AccScalar, typename LhsScalar, typename RhsScalar, typename Index, typename DataMapper>
EIGEN_STRONG_INLINE void gemm_old(const DataMapper& res, const LhsScalar* blockA, const RhsScalar* blockB,
Index rows, Index depth, Index cols, ResScalar alpha, Index strideA, Index strideB, Index offsetA, Index offsetB)
{
using AccPacket = typename packet_traits<AccScalar>::type;
using LhsPacket = typename packet_traits<LhsScalar>::type;
using RhsPacket = typename packet_traits<RhsScalar>::type;
using ResPacket = typename packet_traits<ResScalar>::type;
ResPacket pAlpha = pset1<ResPacket>(alpha);
#ifdef __DEBUG__
std::cout << "blockA" << std::endl;
for(auto i = 0; i < rows*depth; i++)
{
if(i % 4 == 0 && i > 0)
std::cout << std::endl;
std::cout << blockA[i] << " ";
}
std::cout << std::endl;
std::cout << "blockB" << std::endl;
for(auto i = 0; i < depth*cols; i++)
{
if(i % 4 == 0 && i > 0)
std::cout << std::endl;
std::cout << blockB[i] << " ";
}
std::cout << std::endl;
#endif
if( strideA == -1 ) strideA = depth;
if( strideB == -1 ) strideB = depth;
int accLhsProgress = 4;
int accRhsProgress = 4;
PackMap<LhsScalar, LhsPacket, Index> lhsMap(blockA, rows, depth, offsetA, strideA);
PackMap<RhsScalar, RhsPacket, Index, false> rhsMap(blockB, depth, cols, offsetB, strideB);
auto col = 0;
for(; col < rhsMap.get_packed_size(); col+=accRhsProgress)
{
for(auto k = 0; k < depth; k++)
{
const LhsScalar *lhs_ptr = lhsMap.get_packed_at(k);
const RhsScalar *rhs_ptr = rhsMap.get_packed_at(col/accRhsProgress) + k*accRhsProgress;
PacketBlock<AccPacket, 4> acc;
RhsPacket prhs = pload<RhsPacket>(rhs_ptr);
PacketBlock<RhsPacket, 4> pbrhs;
pbrhs.packet[0] = pset1<RhsPacket>(prhs[0]);
pbrhs.packet[1] = pset1<RhsPacket>(prhs[1]);
pbrhs.packet[2] = pset1<RhsPacket>(prhs[2]);
pbrhs.packet[3] = pset1<RhsPacket>(prhs[3]);
auto row = 0;
using LinearMapper = typename DataMapper::LinearMapper;
for(; row < lhsMap.get_packed_size(); row+=accLhsProgress)
{
LinearMapper r0 = res.getLinearMapper(row, col + 0);
LinearMapper r1 = res.getLinearMapper(row, col + 1);
LinearMapper r2 = res.getLinearMapper(row, col + 2);
LinearMapper r3 = res.getLinearMapper(row, col + 3);
LhsPacket plhs = pload<LhsPacket>(lhs_ptr);
#ifdef __NDEBUG__
std::cout << "(" << row << "," << k << "," << col << ")" << std::endl;
std::cout << "lhs " << plhs[0] << " " << plhs[1] << " " << plhs[2] << " " << plhs[3] << std::endl;
std::cout << "rhs " << prhs[0] << " " << prhs[1] << " " << prhs[2] << " " << prhs[3] << std::endl;
#endif
acc.packet[0] = plhs*pbrhs.packet[0];
acc.packet[1] = plhs*pbrhs.packet[1];
acc.packet[2] = plhs*pbrhs.packet[2];
acc.packet[3] = plhs*pbrhs.packet[3];
r0.storePacket(0,r0.template loadPacket<ResPacket>(0) + acc.packet[0]);
r1.storePacket(0,r1.template loadPacket<ResPacket>(0) + acc.packet[1]);
r2.storePacket(0,r2.template loadPacket<ResPacket>(0) + acc.packet[2]);
r3.storePacket(0,r3.template loadPacket<ResPacket>(0) + acc.packet[3]);
lhs_ptr += accLhsProgress;
}
auto residue = 0;
for(;row < rows; row++)
{
LhsScalar lhs = *(lhsMap.get_residue_at(residue) + k);
#ifdef __NDEBUG__
std::cout << "(" << row << "," << k << "," << col << ")" << std::endl;
std::cout << "lhs " << lhs << " (" << prhs[0] << " " << prhs[1] << " " << prhs[2] << " " << prhs[3] << ")" << std::endl;
#endif
res(row, col + 0) += lhs*prhs[0];
res(row, col + 1) += lhs*prhs[1];
res(row, col + 2) += lhs*prhs[2];
res(row, col + 3) += lhs*prhs[3];
residue++;
}
}
}
auto colResidue = 0;
for(;col < cols; col++)
{
for(auto k = 0; k < depth; k++)
{
const LhsScalar *lhs_ptr = lhsMap.get_packed_at(k);
const RhsScalar *rhs_ptr = rhsMap.get_residue_at(colResidue) + k;
AccPacket acc;
RhsPacket prhs = pset1<RhsPacket>(*rhs_ptr);
auto row = 0;
using LinearMapper = typename DataMapper::LinearMapper;
for(; row < lhsMap.get_packed_size(); row+=accLhsProgress)
{
LinearMapper r0 = res.getLinearMapper(row, col + 0);
LhsPacket plhs = pload<LhsPacket>(lhs_ptr);
#ifdef __DEBUG__
std::cout << "(" << row << "," << k << "," << col << ")" << std::endl;
std::cout << "lhs " << plhs[0] << " " << plhs[1] << " " << plhs[2] << " " << plhs[3] << std::endl;
std::cout << "rhs " << prhs[0] << " " << prhs[1] << " " << prhs[2] << " " << prhs[3] << std::endl;
#endif
acc = plhs*prhs;
r0.storePacket(0,r0.template loadPacket<ResPacket>(0) + acc);
lhs_ptr += accLhsProgress;
}
auto residue = 0;
for(;row < rows; row++)
{
LhsScalar lhs = *(lhsMap.get_residue_at(residue) + k);
#ifdef __DEBUG__
std::cout << "(" << row << "," << k << "," << col << ")" << std::endl;
std::cout << "lhs " << lhs << " (" << prhs[0] << " " << prhs[1] << " " << prhs[2] << " " << prhs[3] << ")" << std::endl;
#endif
res(row, col + 0) += lhs*prhs[0];
residue++;
}
}
colResidue++;
}
}
#endif
template<int Architecture, int CPU, typename LhsScalar, typename RhsScalar>
constexpr int SHAPES_COUNT = 4;
constexpr int SHAPES_DIMENSION = 4;
constexpr int SHAPES_DIMENSION = 6;
constexpr int SHAPES_LHS_DIMENSION = 0;
constexpr int SHAPES_DEP_DIMENSION = 1;
constexpr int SHAPES_RHS_DIMENSION = 2;
constexpr int SHAPES_POINTER = 3;
constexpr int SHAPES_RHS_POINTER = 3;
constexpr int SHAPES_LHS_POINTER = 4;
constexpr int SHAPES_DEP_POINTER = 5;
constexpr int SHAPES_POINTER_END = -1;
template<int Architecture, int CPU, typename Scalar, bool isLhs>
@ -575,18 +36,18 @@ constexpr int PACK_SHAPES_DIMENSION = 3;
constexpr int PACK_SHAPES_POINTER = 2;
constexpr int PACK_SHAPES_END = -1;
// lhs_progress x depth_progress x rhs_progress (depth_progress > 1 matrix ops) x pointer to next rhs_progress on the shapes map
template<int Architecture, int CPU, typename LhsScalar, typename RhsScalar>
constexpr int SHAPES[SHAPES_COUNT<Architecture, CPU, LhsScalar,RhsScalar>][SHAPES_DIMENSION] = {{1,1,1,SHAPES_POINTER_END},{4,1,1,0},{1,1,4,1},{4,1,4,1}};
constexpr int SHAPES[SHAPES_COUNT<Architecture, CPU, LhsScalar,RhsScalar>][SHAPES_DIMENSION] =
{ {1,1,1,SHAPES_POINTER_END, SHAPES_POINTER_END, SHAPES_POINTER_END},
{4,1,1, 0, 0, SHAPES_POINTER_END},
{1,1,4, 1, SHAPES_POINTER_END, SHAPES_POINTER_END},
{4,1,4, 1, 2, SHAPES_POINTER_END}};
// d1progress x d2progress
template<int Architecture, int CPU, typename Scalar, bool isLhs>
constexpr int PACK_SHAPES[PACK_SHAPES_COUNT<Architecture, CPU, Scalar, isLhs>][PACK_SHAPES_DIMENSION] = {{1,1,PACK_SHAPES_END},{4,1,0}};
//template<int Architecture, int CPU, typename Scalar>
//constexpr int PACK_SHAPES<Architecture, CPU, Scalar, false>[PACK_SHAPES_COUNT<Architecture, CPU, Scalar, false>][PACK_SHAPES_DIMENSION] = {{1,1,PACK_SHAPES_END},{4,1,0}};
template<int Architecture, int CPU, typename Index, typename Scalar, bool isLhs, typename DataMapper, bool Conjugate, bool PanelMode, int StorageOrder, int M, int N>
struct PackingOperator
{
@ -816,27 +277,29 @@ struct MicroKernel
template<int Architecture, int CPU, typename Index, typename LhsScalar, typename LhsPackMap, typename RhsScalar, typename RhsPackMap, typename AccScalar, typename ResScalar, typename ResPacket, typename DataMapper, int RHS_SHAPE_IDX, int LHS_SHAPE_IDX, int IDX>
struct DepthLoopStruct
{
DepthLoopStruct<Architecture, CPU, Index, LhsScalar, LhsPackMap, RhsScalar, RhsPackMap, AccScalar, ResScalar, ResPacket, DataMapper, RHS_SHAPE_IDX, LHS_SHAPE_IDX, IDX-1> depthLS;
static constexpr auto PREVIOUS = SHAPES<Architecture, CPU, LhsScalar, RhsScalar>[IDX][SHAPES_DEP_POINTER];
DepthLoopStruct<Architecture, CPU, Index, LhsScalar, LhsPackMap, RhsScalar, RhsPackMap, AccScalar, ResScalar, ResPacket, DataMapper, RHS_SHAPE_IDX, LHS_SHAPE_IDX, PREVIOUS> depthLS;
EIGEN_STRONG_INLINE void operator()(Index rowIdx, Index colIdx, Index depthIdx, const DataMapper& res,
Index rows, Index depth, Index cols, ResScalar alpha, const ResPacket& pAlpha, LhsPackMap& lhsPackMap, RhsPackMap& rhsPackMap)
{
constexpr auto rhsProgress = SHAPES<Architecture, CPU, LhsScalar, RhsScalar>[RHS_SHAPE_IDX][SHAPES_RHS_DIMENSION];
constexpr auto lhsProgress = SHAPES<Architecture, CPU, LhsScalar, RhsScalar>[LHS_SHAPE_IDX][SHAPES_LHS_DIMENSION];
constexpr auto depthProgress = SHAPES<Architecture, CPU, LhsScalar, RhsScalar>[IDX][SHAPES_DEP_DIMENSION];
typedef Accumulator<Architecture, CPU, AccScalar, ResScalar, DataMapper, lhsProgress, rhsProgress> AccumulatorType;
if(rhsProgress == SHAPES<Architecture, CPU, LhsScalar, RhsScalar>[IDX][SHAPES_RHS_DIMENSION] && lhsProgress == SHAPES<Architecture, CPU, LhsScalar, RhsScalar>[IDX][SHAPES_LHS_DIMENSION])
MicroKernel<Architecture, CPU, Index, LhsScalar, LhsPackMap, RhsScalar, RhsPackMap, AccScalar, ResScalar, AccumulatorType, lhsProgress, depthProgress, rhsProgress> mkt;
AccumulatorType acc;
acc.zero();
for(; depthIdx + depthProgress <= depth; depthIdx+=depthProgress)
{
MicroKernel<Architecture, CPU, Index, LhsScalar, LhsPackMap, RhsScalar, RhsPackMap, AccScalar, ResScalar, AccumulatorType, lhsProgress, depthProgress, rhsProgress> mkt;
AccumulatorType acc;
acc.zero();
for(; depthIdx + depthProgress <= depth; depthIdx+=depthProgress)
{
mkt(lhsPackMap, rhsPackMap, rowIdx, colIdx, depthIdx, acc);
}
acc.scale(alpha, pAlpha);
acc.store(res, rowIdx, colIdx);
mkt(lhsPackMap, rhsPackMap, rowIdx, colIdx, depthIdx, acc);
}
acc.scale(alpha, pAlpha);
acc.store(res, rowIdx, colIdx);
depthLS(rowIdx, colIdx, depthIdx, res, rows, depth, cols, alpha, pAlpha, lhsPackMap, rhsPackMap);
}
};
@ -851,7 +314,9 @@ struct DepthLoopStruct<Architecture, CPU, Index, LhsScalar, LhsPackMap, RhsScala
template<int Architecture, int CPU, typename Index, typename LhsScalar, typename LhsPackMap, typename RhsScalar, typename RhsPackMap, typename AccScalar, typename ResScalar, typename ResPacket, typename DataMapper, int RHS_SHAPE_IDX, int IDX>
struct LhsLoopStruct
{
LhsLoopStruct<Architecture, CPU, Index, LhsScalar, LhsPackMap, RhsScalar, RhsPackMap, AccScalar, ResScalar, ResPacket, DataMapper, RHS_SHAPE_IDX, IDX-1> lhsLS;
static constexpr auto PREVIOUS = SHAPES<Architecture, CPU, LhsScalar, RhsScalar>[IDX][SHAPES_LHS_POINTER];
LhsLoopStruct<Architecture, CPU, Index, LhsScalar, LhsPackMap, RhsScalar, RhsPackMap, AccScalar, ResScalar, ResPacket, DataMapper, RHS_SHAPE_IDX, PREVIOUS> lhsLS;
EIGEN_STRONG_INLINE void operator()(Index rowIdx, int colIdx, const DataMapper& res,
Index rows, Index depth, Index cols, ResScalar alpha, const ResPacket& pAlpha, LhsPackMap& lhsPackMap, RhsPackMap& rhsPackMap)
{
@ -878,7 +343,7 @@ struct LhsLoopStruct<Architecture, CPU, Index, LhsScalar, LhsPackMap, RhsScalar,
template<int Architecture, int CPU, typename Index, typename LhsScalar, typename LhsPackMap, typename RhsScalar, typename RhsPackMap, typename AccScalar, typename ResScalar, typename ResPacket, typename DataMapper, int IDX>
struct RhsLoopStruct
{
static constexpr auto PREVIOUS = SHAPES<Architecture, CPU, LhsScalar, RhsScalar>[IDX][SHAPES_POINTER];
static constexpr auto PREVIOUS = SHAPES<Architecture, CPU, LhsScalar, RhsScalar>[IDX][SHAPES_RHS_POINTER];
RhsLoopStruct<Architecture, CPU, Index, LhsScalar, LhsPackMap, RhsScalar, RhsPackMap, AccScalar, ResScalar, ResPacket, DataMapper, PREVIOUS> rhsLS;
EIGEN_STRONG_INLINE void operator()(Index colIdx, const DataMapper& res,