From 5d47f6697df33aa12f53dfb606b0ff81494e853f Mon Sep 17 00:00:00 2001 From: Everton Constantino Date: Fri, 14 May 2021 16:26:33 +0000 Subject: [PATCH] WIP 2 --- Eigen/src/Core/arch/NEON/Kernels.h | 143 +++++++++++++++-------- Eigen/src/Core/arch/NEON/MatrixProduct.h | 12 +- new_gemm_test.cpp | 14 +-- run.sh | 12 +- 4 files changed, 108 insertions(+), 73 deletions(-) diff --git a/Eigen/src/Core/arch/NEON/Kernels.h b/Eigen/src/Core/arch/NEON/Kernels.h index b7a673568..b69834070 100644 --- a/Eigen/src/Core/arch/NEON/Kernels.h +++ b/Eigen/src/Core/arch/NEON/Kernels.h @@ -16,29 +16,29 @@ namespace internal { #ifdef __ENABLE_VECTOR_KERNELS__ -#define MICRO_12x1x4() \ - pLhs = pload(lhsPackMap.pCur); \ - pLhs2 = pload(lhsPackMap.pCur + 4); \ - pLhs3 = pload(lhsPackMap.pCur + 8); \ - pRhs = pload(lhsPackMap.pCur);\ - pRhs0 = pset1(pRhs[0]); \ - acc._acc1.packet[0] += pLhs*pRhs0; \ - acc._acc2.packet[0] += pLhs2*pRhs0; \ - acc._acc3.packet[0] += pLhs3*pRhs0; \ - pRhs1 = pset1(pRhs[1]); \ - acc._acc1.packet[1] += pLhs*pRhs1; \ - acc._acc2.packet[1] += pLhs2*pRhs1; \ - acc._acc3.packet[1] += pLhs3*pRhs1; \ - pRhs2 = pset1(pRhs[2]); \ - acc._acc1.packet[2] += pLhs*pRhs2; \ - acc._acc2.packet[2] += pLhs2*pRhs2; \ - acc._acc3.packet[2] += pLhs3*pRhs2; \ - pRhs3 = pset1(pRhs[3]); \ - acc._acc1.packet[3] += pLhs*pRhs3; \ - acc._acc2.packet[3] += pLhs2*pRhs3; \ - acc._acc3.packet[3] += pLhs3*pRhs3; \ - rhsPackMap.advance(4); \ - lhsPackMap.advance(12); +#define MICRO_12x1x4(K) \ + lhsPackMap.prefetch((3*K + 16)*4); \ + rhsPackMap.prefetch((4*K + 16)*1); \ + pLhs = pload(lhsPackMap.pCur + (0 + 3*K)*4); \ + pLhs2 = pload(lhsPackMap.pCur + (1 + 3*K)*4); \ + pLhs3 = pload(lhsPackMap.pCur + (2 + 3*K)*4); \ + pRhs = pload(lhsPackMap.pCur + (0 + 4*K)*1);\ + pRhs0 = pset1(pRhs[0]); \ + acc._acc1.packet[0] += pLhs*pRhs0; \ + acc._acc2.packet[0] += pLhs2*pRhs0; \ + acc._acc3.packet[0] += pLhs3*pRhs0; \ + pRhs1 = pset1(pRhs[1]); \ + acc._acc1.packet[1] += pLhs*pRhs1; \ + acc._acc2.packet[1] += pLhs2*pRhs1; \ + acc._acc3.packet[1] += pLhs3*pRhs1; \ + pRhs2 = pset1(pRhs[2]); \ + acc._acc1.packet[2] += pLhs*pRhs2; \ + acc._acc2.packet[2] += pLhs2*pRhs2; \ + acc._acc3.packet[2] += pLhs3*pRhs2; \ + pRhs3 = pset1(pRhs[3]); \ + acc._acc1.packet[3] += pLhs*pRhs3; \ + acc._acc2.packet[3] += pLhs2*pRhs3; \ + acc._acc3.packet[3] += pLhs3*pRhs3; #define MICRO_8x1x4() \ pLhs = pload(lhsPackMap.pCur); \ @@ -283,6 +283,21 @@ struct Accumulator<0, CPU, Scalar, ResScalar, DataMapper, 4, 4> LinearMapper r2 = dest.getLinearMapper(row, col + 2); LinearMapper r3 = dest.getLinearMapper(row, col + 3); + ResPacket R00 = r0.template loadPacket(0*PacketSize); + ResPacket R01 = r1.template loadPacket(0*PacketSize); + ResPacket R02 = r2.template loadPacket(0*PacketSize); + ResPacket R03 = r3.template loadPacket(0*PacketSize); + + R00 += pAlpha*_acc.packet[0]; + R01 += pAlpha*_acc.packet[1]; + R02 += pAlpha*_acc.packet[2]; + R03 += pAlpha*_acc.packet[3]; + + r0.storePacket(0*PacketSize, R00); + r1.storePacket(0*PacketSize, R01); + r2.storePacket(0*PacketSize, R02); + r3.storePacket(0*PacketSize, R03); + r0.storePacket(0*PacketSize, r0.template loadPacket(0*PacketSize) + pAlpha*_acc.packet[0]); r1.storePacket(0*PacketSize, r1.template loadPacket(0*PacketSize) + pAlpha*_acc.packet[1]); r2.storePacket(0*PacketSize, r2.template loadPacket(0*PacketSize) + pAlpha*_acc.packet[2]); @@ -315,10 +330,11 @@ struct Accumulator<0, CPU, Scalar, ResScalar, DataMapper, 8, 4> EIGEN_STRONG_INLINE void prefetch(const DataMapper& dest, Index row, Index col) { - dest.getLinearMapper(row + 0, col + 0).prefetch(0); - dest.getLinearMapper(row + 0, col + 1).prefetch(0); - dest.getLinearMapper(row + 0, col + 2).prefetch(0); - dest.getLinearMapper(row + 0, col + 3).prefetch(0); + constexpr Index offset = 32 / sizeof(ResScalar); + dest.getLinearMapper(row, col + 0).prefetch(offset); + dest.getLinearMapper(row, col + 1).prefetch(offset); + dest.getLinearMapper(row, col + 2).prefetch(offset); + dest.getLinearMapper(row, col + 3).prefetch(offset); } template @@ -345,15 +361,35 @@ struct Accumulator<0, CPU, Scalar, ResScalar, DataMapper, 8, 4> LinearMapper r2 = dest.getLinearMapper(row, col + 2); LinearMapper r3 = dest.getLinearMapper(row, col + 3); - r0.storePacket(0*PacketSize, r0.template loadPacket(0*PacketSize) + pAlpha*_acc1.packet[0]); - r1.storePacket(0*PacketSize, r1.template loadPacket(0*PacketSize) + pAlpha*_acc1.packet[1]); - r2.storePacket(0*PacketSize, r2.template loadPacket(0*PacketSize) + pAlpha*_acc1.packet[2]); - r3.storePacket(0*PacketSize, r3.template loadPacket(0*PacketSize) + pAlpha*_acc1.packet[3]); + ResPacket R00 = r0.template loadPacket(0*PacketSize); + ResPacket R01 = r1.template loadPacket(0*PacketSize); + ResPacket R02 = r2.template loadPacket(0*PacketSize); + ResPacket R03 = r3.template loadPacket(0*PacketSize); - r0.storePacket(1*PacketSize, r0.template loadPacket(1*PacketSize) + pAlpha*_acc2.packet[0]); - r1.storePacket(1*PacketSize, r1.template loadPacket(1*PacketSize) + pAlpha*_acc2.packet[1]); - r2.storePacket(1*PacketSize, r2.template loadPacket(1*PacketSize) + pAlpha*_acc2.packet[2]); - r3.storePacket(1*PacketSize, r3.template loadPacket(1*PacketSize) + pAlpha*_acc2.packet[3]); + ResPacket R10 = r0.template loadPacket(1*PacketSize); + ResPacket R11 = r1.template loadPacket(1*PacketSize); + ResPacket R12 = r2.template loadPacket(1*PacketSize); + ResPacket R13 = r3.template loadPacket(1*PacketSize); + + R00 += pAlpha*_acc1.packet[0]; + R01 += pAlpha*_acc1.packet[1]; + R02 += pAlpha*_acc1.packet[2]; + R03 += pAlpha*_acc1.packet[3]; + + R10 += pAlpha*_acc2.packet[0]; + R11 += pAlpha*_acc2.packet[1]; + R12 += pAlpha*_acc2.packet[2]; + R13 += pAlpha*_acc2.packet[3]; + + r0.storePacket(0*PacketSize, R00); + r1.storePacket(0*PacketSize, R01); + r2.storePacket(0*PacketSize, R02); + r3.storePacket(0*PacketSize, R03); + + r0.storePacket(1*PacketSize, R10); + r1.storePacket(1*PacketSize, R11); + r2.storePacket(1*PacketSize, R12); + r3.storePacket(1*PacketSize, R13); } }; @@ -519,14 +555,14 @@ struct MicroKernel<0, CPU, Index, LhsScalar, LhsPackMap, RhsScalar, RhsPackMap, #if __UNROLL__ == 8 #ifdef __ENABLE_PREFETCH__ - prefetch(rhsPackMap.pCur + (48+0)); + rhsPackMap.prefetch(48+0); #endif MICRO_8x1x4(); MICRO_8x1x4(); MICRO_8x1x4(); MICRO_8x1x4(); #ifdef __ENABLE_PREFETCH__ - prefetch(rhsPackMap.pCur + (48+16)); + rhsPackMap.prefetch(48+16); #endif MICRO_8x1x4(); MICRO_8x1x4(); @@ -560,21 +596,23 @@ struct MicroKernel<0, CPU, Index, LhsScalar, LhsPackMap, RhsScalar, RhsPackMap, #if __UNROLL__ == 8 #ifdef __ENABLE_PREFETCH__ - prefetch(rhsPackMap.pCur); + rhsPackMap.prefetch(0); #endif - MICRO_12x1x4(); - MICRO_12x1x4(); - MICRO_12x1x4(); - MICRO_12x1x4(); - MICRO_12x1x4(); - MICRO_12x1x4(); - MICRO_12x1x4(); - MICRO_12x1x4(); + MICRO_12x1x4(0); + MICRO_12x1x4(1); + MICRO_12x1x4(2); + MICRO_12x1x4(3); + MICRO_12x1x4(4); + MICRO_12x1x4(5); + MICRO_12x1x4(6); + MICRO_12x1x4(7); + lhsPackMap.advance(12*__UNROLL__); + rhsPackMap.advance(4*__UNROLL__); #else - MICRO_12x1x4(); - MICRO_12x1x4(); - MICRO_12x1x4(); - MICRO_12x1x4(); + MICRO_12x1x4(0); + MICRO_12x1x4(1); + MICRO_12x1x4(2); + MICRO_12x1x4(3); #endif asm __volatile__("#END_NEON_MICROKERNEL_12x8x4\n\t"); }; @@ -596,7 +634,10 @@ struct MicroKernel<0, CPU, Index, LhsScalar, LhsPackMap, RhsScalar, RhsPackMap, LhsPacket pLhs, pLhs2, pLhs3; RhsPacket pRhs, pRhs0, pRhs1, pRhs2, pRhs3; - MICRO_12x1x4(); + MICRO_12x1x4(0); + + lhsPackMap.advance(12); + rhsPackMap.advance(4); asm __volatile__("#END_NEON_MICROKERNEL_12x1x4\n\t"); }; diff --git a/Eigen/src/Core/arch/NEON/MatrixProduct.h b/Eigen/src/Core/arch/NEON/MatrixProduct.h index f68699120..16e9f0114 100644 --- a/Eigen/src/Core/arch/NEON/MatrixProduct.h +++ b/Eigen/src/Core/arch/NEON/MatrixProduct.h @@ -217,7 +217,8 @@ struct PackMap EIGEN_STRONG_INLINE void resetCur() { pCur = pBase; } EIGEN_STRONG_INLINE void updateBase() { pBase = pCur; } EIGEN_STRONG_INLINE void moveTo(Index p1) { pCur = pBase + pmc.getPosition(p1, d2Size); } - EIGEN_STRONG_INLINE void advance(int progress) { pCur += progress; } + EIGEN_STRONG_INLINE void advance(Index progress) { pCur += progress; } + EIGEN_STRONG_INLINE void prefetch(Index amnt) { internal::prefetch(pCur + amnt); } }; template @@ -319,26 +320,23 @@ struct DepthLoopStruct constexpr auto lhsProgress = SHAPES[LHS_SHAPE_IDX][SHAPES_LHS_DIMENSION]; constexpr auto depthProgress = SHAPES[IDX][SHAPES_DEP_DIMENSION]; -#ifdef __ENABLE_PREFETCH__ - prefetch(lhsPackMap.pCur); - prefetch(rhsPackMap.pCur); -#endif - typedef Accumulator AccumulatorType; MicroKernel mkt; AccumulatorType acc; + acc.zero(); #ifdef __ENABLE_PREFETCH__ + lhsPackMap.prefetch(0); acc.prefetch(res, rowIdx, colIdx); + rhsPackMap.prefetch(0); #endif for(; depthIdx + depthProgress <= depth; depthIdx+=depthProgress) { mkt(lhsPackMap, rhsPackMap, rowIdx, colIdx, depthIdx, acc); } - //acc.scale(alpha, pAlpha); acc.store(res, rowIdx, colIdx, alpha, pAlpha); depthLS(rowIdx, colIdx, depthIdx, res, rows, depth, cols, alpha, pAlpha, lhsPackMap, rhsPackMap); diff --git a/new_gemm_test.cpp b/new_gemm_test.cpp index 1524b36d9..42d60d8cf 100644 --- a/new_gemm_test.cpp +++ b/new_gemm_test.cpp @@ -51,23 +51,19 @@ int main(int argc, char* argv[]) std::cout << D << std::endl; #else - if(argc < 3) + if(argc < 5) { std::cout << "Wrong number of arguments." << std::endl; return -1; } - - int sz = std::atoi(argv[1]); - int m = sz, k = sz, n = sz; - int RUNS = std::atoi(argv[2]); + int m = std::atoi(argv[1]), k = std::atoi(argv[2]), n = std::atoi(argv[3]); + int RUNS = std::atoi(argv[4]); double time = 0; + MatrixXf A = MatrixXf::Random(m,k); + MatrixXf B = MatrixXf::Random(k,n); for(auto i = 0; i < RUNS; i++) { - MatrixXf A = MatrixXf::Random(m,k); - MatrixXf B = MatrixXf::Random(k,n); - //set(A,m, k, 1); - //set(B,k, n, 2); MatrixXf C = MatrixXf::Zero(m, n); std::clock_t start,end; diff --git a/run.sh b/run.sh index 063aee907..80633c689 100755 --- a/run.sh +++ b/run.sh @@ -9,13 +9,13 @@ function run() { for ((i = 0; i < $EXECS; i++)) do SEL=$(A=$(shuf -i 0-10 -n 1); echo $(($A % 2))) if [ $SEL -eq 0 ]; then - T_OLD=$(./gto $SIZE $RUNS) - T_NEW=$(./gt $SIZE $RUNS) - T_NEWP=$(./gtp $SIZE $RUNS) + T_OLD=$(./gto $SIZE $SIZE $SIZE $RUNS) + T_NEW=$(./gt $SIZE $SIZE $SIZE $RUNS) + T_NEWP=$(./gtp $SIZE $SIZE $SIZE $RUNS) else - T_NEW=$(./gt $SIZE $RUNS) - T_NEWP=$(./gtp $SIZE $RUNS) - T_OLD=$(./gto $SIZE $RUNS) + T_NEW=$(./gt $SIZE $SIZE $SIZE $RUNS) + T_NEWP=$(./gtp $SIZE $SIZE $SIZE $RUNS) + T_OLD=$(./gto $SIZE $SIZE $SIZE $RUNS) fi NEW=$NEW+$T_NEW OLD=$OLD+$T_OLD