This commit is contained in:
Everton Constantino 2021-05-13 18:12:52 +00:00
parent 58db05afbc
commit 3999ab2dc7
3 changed files with 26 additions and 22 deletions

View File

@ -17,30 +17,28 @@ namespace internal {
#ifdef __ENABLE_VECTOR_KERNELS__
#define MICRO_12x1x4() \
pRhs = pload<RhsPacket>(rhsPackMap.pCur); \
rhsPackMap.advance(1*4); \
pRhs0 = pset1<RhsPacket>(pRhs[0]); \
pRhs1 = pset1<RhsPacket>(pRhs[1]); \
pRhs2 = pset1<RhsPacket>(pRhs[2]); \
pRhs3 = pset1<RhsPacket>(pRhs[3]); \
pLhs = pload<LhsPacket>(lhsPackMap.pCur); \
lhsPackMap.advance(4*1); \
pLhs2 = pload<LhsPacket>(lhsPackMap.pCur + 4); \
pLhs3 = pload<LhsPacket>(lhsPackMap.pCur + 8); \
pRhs = pload<RhsPacket>(lhsPackMap.pCur);\
pRhs0 = pset1<RhsPacket>(pRhs[0]); \
acc._acc1.packet[0] += pLhs*pRhs0; \
acc._acc1.packet[1] += pLhs*pRhs1; \
acc._acc1.packet[2] += pLhs*pRhs2; \
acc._acc1.packet[3] += pLhs*pRhs3; \
pLhs2 = pload<LhsPacket>(lhsPackMap.pCur); \
lhsPackMap.advance(4*1); \
acc._acc2.packet[0] += pLhs2*pRhs0; \
acc._acc2.packet[1] += pLhs2*pRhs1; \
acc._acc2.packet[2] += pLhs2*pRhs2; \
acc._acc2.packet[3] += pLhs2*pRhs3; \
pLhs3 = pload<LhsPacket>(lhsPackMap.pCur); \
acc._acc3.packet[0] += pLhs3*pRhs0; \
pRhs1 = pset1<RhsPacket>(pRhs[1]); \
acc._acc1.packet[1] += pLhs*pRhs1; \
acc._acc2.packet[1] += pLhs2*pRhs1; \
acc._acc3.packet[1] += pLhs3*pRhs1; \
pRhs2 = pset1<RhsPacket>(pRhs[2]); \
acc._acc1.packet[2] += pLhs*pRhs2; \
acc._acc2.packet[2] += pLhs2*pRhs2; \
acc._acc3.packet[2] += pLhs3*pRhs2; \
pRhs3 = pset1<RhsPacket>(pRhs[3]); \
acc._acc1.packet[3] += pLhs*pRhs3; \
acc._acc2.packet[3] += pLhs2*pRhs3; \
acc._acc3.packet[3] += pLhs3*pRhs3; \
lhsPackMap.advance(4*1);
rhsPackMap.advance(4); \
lhsPackMap.advance(12);
#define MICRO_8x1x4() \
pLhs = pload<LhsPacket>(lhsPackMap.pCur); \

View File

@ -215,6 +215,7 @@ struct PackMap
PackMap(const Scalar *base, Index d2Size, Index stride, Index offset) : pBase(base), pCur(base), d2Size(d2Size), stride(stride), offset(offset) {}
EIGEN_STRONG_INLINE void resetCur() { pCur = pBase; }
EIGEN_STRONG_INLINE void updateBase() { pBase = pCur; }
EIGEN_STRONG_INLINE void moveTo(Index p1) { pCur = pBase + pmc.getPosition(p1, d2Size); }
EIGEN_STRONG_INLINE void advance(int progress) { pCur += progress; }
};
@ -362,12 +363,12 @@ struct LhsLoopStruct
constexpr auto lhsProgress = SHAPES<Architecture, CPU, LhsScalar, RhsScalar>[IDX][SHAPES_LHS_DIMENSION];
constexpr auto rhsProgress = SHAPES<Architecture, CPU, LhsScalar, RhsScalar>[IDX][SHAPES_RHS_DIMENSION];
DepthLoopStruct<Architecture, CPU, Index, LhsScalar, LhsPackMap, RhsScalar, RhsPackMap, AccScalar, ResScalar, ResPacket, DataMapper, RHS_SHAPE_IDX, IDX, IDX> depthLS;
rhsPackMap.resetCur();
for(;rowIdx + lhsProgress <= rows; rowIdx+=lhsProgress)
{
lhsPackMap.moveTo(rowIdx);
rhsPackMap.moveTo(colIdx);
//prefetch(lhsPackMap.pCur + 2*lhsProgress);
//prefetch(rhsPackMap.pCur + 2*rhsProgress);
//lhsPackMap.moveTo(rowIdx);
//rhsPackMap.moveTo(colIdx);
depthLS(rowIdx, colIdx, 0, res, rows, depth, cols, alpha, pAlpha, lhsPackMap, rhsPackMap);
}
lhsLS(rowIdx, colIdx, res, rows, depth, cols, alpha, pAlpha, lhsPackMap, rhsPackMap);
@ -395,7 +396,9 @@ struct RhsLoopStruct
for(;colIdx + rhsProgress <= cols; colIdx+=rhsProgress)
{
LhsLoopStruct<Architecture, CPU, Index, LhsScalar, LhsPackMap, RhsScalar, RhsPackMap, AccScalar, ResScalar, ResPacket, DataMapper, IDX, IDX> lhsLS;
lhsPackMap.resetCur();
lhsLS(0, colIdx, res, rows, depth, cols, alpha, pAlpha, lhsPackMap, rhsPackMap);
rhsPackMap.updateBase();
}
rhsLS(colIdx, res, rows, depth, cols, alpha, pAlpha, lhsPackMap, rhsPackMap);
}

View File

@ -15,7 +15,7 @@ void set(MatrixXf& A, int m, int n, int id, int digits)
int main(int argc, char* argv[])
{
#ifdef __DEBUG__
int m = 32, k = 32, n = 32, max = std::max(std::max(m,k),n);
int m = 9, k = 9, n = 9, max = std::max(std::max(m,k),n);
MatrixXf A = MatrixXf::Zero(m, k);
MatrixXf B = MatrixXf::Zero(k, n);
MatrixXf C = MatrixXf::Zero(m, n);
@ -28,6 +28,7 @@ int main(int argc, char* argv[])
std::cout << A << std::endl;
std::cout << B << std::endl;
std::cout << C << std::endl;
std::cout << std::endl;
@ -47,6 +48,8 @@ int main(int argc, char* argv[])
}
}
}
std::cout << D << std::endl;
#else
if(argc < 3)
{