diff --git a/Eigen/src/Core/arch/NEON/Kernels.h b/Eigen/src/Core/arch/NEON/Kernels.h index 79f19e477..8e24ab659 100644 --- a/Eigen/src/Core/arch/NEON/Kernels.h +++ b/Eigen/src/Core/arch/NEON/Kernels.h @@ -115,6 +115,13 @@ namespace internal { lhsPackMap.advance(4); \ rhsPackMap.advance(1); +#define MICRO_2x1x1() \ + pLhs = pload(lhsPackMap.pCur); \ + pRhs = pset1(*rhsPackMap.pCur); \ + acc._acc += pRhs*pLhs; \ + lhsPackMap.advance(2); \ + rhsPackMap.advance(1); + template struct Accumulator<0, CPU, Scalar, ResScalar, DataMapper, 12, 1> { @@ -225,6 +232,38 @@ struct Accumulator<0, CPU, Scalar, ResScalar, DataMapper, 4, 1> } }; +template +struct Accumulator<0, CPU, Scalar, ResScalar, DataMapper, 2, 1> +{ + using LinearMapper = typename DataMapper::LinearMapper; + using AccPacket = typename packet_traits::half; + using ResPacket = typename packet_traits::half; + + AccPacket _acc; + + EIGEN_STRONG_INLINE void zero() + { + _acc = pset1(0); + } + + template + EIGEN_STRONG_INLINE void prefetch(const DataMapper&, Index, Index) {} + + template + EIGEN_STRONG_INLINE void scale(ResScalar alpha, const ResPacket_& pAlpha) + { + _acc *= pAlpha; + } + + template + EIGEN_STRONG_INLINE void store(const DataMapper& dest, Index row, Index col, ResScalar alpha, const ResPacket_& pAlpha) + { + PacketBlock block; + block.packet[0] = dest.template loadPacket(row, col) + pAlpha*_acc; + dest.template storePacketBlock(row, col, block); + } +}; + template struct Accumulator<0, CPU, Scalar, ResScalar, DataMapper, 1, 4> { @@ -985,6 +1024,28 @@ struct MicroKernel<0, CPU, Index, LhsScalar, LhsPackMap, RhsScalar, RhsPackMap, }; }; +template +struct MicroKernel<0, CPU, Index, LhsScalar, LhsPackMap, RhsScalar, RhsPackMap, AccScalar, ResScalar, Accumulator, 2, 1, 1> +{ + EIGEN_STRONG_INLINE void operator()(LhsPackMap& lhsPackMap, + RhsPackMap& rhsPackMap, + Index rowIdx, Index colIdx, Index depthIdx, + Accumulator& acc) + { + using LhsPacket = typename packet_traits::half; + using RhsPacket = typename packet_traits::half; + + asm __volatile__("#BEGIN_NEON_MICROKERNEL_4x1x1\n\t"); + + LhsPacket pLhs; + RhsPacket pRhs; + + MICRO_2x1x1(); + + asm __volatile__("#END_NEON_MICROKERNEL_4x1x1\n\t"); + }; +}; + template struct MicroKernel<0, CPU, Index, LhsScalar, LhsPackMap, RhsScalar, RhsPackMap, AccScalar, ResScalar, Accumulator, 1, 1, 4> { diff --git a/Eigen/src/Core/arch/NEON/MatrixProduct.h b/Eigen/src/Core/arch/NEON/MatrixProduct.h index 9b37d7098..9b14537ba 100644 --- a/Eigen/src/Core/arch/NEON/MatrixProduct.h +++ b/Eigen/src/Core/arch/NEON/MatrixProduct.h @@ -292,26 +292,6 @@ struct MicroKernel Index rowIdx, Index colIdx, Index depthIdx, Accumulator& acc) { -#ifdef __DEBUG__ - std::cout << "Kernel " << M << " x " << K << " x " << N << " @ " << rowIdx << ", " << depthIdx << ", " << colIdx << std::endl; - std::cout << "LHS "; - for(auto i = 0; i < M; i++) - { - for(auto j = 0; j < K; j++) - { - std::cout << lhsPackMap.pCur[i*K + j] << " "; - } - } - std::cout << std::endl << "RHS "; - for(auto i = 0; i < K; i++) - { - for(auto j = 0; j < N; j++) - { - std::cout << rhsPackMap.pCur[i*N + j] << " "; - } - } - std::cout << std::endl; -#endif const RhsScalar *pRhs = rhsPackMap.pCur; for(auto i = 0; i < N; i++) { @@ -326,26 +306,26 @@ struct MicroKernel }; }; -template +template struct DepthLoopStruct { static constexpr auto PREVIOUS = SHAPES[IDX][SHAPES_DEP_POINTER]; - DepthLoopStruct depthLS; + DepthLoopStruct depthLS; - EIGEN_STRONG_INLINE void operator()(Index rowIdx, Index colIdx, Index depthIdx, const DataMapper& res, + EIGEN_STRONG_INLINE void operator()(Index rowIdx, Index colIdx, Index depthIdx, const DataMapper& res, AccumulatorType& acc, Index rows, Index depth, Index cols, ResScalar alpha, const ResPacket& pAlpha, LhsPackMap& lhsPackMap, RhsPackMap& rhsPackMap) { - constexpr int rhsProgress = SHAPES[RHS_SHAPE_IDX][SHAPES_RHS_DIMENSION]; - constexpr int lhsProgress = SHAPES[LHS_SHAPE_IDX][SHAPES_LHS_DIMENSION]; - constexpr int depthProgress = SHAPES[IDX][SHAPES_DEP_DIMENSION]; + constexpr auto rhsProgress = SHAPES[RHS_SHAPE_IDX][SHAPES_RHS_DIMENSION]; + constexpr auto lhsProgress = SHAPES[LHS_SHAPE_IDX][SHAPES_LHS_DIMENSION]; + constexpr auto depthProgress = SHAPES[IDX][SHAPES_DEP_DIMENSION]; - typedef Accumulator AccumulatorType; + //typedef Accumulator AccumulatorType; MicroKernel mkt; - AccumulatorType acc; + //AccumulatorType acc; - acc.zero(); + //acc.zero(); acc.template prefetch(res, rowIdx, colIdx); @@ -354,18 +334,41 @@ struct DepthLoopStruct for(; depthIdx + depthProgress <= depth; depthIdx+=depthProgress) { +#ifdef __DEBUG__ + auto M = lhsProgress; + auto K = depthProgress; + auto N = rhsProgress; + std::cout << "Kernel " << M << " x " << K << " x " << N << " @ " << rowIdx << ", " << depthIdx << ", " << colIdx << std::endl; + std::cout << "LHS "; + for(auto i = 0; i < M; i++) + { + for(auto j = 0; j < K; j++) + { + std::cout << lhsPackMap.pCur[i*K + j] << " "; + } + } + std::cout << std::endl << "RHS "; + for(auto i = 0; i < K; i++) + { + for(auto j = 0; j < N; j++) + { + std::cout << rhsPackMap.pCur[i*N + j] << " "; + } + } + std::cout << std::endl; +#endif mkt(lhsPackMap, rhsPackMap, rowIdx, colIdx, depthIdx, acc); } - acc.store(res, rowIdx, colIdx, alpha, pAlpha); + //acc.store(res, rowIdx, colIdx, alpha, pAlpha); - depthLS(rowIdx, colIdx, depthIdx, res, rows, depth, cols, alpha, pAlpha, lhsPackMap, rhsPackMap); + depthLS(rowIdx, colIdx, depthIdx, res, acc, rows, depth, cols, alpha, pAlpha, lhsPackMap, rhsPackMap); } }; -template -struct DepthLoopStruct +template +struct DepthLoopStruct { - EIGEN_STRONG_INLINE void operator()(Index, Index, Index, const DataMapper&, + EIGEN_STRONG_INLINE void operator()(Index, Index, Index, const DataMapper&, AccumulatorType&, Index, Index, Index, ResScalar, const ResPacket&, LhsPackMap&, RhsPackMap&) {} }; @@ -380,15 +383,22 @@ struct LhsLoopStruct { constexpr auto lhsProgress = SHAPES[IDX][SHAPES_LHS_DIMENSION]; constexpr auto rhsProgress = SHAPES[IDX][SHAPES_RHS_DIMENSION]; - DepthLoopStruct depthLS; + + typedef Accumulator AccumulatorType; + + DepthLoopStruct depthLS; + //rhsPackMap.resetCur(); for(;rowIdx + lhsProgress <= rows; rowIdx+=lhsProgress) { rhsPackMap.resetCur(); + AccumulatorType acc; + acc.zero(); //lhsPackMap.moveTo(rowIdx); //rhsPackMap.moveTo(colIdx); - depthLS(rowIdx, colIdx, 0, res, rows, depth, cols, alpha, pAlpha, lhsPackMap, rhsPackMap); + depthLS(rowIdx, colIdx, 0, res, acc, rows, depth, cols, alpha, pAlpha, lhsPackMap, rhsPackMap); + acc.store(res, rowIdx, colIdx, alpha, pAlpha); } lhsLS(rowIdx, colIdx, res, rows, depth, cols, alpha, pAlpha, lhsPackMap, rhsPackMap); } diff --git a/new_gemm_test.cpp b/new_gemm_test.cpp index b936924ca..9978d3a38 100644 --- a/new_gemm_test.cpp +++ b/new_gemm_test.cpp @@ -28,9 +28,14 @@ int main(int argc, char* argv[]) for(auto i = 0; i < 2; i++) C = A*B; +#ifdef __DEBUG_SHOW_INPUTS__ std::cout << A << std::endl; std::cout << B << std::endl; +#endif + +#ifdef __DEBUG_SHOW_RESULT__ std::cout << C << std::endl; +#endif std::cout << std::endl; @@ -50,8 +55,9 @@ int main(int argc, char* argv[]) } } } - +#ifdef __DEBUG_SHOW_RESULT__ std::cout << D << std::endl; +#endif #else if(argc < 5) {