diff --git a/Eigen/src/Core/arch/NEON/Kernels.h b/Eigen/src/Core/arch/NEON/Kernels.h index 0eb78f1b1..65f23a9b6 100644 --- a/Eigen/src/Core/arch/NEON/Kernels.h +++ b/Eigen/src/Core/arch/NEON/Kernels.h @@ -14,6 +14,63 @@ namespace Eigen { namespace internal { +template +struct Accumulator<0, CPU, Scalar, ResScalar, DataMapper, 4, 1> +{ + using LinearMapper = typename DataMapper::LinearMapper; + using AccPacket = typename packet_traits::type; + using ResPacket = typename packet_traits::type; + + AccPacket _acc; + + EIGEN_STRONG_INLINE void zero() + { + _acc = pset1(0); + } + + template + EIGEN_STRONG_INLINE void scale(ResScalar alpha, const ResPacket_& pAlpha) + { + _acc *= pAlpha; + } + + EIGEN_STRONG_INLINE void store(const DataMapper& dest, Index row, Index col) + { + LinearMapper r0 = dest.getLinearMapper(row, col + 0); + + r0.storePacket(0, r0.template loadPacket(0) + _acc); + } +}; + +template +struct Accumulator<0, CPU, Scalar, ResScalar, DataMapper, 1, 4> +{ + using LinearMapper = typename DataMapper::LinearMapper; + using AccPacket = typename packet_traits::type; + using ResPacket = typename packet_traits::type; + + AccPacket _acc; + + EIGEN_STRONG_INLINE void zero() + { + _acc = pset1(0); + } + + template + EIGEN_STRONG_INLINE void scale(ResScalar alpha, const ResPacket_& pAlpha) + { + _acc *= pAlpha; + } + + EIGEN_STRONG_INLINE void store(const DataMapper& dest, Index row, Index col) + { + dest(row, col + 0) = _acc[0]; + dest(row, col + 1) = _acc[1]; + dest(row, col + 2) = _acc[2]; + dest(row, col + 3) = _acc[3]; + } +}; + template struct Accumulator<0, CPU, Scalar, ResScalar, DataMapper, 4, 4> { @@ -65,7 +122,7 @@ struct MicroKernel<0, CPU, Index, LhsScalar, LhsPackMap, RhsScalar, RhsPackMap, using LhsPacket = typename packet_traits::type; using RhsPacket = typename packet_traits::type; - asm __volatile__("#BEGIN_NEON_MICROKERNEL_3x1\n\t"); + asm __volatile__("#BEGIN_NEON_MICROKERNEL_4x1x4\n\t"); LhsPacket pLhs = pload(lhsPackMap.pCur); RhsPacket pRhs = pload(rhsPackMap.pCur); RhsPacket pRhs0 = pset1(pRhs[0]); @@ -80,7 +137,53 @@ struct MicroKernel<0, CPU, Index, LhsScalar, LhsPackMap, RhsScalar, RhsPackMap, lhsPackMap.advance(4*1); rhsPackMap.advance(1*4); - asm __volatile__("#END_NEON_MICROKERNEL_3x1\n\t"); + asm __volatile__("#END_NEON_MICROKERNEL_4x1x4\n\t"); + }; +}; + +template +struct MicroKernel<0, CPU, Index, LhsScalar, LhsPackMap, RhsScalar, RhsPackMap, AccScalar, ResScalar, Accumulator, 4, 1, 1> +{ + EIGEN_STRONG_INLINE void operator()(LhsPackMap& lhsPackMap, + RhsPackMap& rhsPackMap, + Index rowIdx, Index colIdx, Index depthIdx, + Accumulator& acc) + { + using LhsPacket = typename packet_traits::type; + + asm __volatile__("#BEGIN_NEON_MICROKERNEL_4x1x1\n\t"); + + LhsPacket pLhs = pload(lhsPackMap.pCur); + RhsScalar rhs = *rhsPackMap.pCur; + + acc._acc += pLhs*rhs; + + lhsPackMap.advance(4*1); + rhsPackMap.advance(1); + asm __volatile__("#END_NEON_MICROKERNEL_4x1x1\n\t"); + }; +}; + +template +struct MicroKernel<0, CPU, Index, LhsScalar, LhsPackMap, RhsScalar, RhsPackMap, AccScalar, ResScalar, Accumulator, 1, 1, 4> +{ + EIGEN_STRONG_INLINE void operator()(LhsPackMap& lhsPackMap, + RhsPackMap& rhsPackMap, + Index rowIdx, Index colIdx, Index depthIdx, + Accumulator& acc) + { + using RhsPacket = typename packet_traits::type; + + asm __volatile__("#BEGIN_NEON_MICROKERNEL_1x1x4\n\t"); + + RhsPacket pRhs = pload(rhsPackMap.pCur); + LhsScalar lhs = *lhsPackMap.pCur; + + acc._acc += pRhs*lhs; + + lhsPackMap.advance(1); + rhsPackMap.advance(4*1); + asm __volatile__("#END_NEON_MICROKERNEL_1x1x4\n\t"); }; }; diff --git a/Eigen/src/Core/arch/NEON/MatrixProduct.h b/Eigen/src/Core/arch/NEON/MatrixProduct.h index 85c2d05f3..1ff0a17bd 100644 --- a/Eigen/src/Core/arch/NEON/MatrixProduct.h +++ b/Eigen/src/Core/arch/NEON/MatrixProduct.h @@ -928,6 +928,7 @@ EIGEN_STRONG_INLINE void gemm(const DataMapper& res, const LhsScalar* blockA, co } std::cout << std::endl; #endif + asm __volatile__("#BEGING_GEBP\n\t"); RhsLoopStruct-1> rhsLS; LhsPackMap lhsPackMap(blockA, depth, strideA, offsetA); RhsPackMap rhsPackMap(blockB, depth, strideB, offsetB); @@ -935,6 +936,7 @@ EIGEN_STRONG_INLINE void gemm(const DataMapper& res, const LhsScalar* blockA, co ResPacket pAlpha = pset1(alpha); rhsLS(0, res, rows, depth, cols, alpha, pAlpha, lhsPackMap, rhsPackMap); + asm __volatile__("#END_GEBP\n\t"); } template