This commit is contained in:
Everton Constantino 2021-05-14 16:26:33 +00:00
parent ad67705447
commit 5d47f6697d
4 changed files with 108 additions and 73 deletions

View File

@ -16,29 +16,29 @@ namespace internal {
#ifdef __ENABLE_VECTOR_KERNELS__
#define MICRO_12x1x4() \
pLhs = pload<LhsPacket>(lhsPackMap.pCur); \
pLhs2 = pload<LhsPacket>(lhsPackMap.pCur + 4); \
pLhs3 = pload<LhsPacket>(lhsPackMap.pCur + 8); \
pRhs = pload<RhsPacket>(lhsPackMap.pCur);\
pRhs0 = pset1<RhsPacket>(pRhs[0]); \
acc._acc1.packet[0] += pLhs*pRhs0; \
acc._acc2.packet[0] += pLhs2*pRhs0; \
acc._acc3.packet[0] += pLhs3*pRhs0; \
pRhs1 = pset1<RhsPacket>(pRhs[1]); \
acc._acc1.packet[1] += pLhs*pRhs1; \
acc._acc2.packet[1] += pLhs2*pRhs1; \
acc._acc3.packet[1] += pLhs3*pRhs1; \
pRhs2 = pset1<RhsPacket>(pRhs[2]); \
acc._acc1.packet[2] += pLhs*pRhs2; \
acc._acc2.packet[2] += pLhs2*pRhs2; \
acc._acc3.packet[2] += pLhs3*pRhs2; \
pRhs3 = pset1<RhsPacket>(pRhs[3]); \
acc._acc1.packet[3] += pLhs*pRhs3; \
acc._acc2.packet[3] += pLhs2*pRhs3; \
acc._acc3.packet[3] += pLhs3*pRhs3; \
rhsPackMap.advance(4); \
lhsPackMap.advance(12);
#define MICRO_12x1x4(K) \
lhsPackMap.prefetch((3*K + 16)*4); \
rhsPackMap.prefetch((4*K + 16)*1); \
pLhs = pload<LhsPacket>(lhsPackMap.pCur + (0 + 3*K)*4); \
pLhs2 = pload<LhsPacket>(lhsPackMap.pCur + (1 + 3*K)*4); \
pLhs3 = pload<LhsPacket>(lhsPackMap.pCur + (2 + 3*K)*4); \
pRhs = pload<RhsPacket>(lhsPackMap.pCur + (0 + 4*K)*1);\
pRhs0 = pset1<RhsPacket>(pRhs[0]); \
acc._acc1.packet[0] += pLhs*pRhs0; \
acc._acc2.packet[0] += pLhs2*pRhs0; \
acc._acc3.packet[0] += pLhs3*pRhs0; \
pRhs1 = pset1<RhsPacket>(pRhs[1]); \
acc._acc1.packet[1] += pLhs*pRhs1; \
acc._acc2.packet[1] += pLhs2*pRhs1; \
acc._acc3.packet[1] += pLhs3*pRhs1; \
pRhs2 = pset1<RhsPacket>(pRhs[2]); \
acc._acc1.packet[2] += pLhs*pRhs2; \
acc._acc2.packet[2] += pLhs2*pRhs2; \
acc._acc3.packet[2] += pLhs3*pRhs2; \
pRhs3 = pset1<RhsPacket>(pRhs[3]); \
acc._acc1.packet[3] += pLhs*pRhs3; \
acc._acc2.packet[3] += pLhs2*pRhs3; \
acc._acc3.packet[3] += pLhs3*pRhs3;
#define MICRO_8x1x4() \
pLhs = pload<LhsPacket>(lhsPackMap.pCur); \
@ -283,6 +283,21 @@ struct Accumulator<0, CPU, Scalar, ResScalar, DataMapper, 4, 4>
LinearMapper r2 = dest.getLinearMapper(row, col + 2);
LinearMapper r3 = dest.getLinearMapper(row, col + 3);
ResPacket R00 = r0.template loadPacket<ResPacket>(0*PacketSize);
ResPacket R01 = r1.template loadPacket<ResPacket>(0*PacketSize);
ResPacket R02 = r2.template loadPacket<ResPacket>(0*PacketSize);
ResPacket R03 = r3.template loadPacket<ResPacket>(0*PacketSize);
R00 += pAlpha*_acc.packet[0];
R01 += pAlpha*_acc.packet[1];
R02 += pAlpha*_acc.packet[2];
R03 += pAlpha*_acc.packet[3];
r0.storePacket(0*PacketSize, R00);
r1.storePacket(0*PacketSize, R01);
r2.storePacket(0*PacketSize, R02);
r3.storePacket(0*PacketSize, R03);
r0.storePacket(0*PacketSize, r0.template loadPacket<ResPacket>(0*PacketSize) + pAlpha*_acc.packet[0]);
r1.storePacket(0*PacketSize, r1.template loadPacket<ResPacket>(0*PacketSize) + pAlpha*_acc.packet[1]);
r2.storePacket(0*PacketSize, r2.template loadPacket<ResPacket>(0*PacketSize) + pAlpha*_acc.packet[2]);
@ -315,10 +330,11 @@ struct Accumulator<0, CPU, Scalar, ResScalar, DataMapper, 8, 4>
EIGEN_STRONG_INLINE void prefetch(const DataMapper& dest, Index row, Index col)
{
dest.getLinearMapper(row + 0, col + 0).prefetch(0);
dest.getLinearMapper(row + 0, col + 1).prefetch(0);
dest.getLinearMapper(row + 0, col + 2).prefetch(0);
dest.getLinearMapper(row + 0, col + 3).prefetch(0);
constexpr Index offset = 32 / sizeof(ResScalar);
dest.getLinearMapper(row, col + 0).prefetch(offset);
dest.getLinearMapper(row, col + 1).prefetch(offset);
dest.getLinearMapper(row, col + 2).prefetch(offset);
dest.getLinearMapper(row, col + 3).prefetch(offset);
}
template<typename ResPacket_>
@ -345,15 +361,35 @@ struct Accumulator<0, CPU, Scalar, ResScalar, DataMapper, 8, 4>
LinearMapper r2 = dest.getLinearMapper(row, col + 2);
LinearMapper r3 = dest.getLinearMapper(row, col + 3);
r0.storePacket(0*PacketSize, r0.template loadPacket<ResPacket>(0*PacketSize) + pAlpha*_acc1.packet[0]);
r1.storePacket(0*PacketSize, r1.template loadPacket<ResPacket>(0*PacketSize) + pAlpha*_acc1.packet[1]);
r2.storePacket(0*PacketSize, r2.template loadPacket<ResPacket>(0*PacketSize) + pAlpha*_acc1.packet[2]);
r3.storePacket(0*PacketSize, r3.template loadPacket<ResPacket>(0*PacketSize) + pAlpha*_acc1.packet[3]);
ResPacket R00 = r0.template loadPacket<ResPacket>(0*PacketSize);
ResPacket R01 = r1.template loadPacket<ResPacket>(0*PacketSize);
ResPacket R02 = r2.template loadPacket<ResPacket>(0*PacketSize);
ResPacket R03 = r3.template loadPacket<ResPacket>(0*PacketSize);
r0.storePacket(1*PacketSize, r0.template loadPacket<ResPacket>(1*PacketSize) + pAlpha*_acc2.packet[0]);
r1.storePacket(1*PacketSize, r1.template loadPacket<ResPacket>(1*PacketSize) + pAlpha*_acc2.packet[1]);
r2.storePacket(1*PacketSize, r2.template loadPacket<ResPacket>(1*PacketSize) + pAlpha*_acc2.packet[2]);
r3.storePacket(1*PacketSize, r3.template loadPacket<ResPacket>(1*PacketSize) + pAlpha*_acc2.packet[3]);
ResPacket R10 = r0.template loadPacket<ResPacket>(1*PacketSize);
ResPacket R11 = r1.template loadPacket<ResPacket>(1*PacketSize);
ResPacket R12 = r2.template loadPacket<ResPacket>(1*PacketSize);
ResPacket R13 = r3.template loadPacket<ResPacket>(1*PacketSize);
R00 += pAlpha*_acc1.packet[0];
R01 += pAlpha*_acc1.packet[1];
R02 += pAlpha*_acc1.packet[2];
R03 += pAlpha*_acc1.packet[3];
R10 += pAlpha*_acc2.packet[0];
R11 += pAlpha*_acc2.packet[1];
R12 += pAlpha*_acc2.packet[2];
R13 += pAlpha*_acc2.packet[3];
r0.storePacket(0*PacketSize, R00);
r1.storePacket(0*PacketSize, R01);
r2.storePacket(0*PacketSize, R02);
r3.storePacket(0*PacketSize, R03);
r0.storePacket(1*PacketSize, R10);
r1.storePacket(1*PacketSize, R11);
r2.storePacket(1*PacketSize, R12);
r3.storePacket(1*PacketSize, R13);
}
};
@ -519,14 +555,14 @@ struct MicroKernel<0, CPU, Index, LhsScalar, LhsPackMap, RhsScalar, RhsPackMap,
#if __UNROLL__ == 8
#ifdef __ENABLE_PREFETCH__
prefetch(rhsPackMap.pCur + (48+0));
rhsPackMap.prefetch(48+0);
#endif
MICRO_8x1x4();
MICRO_8x1x4();
MICRO_8x1x4();
MICRO_8x1x4();
#ifdef __ENABLE_PREFETCH__
prefetch(rhsPackMap.pCur + (48+16));
rhsPackMap.prefetch(48+16);
#endif
MICRO_8x1x4();
MICRO_8x1x4();
@ -560,21 +596,23 @@ struct MicroKernel<0, CPU, Index, LhsScalar, LhsPackMap, RhsScalar, RhsPackMap,
#if __UNROLL__ == 8
#ifdef __ENABLE_PREFETCH__
prefetch(rhsPackMap.pCur);
rhsPackMap.prefetch(0);
#endif
MICRO_12x1x4();
MICRO_12x1x4();
MICRO_12x1x4();
MICRO_12x1x4();
MICRO_12x1x4();
MICRO_12x1x4();
MICRO_12x1x4();
MICRO_12x1x4();
MICRO_12x1x4(0);
MICRO_12x1x4(1);
MICRO_12x1x4(2);
MICRO_12x1x4(3);
MICRO_12x1x4(4);
MICRO_12x1x4(5);
MICRO_12x1x4(6);
MICRO_12x1x4(7);
lhsPackMap.advance(12*__UNROLL__);
rhsPackMap.advance(4*__UNROLL__);
#else
MICRO_12x1x4();
MICRO_12x1x4();
MICRO_12x1x4();
MICRO_12x1x4();
MICRO_12x1x4(0);
MICRO_12x1x4(1);
MICRO_12x1x4(2);
MICRO_12x1x4(3);
#endif
asm __volatile__("#END_NEON_MICROKERNEL_12x8x4\n\t");
};
@ -596,7 +634,10 @@ struct MicroKernel<0, CPU, Index, LhsScalar, LhsPackMap, RhsScalar, RhsPackMap,
LhsPacket pLhs, pLhs2, pLhs3;
RhsPacket pRhs, pRhs0, pRhs1, pRhs2, pRhs3;
MICRO_12x1x4();
MICRO_12x1x4(0);
lhsPackMap.advance(12);
rhsPackMap.advance(4);
asm __volatile__("#END_NEON_MICROKERNEL_12x1x4\n\t");
};

