mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-10-16 18:11:29 +08:00
WIP
This commit is contained in:
parent
c62ed9b214
commit
646d92c7f1
@ -17,9 +17,16 @@ namespace internal {
|
||||
template<int CPU, typename LhsScalar, typename RhsScalar>
|
||||
constexpr int SHAPES_COUNT<0, CPU, LhsScalar, RhsScalar> = 7;
|
||||
|
||||
// lhs_progress x depth_progress x rhs_progress (depth_progress > 1 matrix ops) x pointer to next rhs_progress on the shapes map
|
||||
template<int CPU, typename LhsScalar, typename RhsScalar>
|
||||
constexpr int SHAPES<0, CPU, LhsScalar, RhsScalar>[SHAPES_COUNT<0, CPU, LhsScalar,RhsScalar>][SHAPES_DIMENSION] =
|
||||
{{1,1,1,SHAPES_POINTER_END},{4,1,1,0},{1,1,4,1},{4,1,4,1},{4,4,4,1},{8,1,4,1},{8,4,4,1}};
|
||||
{ {1,1,1,SHAPES_POINTER_END, SHAPES_POINTER_END, SHAPES_POINTER_END},
|
||||
{4,1,1, 0, 0, SHAPES_POINTER_END},
|
||||
{1,1,4, 1, SHAPES_POINTER_END, SHAPES_POINTER_END},
|
||||
{4,1,4, 1, 2, SHAPES_POINTER_END},
|
||||
{4,4,4, 1, 2, 3},
|
||||
{8,1,4, 1, 4, SHAPES_POINTER_END},
|
||||
{8,4,4, 1, 4, 5}};
|
||||
|
||||
template<int CPU, typename Scalar, typename ResScalar, typename DataMapper>
|
||||
struct Accumulator<0, CPU, Scalar, ResScalar, DataMapper, 4, 1>
|
||||
|
@ -18,555 +18,16 @@ namespace Eigen {
|
||||
|
||||
namespace internal {
|
||||
|
||||
|
||||
#ifdef __OLD__
|
||||
template<typename Scalar, typename Packet, typename Index, bool IsLhs = true>
|
||||
class PackMap
|
||||
{
|
||||
const int packetSize = packet_traits<Scalar>::size;
|
||||
const Scalar *packed_block;
|
||||
const Scalar *residue_block;
|
||||
Index packed_stride;
|
||||
Index residue_size;
|
||||
Index rows, cols;
|
||||
Index offset, stride;
|
||||
Scalar *cur;
|
||||
public:
|
||||
PackMap(const Scalar *packed_block, const Scalar *residue_block, Index rows, Index cols, Index offset, Index stride) : packed_block(packed_block), residue_block(residue_block), rows(rows), cols(cols), offset(offset), stride(stride)
|
||||
{
|
||||
if(IsLhs)
|
||||
{
|
||||
packed_stride = (rows / packetSize) * packetSize;
|
||||
residue_size = rows % packetSize;
|
||||
}
|
||||
else {
|
||||
packed_stride = (cols / packetSize) * packetSize;
|
||||
residue_size = cols % packetSize;
|
||||
}
|
||||
};
|
||||
|
||||
PackMap(const Scalar *packed_block, Index rows, Index cols, Index offset, Index stride) : packed_block(packed_block), rows(rows), cols(cols)
|
||||
{
|
||||
if(IsLhs)
|
||||
{
|
||||
packed_stride = (rows / packetSize) * packetSize;
|
||||
residue_block = packed_block + packed_stride*cols;
|
||||
residue_size = rows % packetSize;
|
||||
}
|
||||
else {
|
||||
packed_stride = (cols / packetSize) * packetSize;
|
||||
residue_block = packed_block + packed_stride*rows;
|
||||
residue_size = cols % packetSize;
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
EIGEN_STRONG_INLINE Index get_packed_size()
|
||||
{
|
||||
return packed_stride;
|
||||
};
|
||||
|
||||
EIGEN_STRONG_INLINE Index get_residue_size()
|
||||
{
|
||||
return residue_size;
|
||||
};
|
||||
|
||||
EIGEN_STRONG_INLINE const Scalar* get_packed_at(Index at)
|
||||
{
|
||||
return IsLhs ? packed_block + at : packed_block + at*packetSize*rows;
|
||||
};
|
||||
|
||||
EIGEN_STRONG_INLINE const Scalar* get_residue_at(Index at)
|
||||
{
|
||||
return residue_block + stride*at;
|
||||
};
|
||||
};
|
||||
|
||||
template<typename ResScalar, typename AccScalar, typename LhsScalar, typename RhsScalar, typename Index, typename DataMapper>
|
||||
EIGEN_STRONG_INLINE void gemm(const DataMapper& res, const LhsScalar* blockA, const RhsScalar* blockB,
|
||||
Index rows, Index depth, Index cols, ResScalar alpha, Index strideA, Index strideB, Index offsetA, Index offsetB)
|
||||
{
|
||||
using AccPacket = typename packet_traits<AccScalar>::type;
|
||||
using LhsPacket = typename packet_traits<LhsScalar>::type;
|
||||
using RhsPacket = typename packet_traits<RhsScalar>::type;
|
||||
using ResPacket = typename packet_traits<ResScalar>::type;
|
||||
using LinearMapper = typename DataMapper::LinearMapper;
|
||||
|
||||
if( strideA == -1 ) strideA = depth;
|
||||
if( strideB == -1 ) strideB = depth;
|
||||
|
||||
ResPacket pAlpha = pset1<ResPacket>(alpha);
|
||||
|
||||
#ifdef __DEBUG__
|
||||
std::cout << "blockA" << std::endl;
|
||||
for(auto i = 0; i < rows*depth; i++)
|
||||
{
|
||||
if(i % strideA == 0 && i > 0)
|
||||
std::cout << std::endl;
|
||||
std::cout << blockA[i] << " ";
|
||||
}
|
||||
std::cout << std::endl;
|
||||
std::cout << "blockB" << std::endl;
|
||||
for(auto i = 0; i < depth*cols; i++)
|
||||
{
|
||||
if(i % strideB == 0 && i > 0)
|
||||
std::cout << std::endl;
|
||||
std::cout << blockB[i] << " ";
|
||||
}
|
||||
std::cout << std::endl;
|
||||
#endif
|
||||
|
||||
int accLhsProgress = 4;
|
||||
int accRhsProgress = 4;
|
||||
|
||||
PackMap<LhsScalar, LhsPacket, Index> lhsMap(blockA, rows, depth, offsetA, strideA);
|
||||
PackMap<RhsScalar, RhsPacket, Index, false> rhsMap(blockB, depth, cols, offsetB, strideB);
|
||||
auto col = 0;
|
||||
for(; col + accRhsProgress <= rhsMap.get_packed_size(); col+=accRhsProgress)
|
||||
{
|
||||
auto row = 0;
|
||||
for(; row + 3*accLhsProgress <= lhsMap.get_packed_size(); row+=3*accLhsProgress)
|
||||
{
|
||||
const LhsScalar *lhs_ptr1 = lhsMap.get_packed_at(row + 0*accLhsProgress);
|
||||
const LhsScalar *lhs_ptr2 = lhsMap.get_packed_at(row + 1*accLhsProgress);
|
||||
const LhsScalar *lhs_ptr3 = lhsMap.get_packed_at(row + 2*accLhsProgress);
|
||||
const RhsScalar *rhs_ptr = rhsMap.get_packed_at(col/accRhsProgress);
|
||||
|
||||
PacketBlock<AccPacket, 4> acc1;
|
||||
acc1.packet[0] = pset1<AccPacket>(0);
|
||||
acc1.packet[1] = pset1<AccPacket>(0);
|
||||
acc1.packet[2] = pset1<AccPacket>(0);
|
||||
acc1.packet[3] = pset1<AccPacket>(0);
|
||||
|
||||
PacketBlock<AccPacket, 4> acc2;
|
||||
acc2.packet[0] = pset1<AccPacket>(0);
|
||||
acc2.packet[1] = pset1<AccPacket>(0);
|
||||
acc2.packet[2] = pset1<AccPacket>(0);
|
||||
acc2.packet[3] = pset1<AccPacket>(0);
|
||||
|
||||
PacketBlock<AccPacket, 4> acc3;
|
||||
acc3.packet[0] = pset1<AccPacket>(0);
|
||||
acc3.packet[1] = pset1<AccPacket>(0);
|
||||
acc3.packet[2] = pset1<AccPacket>(0);
|
||||
acc3.packet[3] = pset1<AccPacket>(0);
|
||||
|
||||
LinearMapper r00 = res.getLinearMapper(row + 0*accLhsProgress, col + 0);
|
||||
LinearMapper r01 = res.getLinearMapper(row + 0*accLhsProgress, col + 1);
|
||||
LinearMapper r02 = res.getLinearMapper(row + 0*accLhsProgress, col + 2);
|
||||
LinearMapper r03 = res.getLinearMapper(row + 0*accLhsProgress, col + 3);
|
||||
|
||||
LinearMapper r10 = res.getLinearMapper(row + 1*accLhsProgress, col + 0);
|
||||
LinearMapper r11 = res.getLinearMapper(row + 1*accLhsProgress, col + 1);
|
||||
LinearMapper r12 = res.getLinearMapper(row + 1*accLhsProgress, col + 2);
|
||||
LinearMapper r13 = res.getLinearMapper(row + 1*accLhsProgress, col + 3);
|
||||
|
||||
LinearMapper r20 = res.getLinearMapper(row + 2*accLhsProgress, col + 0);
|
||||
LinearMapper r21 = res.getLinearMapper(row + 2*accLhsProgress, col + 1);
|
||||
LinearMapper r22 = res.getLinearMapper(row + 2*accLhsProgress, col + 2);
|
||||
LinearMapper r23 = res.getLinearMapper(row + 2*accLhsProgress, col + 3);
|
||||
|
||||
auto k = 0;
|
||||
for(; k < depth; k++)
|
||||
{
|
||||
RhsPacket prhs = pload<RhsPacket>(rhs_ptr);
|
||||
PacketBlock<RhsPacket, 4> pbrhs;
|
||||
pbrhs.packet[0] = pset1<RhsPacket>(prhs[0]);
|
||||
pbrhs.packet[1] = pset1<RhsPacket>(prhs[1]);
|
||||
pbrhs.packet[2] = pset1<RhsPacket>(prhs[2]);
|
||||
pbrhs.packet[3] = pset1<RhsPacket>(prhs[3]);
|
||||
|
||||
LhsPacket plhs1 = pload<LhsPacket>(lhs_ptr1);
|
||||
LhsPacket plhs2 = pload<LhsPacket>(lhs_ptr2);
|
||||
LhsPacket plhs3 = pload<LhsPacket>(lhs_ptr3);
|
||||
|
||||
acc1.packet[0] += plhs1*pbrhs.packet[0];
|
||||
acc1.packet[1] += plhs1*pbrhs.packet[1];
|
||||
acc1.packet[2] += plhs1*pbrhs.packet[2];
|
||||
acc1.packet[3] += plhs1*pbrhs.packet[3];
|
||||
|
||||
acc2.packet[0] += plhs2*pbrhs.packet[0];
|
||||
acc2.packet[1] += plhs2*pbrhs.packet[1];
|
||||
acc2.packet[2] += plhs2*pbrhs.packet[2];
|
||||
acc2.packet[3] += plhs2*pbrhs.packet[3];
|
||||
|
||||
acc3.packet[0] += plhs3*pbrhs.packet[0];
|
||||
acc3.packet[1] += plhs3*pbrhs.packet[1];
|
||||
acc3.packet[2] += plhs3*pbrhs.packet[2];
|
||||
acc3.packet[3] += plhs3*pbrhs.packet[3];
|
||||
|
||||
lhs_ptr1 += (rows/accLhsProgress)*accLhsProgress;
|
||||
lhs_ptr2 += (rows/accLhsProgress)*accLhsProgress;
|
||||
lhs_ptr3 += (rows/accLhsProgress)*accLhsProgress;
|
||||
rhs_ptr += accRhsProgress;
|
||||
}
|
||||
|
||||
r00.storePacket(0,r00.template loadPacket<ResPacket>(0) + acc1.packet[0]);
|
||||
r01.storePacket(0,r01.template loadPacket<ResPacket>(0) + acc1.packet[1]);
|
||||
r02.storePacket(0,r02.template loadPacket<ResPacket>(0) + acc1.packet[2]);
|
||||
r03.storePacket(0,r03.template loadPacket<ResPacket>(0) + acc1.packet[3]);
|
||||
|
||||
r10.storePacket(0,r10.template loadPacket<ResPacket>(0) + acc2.packet[0]);
|
||||
r11.storePacket(0,r11.template loadPacket<ResPacket>(0) + acc2.packet[1]);
|
||||
r12.storePacket(0,r12.template loadPacket<ResPacket>(0) + acc2.packet[2]);
|
||||
r13.storePacket(0,r13.template loadPacket<ResPacket>(0) + acc2.packet[3]);
|
||||
|
||||
r20.storePacket(0,r20.template loadPacket<ResPacket>(0) + acc3.packet[0]);
|
||||
r21.storePacket(0,r21.template loadPacket<ResPacket>(0) + acc3.packet[1]);
|
||||
r22.storePacket(0,r22.template loadPacket<ResPacket>(0) + acc3.packet[2]);
|
||||
r23.storePacket(0,r23.template loadPacket<ResPacket>(0) + acc3.packet[3]);
|
||||
}
|
||||
for(; row + 2*accLhsProgress <= lhsMap.get_packed_size(); row+=2*accLhsProgress)
|
||||
{
|
||||
const LhsScalar *lhs_ptr1 = lhsMap.get_packed_at(row + 0*accLhsProgress);
|
||||
const LhsScalar *lhs_ptr2 = lhsMap.get_packed_at(row + 1*accLhsProgress);
|
||||
const RhsScalar *rhs_ptr = rhsMap.get_packed_at(col/accRhsProgress);
|
||||
|
||||
PacketBlock<AccPacket, 4> acc1;
|
||||
acc1.packet[0] = pset1<AccPacket>(0);
|
||||
acc1.packet[1] = pset1<AccPacket>(0);
|
||||
acc1.packet[2] = pset1<AccPacket>(0);
|
||||
acc1.packet[3] = pset1<AccPacket>(0);
|
||||
|
||||
PacketBlock<AccPacket, 4> acc2;
|
||||
acc2.packet[0] = pset1<AccPacket>(0);
|
||||
acc2.packet[1] = pset1<AccPacket>(0);
|
||||
acc2.packet[2] = pset1<AccPacket>(0);
|
||||
acc2.packet[3] = pset1<AccPacket>(0);
|
||||
|
||||
LinearMapper r00 = res.getLinearMapper(row + 0*accLhsProgress, col + 0);
|
||||
LinearMapper r01 = res.getLinearMapper(row + 0*accLhsProgress, col + 1);
|
||||
LinearMapper r02 = res.getLinearMapper(row + 0*accLhsProgress, col + 2);
|
||||
LinearMapper r03 = res.getLinearMapper(row + 0*accLhsProgress, col + 3);
|
||||
|
||||
LinearMapper r10 = res.getLinearMapper(row + 1*accLhsProgress, col + 0);
|
||||
LinearMapper r11 = res.getLinearMapper(row + 1*accLhsProgress, col + 1);
|
||||
LinearMapper r12 = res.getLinearMapper(row + 1*accLhsProgress, col + 2);
|
||||
LinearMapper r13 = res.getLinearMapper(row + 1*accLhsProgress, col + 3);
|
||||
|
||||
auto k = 0;
|
||||
for(; k < depth; k++)
|
||||
{
|
||||
RhsPacket prhs = pload<RhsPacket>(rhs_ptr);
|
||||
PacketBlock<RhsPacket, 4> pbrhs;
|
||||
pbrhs.packet[0] = pset1<RhsPacket>(prhs[0]);
|
||||
pbrhs.packet[1] = pset1<RhsPacket>(prhs[1]);
|
||||
pbrhs.packet[2] = pset1<RhsPacket>(prhs[2]);
|
||||
pbrhs.packet[3] = pset1<RhsPacket>(prhs[3]);
|
||||
|
||||
LhsPacket plhs1 = pload<LhsPacket>(lhs_ptr1);
|
||||
LhsPacket plhs2 = pload<LhsPacket>(lhs_ptr2);
|
||||
|
||||
acc1.packet[0] += plhs1*pbrhs.packet[0];
|
||||
acc1.packet[1] += plhs1*pbrhs.packet[1];
|
||||
acc1.packet[2] += plhs1*pbrhs.packet[2];
|
||||
acc1.packet[3] += plhs1*pbrhs.packet[3];
|
||||
|
||||
acc2.packet[0] += plhs2*pbrhs.packet[0];
|
||||
acc2.packet[1] += plhs2*pbrhs.packet[1];
|
||||
acc2.packet[2] += plhs2*pbrhs.packet[2];
|
||||
acc2.packet[3] += plhs2*pbrhs.packet[3];
|
||||
|
||||
lhs_ptr1 += (rows/accLhsProgress)*accLhsProgress;
|
||||
lhs_ptr2 += (rows/accLhsProgress)*accLhsProgress;
|
||||
rhs_ptr += accRhsProgress;
|
||||
}
|
||||
|
||||
r00.storePacket(0,r00.template loadPacket<ResPacket>(0) + acc1.packet[0]);
|
||||
r01.storePacket(0,r01.template loadPacket<ResPacket>(0) + acc1.packet[1]);
|
||||
r02.storePacket(0,r02.template loadPacket<ResPacket>(0) + acc1.packet[2]);
|
||||
r03.storePacket(0,r03.template loadPacket<ResPacket>(0) + acc1.packet[3]);
|
||||
|
||||
r10.storePacket(0,r10.template loadPacket<ResPacket>(0) + acc2.packet[0]);
|
||||
r11.storePacket(0,r11.template loadPacket<ResPacket>(0) + acc2.packet[1]);
|
||||
r12.storePacket(0,r12.template loadPacket<ResPacket>(0) + acc2.packet[2]);
|
||||
r13.storePacket(0,r13.template loadPacket<ResPacket>(0) + acc2.packet[3]);
|
||||
}
|
||||
for(; row + accLhsProgress <= lhsMap.get_packed_size(); row+=accLhsProgress)
|
||||
{
|
||||
const LhsScalar *lhs_ptr = lhsMap.get_packed_at(row);
|
||||
const RhsScalar *rhs_ptr = rhsMap.get_packed_at(col/accRhsProgress);
|
||||
PacketBlock<AccPacket, 4> acc;
|
||||
acc.packet[0] = pset1<AccPacket>(0);
|
||||
acc.packet[1] = pset1<AccPacket>(0);
|
||||
acc.packet[2] = pset1<AccPacket>(0);
|
||||
acc.packet[3] = pset1<AccPacket>(0);
|
||||
|
||||
LinearMapper r0 = res.getLinearMapper(row, col + 0);
|
||||
LinearMapper r1 = res.getLinearMapper(row, col + 1);
|
||||
LinearMapper r2 = res.getLinearMapper(row, col + 2);
|
||||
LinearMapper r3 = res.getLinearMapper(row, col + 3);
|
||||
|
||||
auto k = 0;
|
||||
for(; k < depth; k++)
|
||||
{
|
||||
RhsPacket prhs = pload<RhsPacket>(rhs_ptr);
|
||||
PacketBlock<RhsPacket, 4> pbrhs;
|
||||
pbrhs.packet[0] = pset1<RhsPacket>(prhs[0]);
|
||||
pbrhs.packet[1] = pset1<RhsPacket>(prhs[1]);
|
||||
pbrhs.packet[2] = pset1<RhsPacket>(prhs[2]);
|
||||
pbrhs.packet[3] = pset1<RhsPacket>(prhs[3]);
|
||||
|
||||
LhsPacket plhs = pload<LhsPacket>(lhs_ptr);
|
||||
|
||||
#ifdef __NDEBUG__
|
||||
std::cout << "(" << row << "," << k << "," << col << ")" << std::endl;
|
||||
std::cout << "lhs " << plhs[0] << " " << plhs[1] << " " << plhs[2] << " " << plhs[3] << std::endl;
|
||||
std::cout << "rhs " << prhs[0] << " " << prhs[1] << " " << prhs[2] << " " << prhs[3] << std::endl;
|
||||
#endif
|
||||
acc.packet[0] += plhs*pbrhs.packet[0];
|
||||
acc.packet[1] += plhs*pbrhs.packet[1];
|
||||
acc.packet[2] += plhs*pbrhs.packet[2];
|
||||
acc.packet[3] += plhs*pbrhs.packet[3];
|
||||
|
||||
lhs_ptr += (rows/accLhsProgress)*accLhsProgress;
|
||||
rhs_ptr += accRhsProgress;
|
||||
}
|
||||
|
||||
r0.storePacket(0,r0.template loadPacket<ResPacket>(0) + acc.packet[0]);
|
||||
r1.storePacket(0,r1.template loadPacket<ResPacket>(0) + acc.packet[1]);
|
||||
r2.storePacket(0,r2.template loadPacket<ResPacket>(0) + acc.packet[2]);
|
||||
r3.storePacket(0,r3.template loadPacket<ResPacket>(0) + acc.packet[3]);
|
||||
}
|
||||
auto row_residue = 0;
|
||||
for(;row < rows; row++)
|
||||
{
|
||||
const LhsScalar *lhs_ptr = lhsMap.get_residue_at(row_residue);
|
||||
const RhsScalar *rhs_ptr = rhsMap.get_packed_at(col/accRhsProgress);
|
||||
PacketBlock<AccPacket, 1> acc;
|
||||
acc.packet[0] = pset1<AccPacket>(0);
|
||||
|
||||
auto k = 0;
|
||||
for(; k < depth; k++)
|
||||
{
|
||||
RhsPacket prhs = pload<RhsPacket>(rhs_ptr);
|
||||
LhsPacket plhs = pset1<LhsPacket>(*lhs_ptr);
|
||||
|
||||
#ifdef __NDEBUG__
|
||||
std::cout << "(" << row << "," << k << "," << col << ")" << std::endl;
|
||||
std::cout << "lhs " << plhs[0] << " " << plhs[1] << " " << plhs[2] << " " << plhs[3] << std::endl;
|
||||
std::cout << "rhs " << prhs[0] << " " << prhs[1] << " " << prhs[2] << " " << prhs[3] << std::endl;
|
||||
#endif
|
||||
acc.packet[0] += (*lhs_ptr)*prhs;
|
||||
|
||||
lhs_ptr++;
|
||||
rhs_ptr += accRhsProgress;
|
||||
}
|
||||
|
||||
res(row, col + 0) += acc.packet[0][0];
|
||||
res(row, col + 1) += acc.packet[0][1];
|
||||
res(row, col + 2) += acc.packet[0][2];
|
||||
res(row, col + 3) += acc.packet[0][3];
|
||||
row_residue++;
|
||||
}
|
||||
}
|
||||
auto col_residue = 0;
|
||||
for(; col < cols; col++)
|
||||
{
|
||||
auto row = 0;
|
||||
for(; row + accLhsProgress <= lhsMap.get_packed_size(); row+=accLhsProgress)
|
||||
{
|
||||
const LhsScalar *lhs_ptr = lhsMap.get_packed_at(row);
|
||||
const RhsScalar *rhs_ptr = rhsMap.get_residue_at(col_residue);
|
||||
PacketBlock<AccPacket, 1> acc;
|
||||
acc.packet[0] = pset1<AccPacket>(0);
|
||||
|
||||
LinearMapper r0 = res.getLinearMapper(row, col + 0);
|
||||
|
||||
auto k = 0;
|
||||
for(; k < depth; k++)
|
||||
{
|
||||
RhsPacket prhs = pset1<RhsPacket>(*rhs_ptr);
|
||||
|
||||
LhsPacket plhs = pload<LhsPacket>(lhs_ptr);
|
||||
|
||||
#ifdef __NDEBUG__
|
||||
std::cout << "(" << row << "," << k << "," << col << ")" << std::endl;
|
||||
std::cout << "lhs " << plhs[0] << " " << plhs[1] << " " << plhs[2] << " " << plhs[3] << std::endl;
|
||||
std::cout << "rhs " << prhs[0] << " " << prhs[1] << " " << prhs[2] << " " << prhs[3] << std::endl;
|
||||
#endif
|
||||
acc.packet[0] += plhs*prhs;
|
||||
|
||||
lhs_ptr += (rows/accLhsProgress)*accLhsProgress;
|
||||
rhs_ptr++;
|
||||
}
|
||||
|
||||
r0.storePacket(0,r0.template loadPacket<ResPacket>(0) + acc.packet[0]);
|
||||
}
|
||||
auto row_residue = 0;
|
||||
for(;row < rows; row++)
|
||||
{
|
||||
const LhsScalar *lhs_ptr = lhsMap.get_residue_at(row_residue);
|
||||
const RhsScalar *rhs_ptr = rhsMap.get_residue_at(col_residue);
|
||||
AccScalar acc = 0;
|
||||
auto k = 0;
|
||||
for(; k < depth; k++)
|
||||
{
|
||||
#ifdef __NDEBUG__
|
||||
std::cout << "(" << row << "," << k << "," << col << ")" << std::endl;
|
||||
std::cout << "lhs " << plhs[0] << " " << plhs[1] << " " << plhs[2] << " " << plhs[3] << std::endl;
|
||||
std::cout << "rhs " << prhs[0] << " " << prhs[1] << " " << prhs[2] << " " << prhs[3] << std::endl;
|
||||
#endif
|
||||
acc += (*lhs_ptr)*(*rhs_ptr);
|
||||
|
||||
lhs_ptr++;
|
||||
rhs_ptr++;
|
||||
}
|
||||
|
||||
//r0.storePacket(0,r0.template loadPacket<ResPacket>(0) + acc.packet[0]);
|
||||
res(row, col) += acc;
|
||||
row_residue++;
|
||||
}
|
||||
col_residue++;
|
||||
}
|
||||
}
|
||||
|
||||
template<typename ResScalar, typename AccScalar, typename LhsScalar, typename RhsScalar, typename Index, typename DataMapper>
|
||||
EIGEN_STRONG_INLINE void gemm_old(const DataMapper& res, const LhsScalar* blockA, const RhsScalar* blockB,
|
||||
Index rows, Index depth, Index cols, ResScalar alpha, Index strideA, Index strideB, Index offsetA, Index offsetB)
|
||||
{
|
||||
using AccPacket = typename packet_traits<AccScalar>::type;
|
||||
using LhsPacket = typename packet_traits<LhsScalar>::type;
|
||||
using RhsPacket = typename packet_traits<RhsScalar>::type;
|
||||
using ResPacket = typename packet_traits<ResScalar>::type;
|
||||
|
||||
ResPacket pAlpha = pset1<ResPacket>(alpha);
|
||||
|
||||
#ifdef __DEBUG__
|
||||
std::cout << "blockA" << std::endl;
|
||||
for(auto i = 0; i < rows*depth; i++)
|
||||
{
|
||||
if(i % 4 == 0 && i > 0)
|
||||
std::cout << std::endl;
|
||||
std::cout << blockA[i] << " ";
|
||||
}
|
||||
std::cout << std::endl;
|
||||
std::cout << "blockB" << std::endl;
|
||||
for(auto i = 0; i < depth*cols; i++)
|
||||
{
|
||||
if(i % 4 == 0 && i > 0)
|
||||
std::cout << std::endl;
|
||||
std::cout << blockB[i] << " ";
|
||||
}
|
||||
std::cout << std::endl;
|
||||
#endif
|
||||
|
||||
if( strideA == -1 ) strideA = depth;
|
||||
if( strideB == -1 ) strideB = depth;
|
||||
|
||||
int accLhsProgress = 4;
|
||||
int accRhsProgress = 4;
|
||||
|
||||
PackMap<LhsScalar, LhsPacket, Index> lhsMap(blockA, rows, depth, offsetA, strideA);
|
||||
PackMap<RhsScalar, RhsPacket, Index, false> rhsMap(blockB, depth, cols, offsetB, strideB);
|
||||
auto col = 0;
|
||||
for(; col < rhsMap.get_packed_size(); col+=accRhsProgress)
|
||||
{
|
||||
for(auto k = 0; k < depth; k++)
|
||||
{
|
||||
const LhsScalar *lhs_ptr = lhsMap.get_packed_at(k);
|
||||
const RhsScalar *rhs_ptr = rhsMap.get_packed_at(col/accRhsProgress) + k*accRhsProgress;
|
||||
PacketBlock<AccPacket, 4> acc;
|
||||
RhsPacket prhs = pload<RhsPacket>(rhs_ptr);
|
||||
PacketBlock<RhsPacket, 4> pbrhs;
|
||||
pbrhs.packet[0] = pset1<RhsPacket>(prhs[0]);
|
||||
pbrhs.packet[1] = pset1<RhsPacket>(prhs[1]);
|
||||
pbrhs.packet[2] = pset1<RhsPacket>(prhs[2]);
|
||||
pbrhs.packet[3] = pset1<RhsPacket>(prhs[3]);
|
||||
auto row = 0;
|
||||
using LinearMapper = typename DataMapper::LinearMapper;
|
||||
for(; row < lhsMap.get_packed_size(); row+=accLhsProgress)
|
||||
{
|
||||
LinearMapper r0 = res.getLinearMapper(row, col + 0);
|
||||
LinearMapper r1 = res.getLinearMapper(row, col + 1);
|
||||
LinearMapper r2 = res.getLinearMapper(row, col + 2);
|
||||
LinearMapper r3 = res.getLinearMapper(row, col + 3);
|
||||
|
||||
LhsPacket plhs = pload<LhsPacket>(lhs_ptr);
|
||||
#ifdef __NDEBUG__
|
||||
std::cout << "(" << row << "," << k << "," << col << ")" << std::endl;
|
||||
std::cout << "lhs " << plhs[0] << " " << plhs[1] << " " << plhs[2] << " " << plhs[3] << std::endl;
|
||||
std::cout << "rhs " << prhs[0] << " " << prhs[1] << " " << prhs[2] << " " << prhs[3] << std::endl;
|
||||
#endif
|
||||
acc.packet[0] = plhs*pbrhs.packet[0];
|
||||
acc.packet[1] = plhs*pbrhs.packet[1];
|
||||
acc.packet[2] = plhs*pbrhs.packet[2];
|
||||
acc.packet[3] = plhs*pbrhs.packet[3];
|
||||
|
||||
r0.storePacket(0,r0.template loadPacket<ResPacket>(0) + acc.packet[0]);
|
||||
r1.storePacket(0,r1.template loadPacket<ResPacket>(0) + acc.packet[1]);
|
||||
r2.storePacket(0,r2.template loadPacket<ResPacket>(0) + acc.packet[2]);
|
||||
r3.storePacket(0,r3.template loadPacket<ResPacket>(0) + acc.packet[3]);
|
||||
lhs_ptr += accLhsProgress;
|
||||
}
|
||||
auto residue = 0;
|
||||
for(;row < rows; row++)
|
||||
{
|
||||
LhsScalar lhs = *(lhsMap.get_residue_at(residue) + k);
|
||||
#ifdef __NDEBUG__
|
||||
std::cout << "(" << row << "," << k << "," << col << ")" << std::endl;
|
||||
std::cout << "lhs " << lhs << " (" << prhs[0] << " " << prhs[1] << " " << prhs[2] << " " << prhs[3] << ")" << std::endl;
|
||||
#endif
|
||||
res(row, col + 0) += lhs*prhs[0];
|
||||
res(row, col + 1) += lhs*prhs[1];
|
||||
res(row, col + 2) += lhs*prhs[2];
|
||||
res(row, col + 3) += lhs*prhs[3];
|
||||
residue++;
|
||||
}
|
||||
}
|
||||
}
|
||||
auto colResidue = 0;
|
||||
for(;col < cols; col++)
|
||||
{
|
||||
for(auto k = 0; k < depth; k++)
|
||||
{
|
||||
const LhsScalar *lhs_ptr = lhsMap.get_packed_at(k);
|
||||
const RhsScalar *rhs_ptr = rhsMap.get_residue_at(colResidue) + k;
|
||||
AccPacket acc;
|
||||
|
||||
RhsPacket prhs = pset1<RhsPacket>(*rhs_ptr);
|
||||
|
||||
auto row = 0;
|
||||
using LinearMapper = typename DataMapper::LinearMapper;
|
||||
for(; row < lhsMap.get_packed_size(); row+=accLhsProgress)
|
||||
{
|
||||
LinearMapper r0 = res.getLinearMapper(row, col + 0);
|
||||
|
||||
LhsPacket plhs = pload<LhsPacket>(lhs_ptr);
|
||||
#ifdef __DEBUG__
|
||||
std::cout << "(" << row << "," << k << "," << col << ")" << std::endl;
|
||||
std::cout << "lhs " << plhs[0] << " " << plhs[1] << " " << plhs[2] << " " << plhs[3] << std::endl;
|
||||
std::cout << "rhs " << prhs[0] << " " << prhs[1] << " " << prhs[2] << " " << prhs[3] << std::endl;
|
||||
#endif
|
||||
acc = plhs*prhs;
|
||||
|
||||
r0.storePacket(0,r0.template loadPacket<ResPacket>(0) + acc);
|
||||
lhs_ptr += accLhsProgress;
|
||||
}
|
||||
auto residue = 0;
|
||||
for(;row < rows; row++)
|
||||
{
|
||||
LhsScalar lhs = *(lhsMap.get_residue_at(residue) + k);
|
||||
#ifdef __DEBUG__
|
||||
std::cout << "(" << row << "," << k << "," << col << ")" << std::endl;
|
||||
std::cout << "lhs " << lhs << " (" << prhs[0] << " " << prhs[1] << " " << prhs[2] << " " << prhs[3] << ")" << std::endl;
|
||||
#endif
|
||||
res(row, col + 0) += lhs*prhs[0];
|
||||
residue++;
|
||||
}
|
||||
}
|
||||
colResidue++;
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
template<int Architecture, int CPU, typename LhsScalar, typename RhsScalar>
|
||||
constexpr int SHAPES_COUNT = 4;
|
||||
|
||||
constexpr int SHAPES_DIMENSION = 4;
|
||||
constexpr int SHAPES_DIMENSION = 6;
|
||||
constexpr int SHAPES_LHS_DIMENSION = 0;
|
||||
constexpr int SHAPES_DEP_DIMENSION = 1;
|
||||
constexpr int SHAPES_RHS_DIMENSION = 2;
|
||||
constexpr int SHAPES_POINTER = 3;
|
||||
constexpr int SHAPES_RHS_POINTER = 3;
|
||||
constexpr int SHAPES_LHS_POINTER = 4;
|
||||
constexpr int SHAPES_DEP_POINTER = 5;
|
||||
constexpr int SHAPES_POINTER_END = -1;
|
||||
|
||||
template<int Architecture, int CPU, typename Scalar, bool isLhs>
|
||||
@ -575,18 +36,18 @@ constexpr int PACK_SHAPES_DIMENSION = 3;
|
||||
constexpr int PACK_SHAPES_POINTER = 2;
|
||||
constexpr int PACK_SHAPES_END = -1;
|
||||
|
||||
|
||||
// lhs_progress x depth_progress x rhs_progress (depth_progress > 1 matrix ops) x pointer to next rhs_progress on the shapes map
|
||||
template<int Architecture, int CPU, typename LhsScalar, typename RhsScalar>
|
||||
constexpr int SHAPES[SHAPES_COUNT<Architecture, CPU, LhsScalar,RhsScalar>][SHAPES_DIMENSION] = {{1,1,1,SHAPES_POINTER_END},{4,1,1,0},{1,1,4,1},{4,1,4,1}};
|
||||
constexpr int SHAPES[SHAPES_COUNT<Architecture, CPU, LhsScalar,RhsScalar>][SHAPES_DIMENSION] =
|
||||
{ {1,1,1,SHAPES_POINTER_END, SHAPES_POINTER_END, SHAPES_POINTER_END},
|
||||
{4,1,1, 0, 0, SHAPES_POINTER_END},
|
||||
{1,1,4, 1, SHAPES_POINTER_END, SHAPES_POINTER_END},
|
||||
{4,1,4, 1, 2, SHAPES_POINTER_END}};
|
||||
|
||||
// d1progress x d2progress
|
||||
template<int Architecture, int CPU, typename Scalar, bool isLhs>
|
||||
constexpr int PACK_SHAPES[PACK_SHAPES_COUNT<Architecture, CPU, Scalar, isLhs>][PACK_SHAPES_DIMENSION] = {{1,1,PACK_SHAPES_END},{4,1,0}};
|
||||
|
||||
//template<int Architecture, int CPU, typename Scalar>
|
||||
//constexpr int PACK_SHAPES<Architecture, CPU, Scalar, false>[PACK_SHAPES_COUNT<Architecture, CPU, Scalar, false>][PACK_SHAPES_DIMENSION] = {{1,1,PACK_SHAPES_END},{4,1,0}};
|
||||
|
||||
template<int Architecture, int CPU, typename Index, typename Scalar, bool isLhs, typename DataMapper, bool Conjugate, bool PanelMode, int StorageOrder, int M, int N>
|
||||
struct PackingOperator
|
||||
{
|
||||
@ -816,27 +277,29 @@ struct MicroKernel
|
||||
template<int Architecture, int CPU, typename Index, typename LhsScalar, typename LhsPackMap, typename RhsScalar, typename RhsPackMap, typename AccScalar, typename ResScalar, typename ResPacket, typename DataMapper, int RHS_SHAPE_IDX, int LHS_SHAPE_IDX, int IDX>
|
||||
struct DepthLoopStruct
|
||||
{
|
||||
DepthLoopStruct<Architecture, CPU, Index, LhsScalar, LhsPackMap, RhsScalar, RhsPackMap, AccScalar, ResScalar, ResPacket, DataMapper, RHS_SHAPE_IDX, LHS_SHAPE_IDX, IDX-1> depthLS;
|
||||
static constexpr auto PREVIOUS = SHAPES<Architecture, CPU, LhsScalar, RhsScalar>[IDX][SHAPES_DEP_POINTER];
|
||||
|
||||
DepthLoopStruct<Architecture, CPU, Index, LhsScalar, LhsPackMap, RhsScalar, RhsPackMap, AccScalar, ResScalar, ResPacket, DataMapper, RHS_SHAPE_IDX, LHS_SHAPE_IDX, PREVIOUS> depthLS;
|
||||
|
||||
EIGEN_STRONG_INLINE void operator()(Index rowIdx, Index colIdx, Index depthIdx, const DataMapper& res,
|
||||
Index rows, Index depth, Index cols, ResScalar alpha, const ResPacket& pAlpha, LhsPackMap& lhsPackMap, RhsPackMap& rhsPackMap)
|
||||
{
|
||||
constexpr auto rhsProgress = SHAPES<Architecture, CPU, LhsScalar, RhsScalar>[RHS_SHAPE_IDX][SHAPES_RHS_DIMENSION];
|
||||
constexpr auto lhsProgress = SHAPES<Architecture, CPU, LhsScalar, RhsScalar>[LHS_SHAPE_IDX][SHAPES_LHS_DIMENSION];
|
||||
constexpr auto depthProgress = SHAPES<Architecture, CPU, LhsScalar, RhsScalar>[IDX][SHAPES_DEP_DIMENSION];
|
||||
|
||||
typedef Accumulator<Architecture, CPU, AccScalar, ResScalar, DataMapper, lhsProgress, rhsProgress> AccumulatorType;
|
||||
|
||||
if(rhsProgress == SHAPES<Architecture, CPU, LhsScalar, RhsScalar>[IDX][SHAPES_RHS_DIMENSION] && lhsProgress == SHAPES<Architecture, CPU, LhsScalar, RhsScalar>[IDX][SHAPES_LHS_DIMENSION])
|
||||
MicroKernel<Architecture, CPU, Index, LhsScalar, LhsPackMap, RhsScalar, RhsPackMap, AccScalar, ResScalar, AccumulatorType, lhsProgress, depthProgress, rhsProgress> mkt;
|
||||
AccumulatorType acc;
|
||||
acc.zero();
|
||||
for(; depthIdx + depthProgress <= depth; depthIdx+=depthProgress)
|
||||
{
|
||||
MicroKernel<Architecture, CPU, Index, LhsScalar, LhsPackMap, RhsScalar, RhsPackMap, AccScalar, ResScalar, AccumulatorType, lhsProgress, depthProgress, rhsProgress> mkt;
|
||||
AccumulatorType acc;
|
||||
acc.zero();
|
||||
for(; depthIdx + depthProgress <= depth; depthIdx+=depthProgress)
|
||||
{
|
||||
mkt(lhsPackMap, rhsPackMap, rowIdx, colIdx, depthIdx, acc);
|
||||
}
|
||||
acc.scale(alpha, pAlpha);
|
||||
acc.store(res, rowIdx, colIdx);
|
||||
mkt(lhsPackMap, rhsPackMap, rowIdx, colIdx, depthIdx, acc);
|
||||
}
|
||||
acc.scale(alpha, pAlpha);
|
||||
acc.store(res, rowIdx, colIdx);
|
||||
|
||||
depthLS(rowIdx, colIdx, depthIdx, res, rows, depth, cols, alpha, pAlpha, lhsPackMap, rhsPackMap);
|
||||
}
|
||||
};
|
||||
@ -851,7 +314,9 @@ struct DepthLoopStruct<Architecture, CPU, Index, LhsScalar, LhsPackMap, RhsScala
|
||||
template<int Architecture, int CPU, typename Index, typename LhsScalar, typename LhsPackMap, typename RhsScalar, typename RhsPackMap, typename AccScalar, typename ResScalar, typename ResPacket, typename DataMapper, int RHS_SHAPE_IDX, int IDX>
|
||||
struct LhsLoopStruct
|
||||
{
|
||||
LhsLoopStruct<Architecture, CPU, Index, LhsScalar, LhsPackMap, RhsScalar, RhsPackMap, AccScalar, ResScalar, ResPacket, DataMapper, RHS_SHAPE_IDX, IDX-1> lhsLS;
|
||||
static constexpr auto PREVIOUS = SHAPES<Architecture, CPU, LhsScalar, RhsScalar>[IDX][SHAPES_LHS_POINTER];
|
||||
LhsLoopStruct<Architecture, CPU, Index, LhsScalar, LhsPackMap, RhsScalar, RhsPackMap, AccScalar, ResScalar, ResPacket, DataMapper, RHS_SHAPE_IDX, PREVIOUS> lhsLS;
|
||||
|
||||
EIGEN_STRONG_INLINE void operator()(Index rowIdx, int colIdx, const DataMapper& res,
|
||||
Index rows, Index depth, Index cols, ResScalar alpha, const ResPacket& pAlpha, LhsPackMap& lhsPackMap, RhsPackMap& rhsPackMap)
|
||||
{
|
||||
@ -878,7 +343,7 @@ struct LhsLoopStruct<Architecture, CPU, Index, LhsScalar, LhsPackMap, RhsScalar,
|
||||
template<int Architecture, int CPU, typename Index, typename LhsScalar, typename LhsPackMap, typename RhsScalar, typename RhsPackMap, typename AccScalar, typename ResScalar, typename ResPacket, typename DataMapper, int IDX>
|
||||
struct RhsLoopStruct
|
||||
{
|
||||
static constexpr auto PREVIOUS = SHAPES<Architecture, CPU, LhsScalar, RhsScalar>[IDX][SHAPES_POINTER];
|
||||
static constexpr auto PREVIOUS = SHAPES<Architecture, CPU, LhsScalar, RhsScalar>[IDX][SHAPES_RHS_POINTER];
|
||||
RhsLoopStruct<Architecture, CPU, Index, LhsScalar, LhsPackMap, RhsScalar, RhsPackMap, AccScalar, ResScalar, ResPacket, DataMapper, PREVIOUS> rhsLS;
|
||||
|
||||
EIGEN_STRONG_INLINE void operator()(Index colIdx, const DataMapper& res,
|
||||
|
Loading…
x
Reference in New Issue
Block a user