From 421891e1db97bd40d961f88536f43012906cedbf Mon Sep 17 00:00:00 2001 From: Everton Constantino Date: Wed, 21 Apr 2021 17:58:55 +0000 Subject: [PATCH] WIP --- Eigen/src/Core/arch/NEON/MatrixProduct.h | 196 +++++++++++++++++------ 1 file changed, 148 insertions(+), 48 deletions(-) diff --git a/Eigen/src/Core/arch/NEON/MatrixProduct.h b/Eigen/src/Core/arch/NEON/MatrixProduct.h index 14370ede3..3e8d7aa41 100644 --- a/Eigen/src/Core/arch/NEON/MatrixProduct.h +++ b/Eigen/src/Core/arch/NEON/MatrixProduct.h @@ -560,7 +560,7 @@ EIGEN_STRONG_INLINE void gemm_old(const DataMapper& res, const LhsScalar* blockA #endif template -constexpr int SHAPES_COUNT = 3; +constexpr int SHAPES_COUNT = 2; constexpr int SHAPES_DIMENSION = 4; constexpr int SHAPES_LHS_DIMENSION = 0; @@ -570,7 +570,7 @@ constexpr int SHAPES_POINTER = 3; constexpr int SHAPES_POINTER_END = -1; template -constexpr int PACK_SHAPES_COUNT = 3; +constexpr int PACK_SHAPES_COUNT = 2; constexpr int PACK_SHAPES_DIMENSION = 3; constexpr int PACK_SHAPES_POINTER = 2; constexpr int PACK_SHAPES_END = -1; @@ -578,66 +578,71 @@ constexpr int PACK_SHAPES_END = -1; // lhs_progress x depth_progress x rhs_progress (depth_progress > 1 matrix ops) x pointer to next rhs_progress on the shapes map template -constexpr int SHAPES[SHAPES_COUNT][SHAPES_DIMENSION] = {{1,1,1,SHAPES_POINTER_END},{4,1,4,0},{8,1,8,1}}; -//constexpr int SHAPES[SHAPES_COUNT][SHAPES_DIMENSION] = {{1,1,1,SHAPES_POINTER_END},{2,1,1,0},{1,1,2,1},{2,1,2,1},{2,2,2,1}}; +constexpr int SHAPES[SHAPES_COUNT][SHAPES_DIMENSION] = {{1,1,1,SHAPES_POINTER_END},{4,1,4,0}}; // d1progress x d2progress template -constexpr int PACK_SHAPES[PACK_SHAPES_COUNT][PACK_SHAPES_DIMENSION] = {{1,1,PACK_SHAPES_END},{4,1,0},{8,1,1}}; +constexpr int PACK_SHAPES[PACK_SHAPES_COUNT][PACK_SHAPES_DIMENSION] = {{1,1,PACK_SHAPES_END},{4,1,0}}; template -constexpr int PACK_SHAPES[PACK_SHAPES_COUNT][PACK_SHAPES_DIMENSION] = {{1,1,PACK_SHAPES_END},{4,1,0},{8,1,1}}; +constexpr int PACK_SHAPES[PACK_SHAPES_COUNT][PACK_SHAPES_DIMENSION] = {{1,1,PACK_SHAPES_END},{4,1,0}}; template struct PackingOperator { - EIGEN_STRONG_INLINE void operator()(Index d1Idx, Index d2Idx, Scalar **block, const DataMapper& data) + EIGEN_STRONG_INLINE Scalar* operator()(Index d1Idx, Index d2Idx, Scalar *block, const DataMapper& data) { - std::cout << M << "x" << N << " ( " << d1Idx << ", " << d2Idx <<") -> ( " << d1Idx + M << ", " << d2Idx + N << ")" << std::endl; - Scalar *c = *block; + std::cout << M << "x" << N << " ( " << d1Idx << ", " << d2Idx <<") -> ( " << d1Idx + M << ", " << d2Idx + N << ") "; + Scalar *c = block; for(auto i = 0; i < M; i++) for(auto j = 0; j < N; j++) { - *c = data(d1Idx + i, d2Idx + j); + if(isLhs) + *c = data(d1Idx + i, d2Idx + j); + else + *c = data(d2Idx + j, d1Idx + i); + std::cout << *c << " "; c++; } - - *block = c; + std::cout << std::endl; + return c; } }; template struct PackingInnerStruct { - EIGEN_STRONG_INLINE void operator()(Index d1Idx, Index d2Idx, Scalar *block, const DataMapper& data, Index d1Size, Index d2Size, Index stride, Index offset) + EIGEN_STRONG_INLINE Scalar* operator()(Index d1Idx, Index d2Idx, Scalar *block, const DataMapper& data, Index d1Size, Index d2Size, Index stride, Index offset) { constexpr auto d2Progress = PACK_SHAPES[IDX][1]; PackingOperator po; for(;d2Idx + d2Progress <= d2Size; d2Idx+=d2Progress) { - po(d1Idx, d2Idx, &block, data); + block = po(d1Idx, d2Idx, block, data); } if(PACK_SHAPES[IDX-1][0] == D1PROGRESS) { PackingInnerStruct pis; - pis(d1Idx, d2Idx, block, data, d1Size, d2Size, stride, offset); + block = pis(d1Idx, d2Idx, block, data, d1Size, d2Size, stride, offset); } + return block; } }; template struct PackingInnerStruct { - EIGEN_STRONG_INLINE void operator()(Index d1Idx, Index d2Idx, Scalar *block, const DataMapper& data, Index d1Size, Index d2Size, Index stride, Index offset) + EIGEN_STRONG_INLINE Scalar* operator()(Index d1Idx, Index d2Idx, Scalar *block, const DataMapper& data, Index d1Size, Index d2Size, Index stride, Index offset) { constexpr auto d2Progress = PACK_SHAPES[0][1]; - for(;d2Idx < d2Size; d2Idx++) + for(;d2Idx + d2Progress <= d2Size; d2Idx+=d2Progress) { PackingOperator po; - po(d1Idx, d2Idx, &block, data); + block = po(d1Idx, d2Idx, block, data); } + return block; } }; @@ -646,23 +651,23 @@ struct PackingStruct { PackingStruct[PACK_SHAPE_IDX][PACK_SHAPES_POINTER]> ps; - EIGEN_STRONG_INLINE void operator()(Index d1Idx, Scalar *block, const DataMapper& data, Index d1Size, Index d2Size, Index stride, Index offset) + EIGEN_STRONG_INLINE Scalar* operator()(Index d1Idx, Scalar *block, const DataMapper& data, Index d1Size, Index d2Size, Index stride, Index offset) { constexpr auto d1Progress = PACK_SHAPES[PACK_SHAPE_IDX][0]; for(; d1Idx + d1Progress <= d1Size; d1Idx += d1Progress) { PackingInnerStruct pis; - pis(d1Idx, 0, block, data, d1Size, d2Size, stride, offset); + block = pis(d1Idx, 0, block, data, d1Size, d2Size, stride, offset); } - ps(d1Idx, block, data, d1Size, d2Size, stride, offset); + return ps(d1Idx, block, data, d1Size, d2Size, stride, offset); } }; template struct PackingStruct { - EIGEN_STRONG_INLINE void operator()(Index, Scalar *, const DataMapper&, Index, Index, Index, Index) {} + EIGEN_STRONG_INLINE Scalar* operator()(Index, Scalar *block, const DataMapper&, Index, Index, Index, Index) { return block; } }; template @@ -675,22 +680,68 @@ struct lhs_pack } }; -template +template +struct rhs_pack +{ + EIGEN_STRONG_INLINE void operator()(Scalar *blockB, const DataMapper &rhs, Index depth, Index cols, Index stride, Index offset) + { + PackingStruct-1> ps; + ps(0, blockB, rhs, cols, depth, stride, offset); + } +}; + +template +struct PackMapCalculator +{ + PackMapCalculator[IDX][PACK_SHAPES_POINTER]> pmc; + + inline Index getPosition(Index pos, Index d2Size) + { + constexpr auto d1Progress = PACK_SHAPES[IDX][0]; + Index v = (pos / d1Progress) * d1Progress; + return v*d2Size + pmc.getPosition(pos - v, d2Size); + } +}; + +template +struct PackMapCalculator +{ + inline Index getPosition(Index, Index) { return Index(0); } +}; + +template +struct PackMap +{ + const Scalar *pBase; + const Scalar *pCur; + Index d2Size; + PackMapCalculator-1> pmc; + + PackMap(const Scalar *base, Index d2Size) : pBase(base), pCur(base), d2Size(d2Size) {} + + inline void resetCur() { pCur = pBase; } + inline void moveTo(Index pos) + { + Index inc = pmc.getPosition(pos, d2Size); + std::cout << isLhs << " MOVE_TO " << pos << " " << inc << std::endl; + pCur = pBase + inc; + } + inline void advance(int progress) { pCur += progress; } +}; + +template struct MicroKernel { - EIGEN_STRONG_INLINE void operator()(const LhsScalar** ppLhs,const RhsScalar** ppRhs, Index rowIdx, Index colIdx, Index depthIdx) + EIGEN_STRONG_INLINE void operator()(PackMap& lhsPackMap, PackMap& rhsPackMap, Index rowIdx, Index colIdx, Index depthIdx) { - const LhsScalar *pLhs = *ppLhs; - const RhsScalar *pRhs = *ppRhs; - std::cout << "Kernel " << M << " x " << K << " x " << N << " @ " << rowIdx << ", " << depthIdx << ", " << colIdx << std::endl; std::cout << "LHS "; for(auto i = rowIdx; i < M + rowIdx; i++) { for(auto j = depthIdx; j < K + depthIdx; j++) { - std::cout << *pLhs << " "; - pLhs++; + std::cout << *lhsPackMap.pCur << " "; + lhsPackMap.advance(1); } } std::cout << std::endl << "RHS "; @@ -698,13 +749,11 @@ struct MicroKernel { for(auto j = colIdx; j < N + colIdx; j++) { - std::cout << *pRhs << " "; - pRhs++; + std::cout << *rhsPackMap.pCur << " "; + rhsPackMap.advance(1); } } std::cout << std::endl; - *ppLhs += M*K; - *ppRhs += N*K; }; }; @@ -713,7 +762,7 @@ struct DepthLoopStruct { DepthLoopStruct depthLS; EIGEN_STRONG_INLINE void operator()(Index rowIdx, Index colIdx, Index depthIdx, const DataMapper& res, const LhsScalar* blockA, const RhsScalar*blockB, - Index rows, Index depth, Index cols, ResScalar alpha, Index strideA, Index strideB, Index offsetA, Index offsetB) + Index rows, Index depth, Index cols, ResScalar alpha, Index strideA, Index strideB, Index offsetA, Index offsetB, PackMap& lhsPackMap, PackMap& rhsPackMap) { constexpr auto rhsProgress = SHAPES[RHS_SHAPE_IDX][SHAPES_RHS_DIMENSION]; constexpr auto lhsProgress = SHAPES[LHS_SHAPE_IDX][SHAPES_LHS_DIMENSION]; @@ -721,13 +770,13 @@ struct DepthLoopStruct if(rhsProgress == SHAPES[IDX][SHAPES_RHS_DIMENSION] && lhsProgress == SHAPES[IDX][SHAPES_LHS_DIMENSION]) { - MicroKernel mkt; + MicroKernel mkt; for(; depthIdx + depthProgress <= depth; depthIdx+=depthProgress) { - mkt(&blockA, &blockB, rowIdx, colIdx, depthIdx); + mkt(lhsPackMap, rhsPackMap, rowIdx, colIdx, depthIdx); } } - depthLS(rowIdx, colIdx, depthIdx, res, blockA, blockB, rows, depth, cols, alpha, strideA, strideB, offsetA, offsetB); + depthLS(rowIdx, colIdx, depthIdx, res, blockA, blockB, rows, depth, cols, alpha, strideA, strideB, offsetA, offsetB, lhsPackMap, rhsPackMap); } }; @@ -735,7 +784,7 @@ template { EIGEN_STRONG_INLINE void operator()(Index, Index, Index, const DataMapper&, const LhsScalar*, const RhsScalar*, - Index, Index, Index, ResScalar, Index, Index, Index, Index) {} + Index, Index, Index, ResScalar, Index, Index, Index, Index, PackMap&, PackMap&) {} }; template @@ -743,16 +792,18 @@ struct LhsLoopStruct { LhsLoopStruct lhsLS; EIGEN_STRONG_INLINE void operator()(Index rowIdx, int colIdx, const DataMapper& res, const LhsScalar* blockA, const RhsScalar*blockB, - Index rows, Index depth, Index cols, ResScalar alpha, Index strideA, Index strideB, Index offsetA, Index offsetB) + Index rows, Index depth, Index cols, ResScalar alpha, Index strideA, Index strideB, Index offsetA, Index offsetB, PackMap& lhsPackMap, PackMap& rhsPackMap) { constexpr auto lhsProgress = SHAPES[IDX][SHAPES_LHS_DIMENSION]; DepthLoopStruct depthLS; for(;rowIdx + lhsProgress <= rows; rowIdx+=lhsProgress) { - depthLS(rowIdx, colIdx, 0, res, blockA, blockB, rows, depth, cols, alpha, strideA, strideB, offsetA, offsetB); + lhsPackMap.moveTo(rowIdx); + rhsPackMap.moveTo(colIdx); + depthLS(rowIdx, colIdx, 0, res, blockA, blockB, rows, depth, cols, alpha, strideA, strideB, offsetA, offsetB, lhsPackMap, rhsPackMap); } - lhsLS(rowIdx, colIdx, res, blockA, blockB, rows, depth, cols, alpha, strideA, strideB, offsetA, offsetB); + lhsLS(rowIdx, colIdx, res, blockA, blockB, rows, depth, cols, alpha, strideA, strideB, offsetA, offsetB, lhsPackMap, rhsPackMap); } }; @@ -760,7 +811,7 @@ template { EIGEN_STRONG_INLINE void operator()(Index, Index, const DataMapper&, const LhsScalar*, const RhsScalar*, - Index, Index, Index, ResScalar, Index, Index, Index, Index) {} + Index, Index, Index, ResScalar, Index, Index, Index, Index, PackMap&, PackMap&) {} }; template @@ -770,17 +821,18 @@ struct RhsLoopStruct RhsLoopStruct rhsLS; EIGEN_STRONG_INLINE void operator()(Index colIdx, const DataMapper& res, const LhsScalar* blockA, const RhsScalar*blockB, - Index rows, Index depth, Index cols, ResScalar alpha, Index strideA, Index strideB, Index offsetA, Index offsetB) + Index rows, Index depth, Index cols, ResScalar alpha, Index strideA, Index strideB, Index offsetA, Index offsetB, PackMap& lhsPackMap, PackMap& rhsPackMap) { constexpr auto rhsProgress = SHAPES[IDX][SHAPES_RHS_DIMENSION]; std::cout << __PRETTY_FUNCTION__ << std::endl; for(;colIdx + rhsProgress <= cols; colIdx+=rhsProgress) { + //rhsPackMap.moveTo(colIdx); LhsLoopStruct lhsLS; - lhsLS(0, colIdx, res, blockA, blockB, rows, depth, cols, alpha, strideA, strideB, offsetA, offsetB); + lhsLS(0, colIdx, res, blockA, blockB, rows, depth, cols, alpha, strideA, strideB, offsetA, offsetB, lhsPackMap, rhsPackMap); } - rhsLS(colIdx, res, blockA, blockB, rows, depth, cols, alpha, strideA, strideB, offsetA, offsetB); + rhsLS(colIdx, res, blockA, blockB, rows, depth, cols, alpha, strideA, strideB, offsetA, offsetB, lhsPackMap, rhsPackMap); } }; @@ -788,15 +840,62 @@ template { EIGEN_STRONG_INLINE void operator()(Index colIdx, const DataMapper&, const LhsScalar*, const RhsScalar*, - Index, Index, Index, ResScalar, Index, Index, Index, Index) {} + Index, Index, Index, ResScalar, Index, Index, Index, Index, PackMap&, PackMap&) {} }; template EIGEN_STRONG_INLINE void gemm(const DataMapper& res, const LhsScalar* blockA, const RhsScalar* blockB, Index rows, Index depth, Index cols, ResScalar alpha, Index strideA, Index strideB, Index offsetA, Index offsetB) { + std::cout << "blockA" << std::endl; + for(auto i = 0; i < rows*depth; i++) + { + if(i % 4 == 0 && i > 0) + std::cout << std::endl; + std::cout << blockA[i] << " "; + } + std::cout << std::endl; + std::cout << "blockB" << std::endl; + for(auto i = 0; i < depth*cols; i++) + { + if(i % 4 == 0 && i > 0) + std::cout << std::endl; + std::cout << blockB[i] << " "; + } + std::cout << std::endl; + RhsLoopStruct<0, 0, Index, LhsScalar, RhsScalar, AccScalar, ResScalar, DataMapper, SHAPES_COUNT<0, 0, LhsScalar, RhsScalar>-1> rhsLS; - rhsLS(0, res, blockA, blockB, rows, depth, cols, alpha, strideA, strideB, offsetA, offsetB); + PackMap<0, 0, Index, LhsScalar, DataMapper, true> lhsPackMap(blockA, depth); + PackMap<0, 0, Index, RhsScalar, DataMapper, false> rhsPackMap(blockB, depth); + rhsLS(0, res, blockA, blockB, rows, depth, cols, alpha, strideA, strideB, offsetA, offsetB, lhsPackMap, rhsPackMap); +} + +template +struct gemm_pack_rhs +{ + void operator()(float* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride=0, Index offset=0); +}; + +template +void gemm_pack_rhs + ::operator()(float* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride, Index offset) +{ + rhs_pack<0, 0, Index, float, DataMapper, Conjugate, PanelMode, ColMajor> pack; + pack(blockB, rhs, depth, cols, stride, offset); +} + +template +struct gemm_pack_rhs +{ + void operator()(float* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride=0, Index offset=0); +}; + +template +void gemm_pack_rhs + ::operator()(float* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride, Index offset) +{ + rhs_pack<0, 0, Index, float, DataMapper, Conjugate, PanelMode, RowMajor> pack; + pack(blockB, rhs, depth, cols, stride, offset); } template @@ -809,7 +908,8 @@ template ::operator()(float* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride, Index offset) { - + lhs_pack<0, 0, Index, float, DataMapper, Conjugate, PanelMode, RowMajor> pack; + pack(blockA, lhs, depth, rows, stride, offset); } template