View File

@ -217,7 +217,8 @@ struct PackMap
EIGEN_STRONG_INLINE void resetCur() { pCur = pBase; }
EIGEN_STRONG_INLINE void updateBase() { pBase = pCur; }
EIGEN_STRONG_INLINE void moveTo(Index p1) { pCur = pBase + pmc.getPosition(p1, d2Size); }
EIGEN_STRONG_INLINE void advance(int progress) { pCur += progress; }
EIGEN_STRONG_INLINE void advance(Index progress) { pCur += progress; }
EIGEN_STRONG_INLINE void prefetch(Index amnt) { internal::prefetch(pCur + amnt); }
};
template<int Architecture, int CPU, typename Scalar, typename ResScalar, typename DataMapper, int M, int N>
@ -319,26 +320,23 @@ struct DepthLoopStruct
constexpr auto lhsProgress = SHAPES<Architecture, CPU, LhsScalar, RhsScalar>[LHS_SHAPE_IDX][SHAPES_LHS_DIMENSION];
constexpr auto depthProgress = SHAPES<Architecture, CPU, LhsScalar, RhsScalar>[IDX][SHAPES_DEP_DIMENSION];
#ifdef __ENABLE_PREFETCH__
prefetch(lhsPackMap.pCur);
prefetch(rhsPackMap.pCur);
#endif
typedef Accumulator<Architecture, CPU, AccScalar, ResScalar, DataMapper, lhsProgress, rhsProgress> AccumulatorType;
MicroKernel<Architecture, CPU, Index, LhsScalar, LhsPackMap, RhsScalar, RhsPackMap, AccScalar, ResScalar, AccumulatorType, lhsProgress, depthProgress, rhsProgress> mkt;
AccumulatorType acc;
acc.zero();
#ifdef __ENABLE_PREFETCH__
lhsPackMap.prefetch(0);
acc.prefetch(res, rowIdx, colIdx);
rhsPackMap.prefetch(0);
#endif
for(; depthIdx + depthProgress <= depth; depthIdx+=depthProgress)
{
mkt(lhsPackMap, rhsPackMap, rowIdx, colIdx, depthIdx, acc);
}
//acc.scale(alpha, pAlpha);
acc.store(res, rowIdx, colIdx, alpha, pAlpha);
depthLS(rowIdx, colIdx, depthIdx, res, rows, depth, cols, alpha, pAlpha, lhsPackMap, rhsPackMap);

View File

@ -51,23 +51,19 @@ int main(int argc, char* argv[])
std::cout << D << std::endl;
#else
if(argc < 3)
if(argc < 5)
{
std::cout << "Wrong number of arguments." << std::endl;
return -1;
}
int sz = std::atoi(argv[1]);
int m = sz, k = sz, n = sz;
int RUNS = std::atoi(argv[2]);
int m = std::atoi(argv[1]), k = std::atoi(argv[2]), n = std::atoi(argv[3]);
int RUNS = std::atoi(argv[4]);
double time = 0;
MatrixXf A = MatrixXf::Random(m,k);
MatrixXf B = MatrixXf::Random(k,n);
for(auto i = 0; i < RUNS; i++)
{
MatrixXf A = MatrixXf::Random(m,k);
MatrixXf B = MatrixXf::Random(k,n);
//set(A,m, k, 1);
//set(B,k, n, 2);
MatrixXf C = MatrixXf::Zero(m, n);
std::clock_t start,end;

12
run.sh
View File

@ -9,13 +9,13 @@ function run() {
for ((i = 0; i < $EXECS; i++)) do
SEL=$(A=$(shuf -i 0-10 -n 1); echo $(($A % 2)))
if [ $SEL -eq 0 ]; then
T_OLD=$(./gto $SIZE $RUNS)
T_NEW=$(./gt $SIZE $RUNS)
T_NEWP=$(./gtp $SIZE $RUNS)
T_OLD=$(./gto $SIZE $SIZE $SIZE $RUNS)
T_NEW=$(./gt $SIZE $SIZE $SIZE $RUNS)
T_NEWP=$(./gtp $SIZE $SIZE $SIZE $RUNS)
else
T_NEW=$(./gt $SIZE $RUNS)
T_NEWP=$(./gtp $SIZE $RUNS)
T_OLD=$(./gto $SIZE $RUNS)
T_NEW=$(./gt $SIZE $SIZE $SIZE $RUNS)
T_NEWP=$(./gtp $SIZE $SIZE $SIZE $RUNS)
T_OLD=$(./gto $SIZE $SIZE $SIZE $RUNS)
fi
NEW=$NEW+$T_NEW
OLD=$OLD+$T_OLD