From 646d92c7f1c60899fd3b8ccc64c4ac41b06a948b Mon Sep 17 00:00:00 2001 From: Everton Constantino Date: Fri, 23 Apr 2021 15:39:04 +0000 Subject: [PATCH] WIP --- Eigen/src/Core/arch/NEON/Kernels.h | 9 +- Eigen/src/Core/arch/NEON/MatrixProduct.h | 587 +---------------------- 2 files changed, 34 insertions(+), 562 deletions(-) diff --git a/Eigen/src/Core/arch/NEON/Kernels.h b/Eigen/src/Core/arch/NEON/Kernels.h index 973f71d06..3f6f25d4e 100644 --- a/Eigen/src/Core/arch/NEON/Kernels.h +++ b/Eigen/src/Core/arch/NEON/Kernels.h @@ -17,9 +17,16 @@ namespace internal { template constexpr int SHAPES_COUNT<0, CPU, LhsScalar, RhsScalar> = 7; +// lhs_progress x depth_progress x rhs_progress (depth_progress > 1 matrix ops) x pointer to next rhs_progress on the shapes map template constexpr int SHAPES<0, CPU, LhsScalar, RhsScalar>[SHAPES_COUNT<0, CPU, LhsScalar,RhsScalar>][SHAPES_DIMENSION] = - {{1,1,1,SHAPES_POINTER_END},{4,1,1,0},{1,1,4,1},{4,1,4,1},{4,4,4,1},{8,1,4,1},{8,4,4,1}}; + { {1,1,1,SHAPES_POINTER_END, SHAPES_POINTER_END, SHAPES_POINTER_END}, + {4,1,1, 0, 0, SHAPES_POINTER_END}, + {1,1,4, 1, SHAPES_POINTER_END, SHAPES_POINTER_END}, + {4,1,4, 1, 2, SHAPES_POINTER_END}, + {4,4,4, 1, 2, 3}, + {8,1,4, 1, 4, SHAPES_POINTER_END}, + {8,4,4, 1, 4, 5}}; template struct Accumulator<0, CPU, Scalar, ResScalar, DataMapper, 4, 1> diff --git a/Eigen/src/Core/arch/NEON/MatrixProduct.h b/Eigen/src/Core/arch/NEON/MatrixProduct.h index 1ff0a17bd..f766a6427 100644 --- a/Eigen/src/Core/arch/NEON/MatrixProduct.h +++ b/Eigen/src/Core/arch/NEON/MatrixProduct.h @@ -18,555 +18,16 @@ namespace Eigen { namespace internal { - -#ifdef __OLD__ -template -class PackMap -{ - const int packetSize = packet_traits::size; - const Scalar *packed_block; - const Scalar *residue_block; - Index packed_stride; - Index residue_size; - Index rows, cols; - Index offset, stride; - Scalar *cur; -public: - PackMap(const Scalar *packed_block, const Scalar *residue_block, Index rows, Index cols, Index offset, Index stride) : packed_block(packed_block), residue_block(residue_block), rows(rows), cols(cols), offset(offset), stride(stride) - { - if(IsLhs) - { - packed_stride = (rows / packetSize) * packetSize; - residue_size = rows % packetSize; - } - else { - packed_stride = (cols / packetSize) * packetSize; - residue_size = cols % packetSize; - } - }; - - PackMap(const Scalar *packed_block, Index rows, Index cols, Index offset, Index stride) : packed_block(packed_block), rows(rows), cols(cols) - { - if(IsLhs) - { - packed_stride = (rows / packetSize) * packetSize; - residue_block = packed_block + packed_stride*cols; - residue_size = rows % packetSize; - } - else { - packed_stride = (cols / packetSize) * packetSize; - residue_block = packed_block + packed_stride*rows; - residue_size = cols % packetSize; - } - - }; - - EIGEN_STRONG_INLINE Index get_packed_size() - { - return packed_stride; - }; - - EIGEN_STRONG_INLINE Index get_residue_size() - { - return residue_size; - }; - - EIGEN_STRONG_INLINE const Scalar* get_packed_at(Index at) - { - return IsLhs ? packed_block + at : packed_block + at*packetSize*rows; - }; - - EIGEN_STRONG_INLINE const Scalar* get_residue_at(Index at) - { - return residue_block + stride*at; - }; -}; - -template -EIGEN_STRONG_INLINE void gemm(const DataMapper& res, const LhsScalar* blockA, const RhsScalar* blockB, - Index rows, Index depth, Index cols, ResScalar alpha, Index strideA, Index strideB, Index offsetA, Index offsetB) -{ - using AccPacket = typename packet_traits::type; - using LhsPacket = typename packet_traits::type; - using RhsPacket = typename packet_traits::type; - using ResPacket = typename packet_traits::type; - using LinearMapper = typename DataMapper::LinearMapper; - - if( strideA == -1 ) strideA = depth; - if( strideB == -1 ) strideB = depth; - - ResPacket pAlpha = pset1(alpha); - -#ifdef __DEBUG__ - std::cout << "blockA" << std::endl; - for(auto i = 0; i < rows*depth; i++) - { - if(i % strideA == 0 && i > 0) - std::cout << std::endl; - std::cout << blockA[i] << " "; - } - std::cout << std::endl; - std::cout << "blockB" << std::endl; - for(auto i = 0; i < depth*cols; i++) - { - if(i % strideB == 0 && i > 0) - std::cout << std::endl; - std::cout << blockB[i] << " "; - } - std::cout << std::endl; -#endif - - int accLhsProgress = 4; - int accRhsProgress = 4; - - PackMap lhsMap(blockA, rows, depth, offsetA, strideA); - PackMap rhsMap(blockB, depth, cols, offsetB, strideB); - auto col = 0; - for(; col + accRhsProgress <= rhsMap.get_packed_size(); col+=accRhsProgress) - { - auto row = 0; - for(; row + 3*accLhsProgress <= lhsMap.get_packed_size(); row+=3*accLhsProgress) - { - const LhsScalar *lhs_ptr1 = lhsMap.get_packed_at(row + 0*accLhsProgress); - const LhsScalar *lhs_ptr2 = lhsMap.get_packed_at(row + 1*accLhsProgress); - const LhsScalar *lhs_ptr3 = lhsMap.get_packed_at(row + 2*accLhsProgress); - const RhsScalar *rhs_ptr = rhsMap.get_packed_at(col/accRhsProgress); - - PacketBlock acc1; - acc1.packet[0] = pset1(0); - acc1.packet[1] = pset1(0); - acc1.packet[2] = pset1(0); - acc1.packet[3] = pset1(0); - - PacketBlock acc2; - acc2.packet[0] = pset1(0); - acc2.packet[1] = pset1(0); - acc2.packet[2] = pset1(0); - acc2.packet[3] = pset1(0); - - PacketBlock acc3; - acc3.packet[0] = pset1(0); - acc3.packet[1] = pset1(0); - acc3.packet[2] = pset1(0); - acc3.packet[3] = pset1(0); - - LinearMapper r00 = res.getLinearMapper(row + 0*accLhsProgress, col + 0); - LinearMapper r01 = res.getLinearMapper(row + 0*accLhsProgress, col + 1); - LinearMapper r02 = res.getLinearMapper(row + 0*accLhsProgress, col + 2); - LinearMapper r03 = res.getLinearMapper(row + 0*accLhsProgress, col + 3); - - LinearMapper r10 = res.getLinearMapper(row + 1*accLhsProgress, col + 0); - LinearMapper r11 = res.getLinearMapper(row + 1*accLhsProgress, col + 1); - LinearMapper r12 = res.getLinearMapper(row + 1*accLhsProgress, col + 2); - LinearMapper r13 = res.getLinearMapper(row + 1*accLhsProgress, col + 3); - - LinearMapper r20 = res.getLinearMapper(row + 2*accLhsProgress, col + 0); - LinearMapper r21 = res.getLinearMapper(row + 2*accLhsProgress, col + 1); - LinearMapper r22 = res.getLinearMapper(row + 2*accLhsProgress, col + 2); - LinearMapper r23 = res.getLinearMapper(row + 2*accLhsProgress, col + 3); - - auto k = 0; - for(; k < depth; k++) - { - RhsPacket prhs = pload(rhs_ptr); - PacketBlock pbrhs; - pbrhs.packet[0] = pset1(prhs[0]); - pbrhs.packet[1] = pset1(prhs[1]); - pbrhs.packet[2] = pset1(prhs[2]); - pbrhs.packet[3] = pset1(prhs[3]); - - LhsPacket plhs1 = pload(lhs_ptr1); - LhsPacket plhs2 = pload(lhs_ptr2); - LhsPacket plhs3 = pload(lhs_ptr3); - - acc1.packet[0] += plhs1*pbrhs.packet[0]; - acc1.packet[1] += plhs1*pbrhs.packet[1]; - acc1.packet[2] += plhs1*pbrhs.packet[2]; - acc1.packet[3] += plhs1*pbrhs.packet[3]; - - acc2.packet[0] += plhs2*pbrhs.packet[0]; - acc2.packet[1] += plhs2*pbrhs.packet[1]; - acc2.packet[2] += plhs2*pbrhs.packet[2]; - acc2.packet[3] += plhs2*pbrhs.packet[3]; - - acc3.packet[0] += plhs3*pbrhs.packet[0]; - acc3.packet[1] += plhs3*pbrhs.packet[1]; - acc3.packet[2] += plhs3*pbrhs.packet[2]; - acc3.packet[3] += plhs3*pbrhs.packet[3]; - - lhs_ptr1 += (rows/accLhsProgress)*accLhsProgress; - lhs_ptr2 += (rows/accLhsProgress)*accLhsProgress; - lhs_ptr3 += (rows/accLhsProgress)*accLhsProgress; - rhs_ptr += accRhsProgress; - } - - r00.storePacket(0,r00.template loadPacket(0) + acc1.packet[0]); - r01.storePacket(0,r01.template loadPacket(0) + acc1.packet[1]); - r02.storePacket(0,r02.template loadPacket(0) + acc1.packet[2]); - r03.storePacket(0,r03.template loadPacket(0) + acc1.packet[3]); - - r10.storePacket(0,r10.template loadPacket(0) + acc2.packet[0]); - r11.storePacket(0,r11.template loadPacket(0) + acc2.packet[1]); - r12.storePacket(0,r12.template loadPacket(0) + acc2.packet[2]); - r13.storePacket(0,r13.template loadPacket(0) + acc2.packet[3]); - - r20.storePacket(0,r20.template loadPacket(0) + acc3.packet[0]); - r21.storePacket(0,r21.template loadPacket(0) + acc3.packet[1]); - r22.storePacket(0,r22.template loadPacket(0) + acc3.packet[2]); - r23.storePacket(0,r23.template loadPacket(0) + acc3.packet[3]); - } - for(; row + 2*accLhsProgress <= lhsMap.get_packed_size(); row+=2*accLhsProgress) - { - const LhsScalar *lhs_ptr1 = lhsMap.get_packed_at(row + 0*accLhsProgress); - const LhsScalar *lhs_ptr2 = lhsMap.get_packed_at(row + 1*accLhsProgress); - const RhsScalar *rhs_ptr = rhsMap.get_packed_at(col/accRhsProgress); - - PacketBlock acc1; - acc1.packet[0] = pset1(0); - acc1.packet[1] = pset1(0); - acc1.packet[2] = pset1(0); - acc1.packet[3] = pset1(0); - - PacketBlock acc2; - acc2.packet[0] = pset1(0); - acc2.packet[1] = pset1(0); - acc2.packet[2] = pset1(0); - acc2.packet[3] = pset1(0); - - LinearMapper r00 = res.getLinearMapper(row + 0*accLhsProgress, col + 0); - LinearMapper r01 = res.getLinearMapper(row + 0*accLhsProgress, col + 1); - LinearMapper r02 = res.getLinearMapper(row + 0*accLhsProgress, col + 2); - LinearMapper r03 = res.getLinearMapper(row + 0*accLhsProgress, col + 3); - - LinearMapper r10 = res.getLinearMapper(row + 1*accLhsProgress, col + 0); - LinearMapper r11 = res.getLinearMapper(row + 1*accLhsProgress, col + 1); - LinearMapper r12 = res.getLinearMapper(row + 1*accLhsProgress, col + 2); - LinearMapper r13 = res.getLinearMapper(row + 1*accLhsProgress, col + 3); - - auto k = 0; - for(; k < depth; k++) - { - RhsPacket prhs = pload(rhs_ptr); - PacketBlock pbrhs; - pbrhs.packet[0] = pset1(prhs[0]); - pbrhs.packet[1] = pset1(prhs[1]); - pbrhs.packet[2] = pset1(prhs[2]); - pbrhs.packet[3] = pset1(prhs[3]); - - LhsPacket plhs1 = pload(lhs_ptr1); - LhsPacket plhs2 = pload(lhs_ptr2); - - acc1.packet[0] += plhs1*pbrhs.packet[0]; - acc1.packet[1] += plhs1*pbrhs.packet[1]; - acc1.packet[2] += plhs1*pbrhs.packet[2]; - acc1.packet[3] += plhs1*pbrhs.packet[3]; - - acc2.packet[0] += plhs2*pbrhs.packet[0]; - acc2.packet[1] += plhs2*pbrhs.packet[1]; - acc2.packet[2] += plhs2*pbrhs.packet[2]; - acc2.packet[3] += plhs2*pbrhs.packet[3]; - - lhs_ptr1 += (rows/accLhsProgress)*accLhsProgress; - lhs_ptr2 += (rows/accLhsProgress)*accLhsProgress; - rhs_ptr += accRhsProgress; - } - - r00.storePacket(0,r00.template loadPacket(0) + acc1.packet[0]); - r01.storePacket(0,r01.template loadPacket(0) + acc1.packet[1]); - r02.storePacket(0,r02.template loadPacket(0) + acc1.packet[2]); - r03.storePacket(0,r03.template loadPacket(0) + acc1.packet[3]); - - r10.storePacket(0,r10.template loadPacket(0) + acc2.packet[0]); - r11.storePacket(0,r11.template loadPacket(0) + acc2.packet[1]); - r12.storePacket(0,r12.template loadPacket(0) + acc2.packet[2]); - r13.storePacket(0,r13.template loadPacket(0) + acc2.packet[3]); - } - for(; row + accLhsProgress <= lhsMap.get_packed_size(); row+=accLhsProgress) - { - const LhsScalar *lhs_ptr = lhsMap.get_packed_at(row); - const RhsScalar *rhs_ptr = rhsMap.get_packed_at(col/accRhsProgress); - PacketBlock acc; - acc.packet[0] = pset1(0); - acc.packet[1] = pset1(0); - acc.packet[2] = pset1(0); - acc.packet[3] = pset1(0); - - LinearMapper r0 = res.getLinearMapper(row, col + 0); - LinearMapper r1 = res.getLinearMapper(row, col + 1); - LinearMapper r2 = res.getLinearMapper(row, col + 2); - LinearMapper r3 = res.getLinearMapper(row, col + 3); - - auto k = 0; - for(; k < depth; k++) - { - RhsPacket prhs = pload(rhs_ptr); - PacketBlock pbrhs; - pbrhs.packet[0] = pset1(prhs[0]); - pbrhs.packet[1] = pset1(prhs[1]); - pbrhs.packet[2] = pset1(prhs[2]); - pbrhs.packet[3] = pset1(prhs[3]); - - LhsPacket plhs = pload(lhs_ptr); - -#ifdef __NDEBUG__ - std::cout << "(" << row << "," << k << "," << col << ")" << std::endl; - std::cout << "lhs " << plhs[0] << " " << plhs[1] << " " << plhs[2] << " " << plhs[3] << std::endl; - std::cout << "rhs " << prhs[0] << " " << prhs[1] << " " << prhs[2] << " " << prhs[3] << std::endl; -#endif - acc.packet[0] += plhs*pbrhs.packet[0]; - acc.packet[1] += plhs*pbrhs.packet[1]; - acc.packet[2] += plhs*pbrhs.packet[2]; - acc.packet[3] += plhs*pbrhs.packet[3]; - - lhs_ptr += (rows/accLhsProgress)*accLhsProgress; - rhs_ptr += accRhsProgress; - } - - r0.storePacket(0,r0.template loadPacket(0) + acc.packet[0]); - r1.storePacket(0,r1.template loadPacket(0) + acc.packet[1]); - r2.storePacket(0,r2.template loadPacket(0) + acc.packet[2]); - r3.storePacket(0,r3.template loadPacket(0) + acc.packet[3]); - } - auto row_residue = 0; - for(;row < rows; row++) - { - const LhsScalar *lhs_ptr = lhsMap.get_residue_at(row_residue); - const RhsScalar *rhs_ptr = rhsMap.get_packed_at(col/accRhsProgress); - PacketBlock acc; - acc.packet[0] = pset1(0); - - auto k = 0; - for(; k < depth; k++) - { - RhsPacket prhs = pload(rhs_ptr); - LhsPacket plhs = pset1(*lhs_ptr); - -#ifdef __NDEBUG__ - std::cout << "(" << row << "," << k << "," << col << ")" << std::endl; - std::cout << "lhs " << plhs[0] << " " << plhs[1] << " " << plhs[2] << " " << plhs[3] << std::endl; - std::cout << "rhs " << prhs[0] << " " << prhs[1] << " " << prhs[2] << " " << prhs[3] << std::endl; -#endif - acc.packet[0] += (*lhs_ptr)*prhs; - - lhs_ptr++; - rhs_ptr += accRhsProgress; - } - - res(row, col + 0) += acc.packet[0][0]; - res(row, col + 1) += acc.packet[0][1]; - res(row, col + 2) += acc.packet[0][2]; - res(row, col + 3) += acc.packet[0][3]; - row_residue++; - } - } - auto col_residue = 0; - for(; col < cols; col++) - { - auto row = 0; - for(; row + accLhsProgress <= lhsMap.get_packed_size(); row+=accLhsProgress) - { - const LhsScalar *lhs_ptr = lhsMap.get_packed_at(row); - const RhsScalar *rhs_ptr = rhsMap.get_residue_at(col_residue); - PacketBlock acc; - acc.packet[0] = pset1(0); - - LinearMapper r0 = res.getLinearMapper(row, col + 0); - - auto k = 0; - for(; k < depth; k++) - { - RhsPacket prhs = pset1(*rhs_ptr); - - LhsPacket plhs = pload(lhs_ptr); - -#ifdef __NDEBUG__ - std::cout << "(" << row << "," << k << "," << col << ")" << std::endl; - std::cout << "lhs " << plhs[0] << " " << plhs[1] << " " << plhs[2] << " " << plhs[3] << std::endl; - std::cout << "rhs " << prhs[0] << " " << prhs[1] << " " << prhs[2] << " " << prhs[3] << std::endl; -#endif - acc.packet[0] += plhs*prhs; - - lhs_ptr += (rows/accLhsProgress)*accLhsProgress; - rhs_ptr++; - } - - r0.storePacket(0,r0.template loadPacket(0) + acc.packet[0]); - } - auto row_residue = 0; - for(;row < rows; row++) - { - const LhsScalar *lhs_ptr = lhsMap.get_residue_at(row_residue); - const RhsScalar *rhs_ptr = rhsMap.get_residue_at(col_residue); - AccScalar acc = 0; - auto k = 0; - for(; k < depth; k++) - { -#ifdef __NDEBUG__ - std::cout << "(" << row << "," << k << "," << col << ")" << std::endl; - std::cout << "lhs " << plhs[0] << " " << plhs[1] << " " << plhs[2] << " " << plhs[3] << std::endl; - std::cout << "rhs " << prhs[0] << " " << prhs[1] << " " << prhs[2] << " " << prhs[3] << std::endl; -#endif - acc += (*lhs_ptr)*(*rhs_ptr); - - lhs_ptr++; - rhs_ptr++; - } - - //r0.storePacket(0,r0.template loadPacket(0) + acc.packet[0]); - res(row, col) += acc; - row_residue++; - } - col_residue++; - } -} - -template -EIGEN_STRONG_INLINE void gemm_old(const DataMapper& res, const LhsScalar* blockA, const RhsScalar* blockB, - Index rows, Index depth, Index cols, ResScalar alpha, Index strideA, Index strideB, Index offsetA, Index offsetB) -{ - using AccPacket = typename packet_traits::type; - using LhsPacket = typename packet_traits::type; - using RhsPacket = typename packet_traits::type; - using ResPacket = typename packet_traits::type; - - ResPacket pAlpha = pset1(alpha); - -#ifdef __DEBUG__ - std::cout << "blockA" << std::endl; - for(auto i = 0; i < rows*depth; i++) - { - if(i % 4 == 0 && i > 0) - std::cout << std::endl; - std::cout << blockA[i] << " "; - } - std::cout << std::endl; - std::cout << "blockB" << std::endl; - for(auto i = 0; i < depth*cols; i++) - { - if(i % 4 == 0 && i > 0) - std::cout << std::endl; - std::cout << blockB[i] << " "; - } - std::cout << std::endl; -#endif - - if( strideA == -1 ) strideA = depth; - if( strideB == -1 ) strideB = depth; - - int accLhsProgress = 4; - int accRhsProgress = 4; - - PackMap lhsMap(blockA, rows, depth, offsetA, strideA); - PackMap rhsMap(blockB, depth, cols, offsetB, strideB); - auto col = 0; - for(; col < rhsMap.get_packed_size(); col+=accRhsProgress) - { - for(auto k = 0; k < depth; k++) - { - const LhsScalar *lhs_ptr = lhsMap.get_packed_at(k); - const RhsScalar *rhs_ptr = rhsMap.get_packed_at(col/accRhsProgress) + k*accRhsProgress; - PacketBlock acc; - RhsPacket prhs = pload(rhs_ptr); - PacketBlock pbrhs; - pbrhs.packet[0] = pset1(prhs[0]); - pbrhs.packet[1] = pset1(prhs[1]); - pbrhs.packet[2] = pset1(prhs[2]); - pbrhs.packet[3] = pset1(prhs[3]); - auto row = 0; - using LinearMapper = typename DataMapper::LinearMapper; - for(; row < lhsMap.get_packed_size(); row+=accLhsProgress) - { - LinearMapper r0 = res.getLinearMapper(row, col + 0); - LinearMapper r1 = res.getLinearMapper(row, col + 1); - LinearMapper r2 = res.getLinearMapper(row, col + 2); - LinearMapper r3 = res.getLinearMapper(row, col + 3); - - LhsPacket plhs = pload(lhs_ptr); -#ifdef __NDEBUG__ - std::cout << "(" << row << "," << k << "," << col << ")" << std::endl; - std::cout << "lhs " << plhs[0] << " " << plhs[1] << " " << plhs[2] << " " << plhs[3] << std::endl; - std::cout << "rhs " << prhs[0] << " " << prhs[1] << " " << prhs[2] << " " << prhs[3] << std::endl; -#endif - acc.packet[0] = plhs*pbrhs.packet[0]; - acc.packet[1] = plhs*pbrhs.packet[1]; - acc.packet[2] = plhs*pbrhs.packet[2]; - acc.packet[3] = plhs*pbrhs.packet[3]; - - r0.storePacket(0,r0.template loadPacket(0) + acc.packet[0]); - r1.storePacket(0,r1.template loadPacket(0) + acc.packet[1]); - r2.storePacket(0,r2.template loadPacket(0) + acc.packet[2]); - r3.storePacket(0,r3.template loadPacket(0) + acc.packet[3]); - lhs_ptr += accLhsProgress; - } - auto residue = 0; - for(;row < rows; row++) - { - LhsScalar lhs = *(lhsMap.get_residue_at(residue) + k); -#ifdef __NDEBUG__ - std::cout << "(" << row << "," << k << "," << col << ")" << std::endl; - std::cout << "lhs " << lhs << " (" << prhs[0] << " " << prhs[1] << " " << prhs[2] << " " << prhs[3] << ")" << std::endl; -#endif - res(row, col + 0) += lhs*prhs[0]; - res(row, col + 1) += lhs*prhs[1]; - res(row, col + 2) += lhs*prhs[2]; - res(row, col + 3) += lhs*prhs[3]; - residue++; - } - } - } - auto colResidue = 0; - for(;col < cols; col++) - { - for(auto k = 0; k < depth; k++) - { - const LhsScalar *lhs_ptr = lhsMap.get_packed_at(k); - const RhsScalar *rhs_ptr = rhsMap.get_residue_at(colResidue) + k; - AccPacket acc; - - RhsPacket prhs = pset1(*rhs_ptr); - - auto row = 0; - using LinearMapper = typename DataMapper::LinearMapper; - for(; row < lhsMap.get_packed_size(); row+=accLhsProgress) - { - LinearMapper r0 = res.getLinearMapper(row, col + 0); - - LhsPacket plhs = pload(lhs_ptr); -#ifdef __DEBUG__ - std::cout << "(" << row << "," << k << "," << col << ")" << std::endl; - std::cout << "lhs " << plhs[0] << " " << plhs[1] << " " << plhs[2] << " " << plhs[3] << std::endl; - std::cout << "rhs " << prhs[0] << " " << prhs[1] << " " << prhs[2] << " " << prhs[3] << std::endl; -#endif - acc = plhs*prhs; - - r0.storePacket(0,r0.template loadPacket(0) + acc); - lhs_ptr += accLhsProgress; - } - auto residue = 0; - for(;row < rows; row++) - { - LhsScalar lhs = *(lhsMap.get_residue_at(residue) + k); -#ifdef __DEBUG__ - std::cout << "(" << row << "," << k << "," << col << ")" << std::endl; - std::cout << "lhs " << lhs << " (" << prhs[0] << " " << prhs[1] << " " << prhs[2] << " " << prhs[3] << ")" << std::endl; -#endif - res(row, col + 0) += lhs*prhs[0]; - residue++; - } - } - colResidue++; - } -} -#endif - template constexpr int SHAPES_COUNT = 4; -constexpr int SHAPES_DIMENSION = 4; +constexpr int SHAPES_DIMENSION = 6; constexpr int SHAPES_LHS_DIMENSION = 0; constexpr int SHAPES_DEP_DIMENSION = 1; constexpr int SHAPES_RHS_DIMENSION = 2; -constexpr int SHAPES_POINTER = 3; +constexpr int SHAPES_RHS_POINTER = 3; +constexpr int SHAPES_LHS_POINTER = 4; +constexpr int SHAPES_DEP_POINTER = 5; constexpr int SHAPES_POINTER_END = -1; template @@ -575,18 +36,18 @@ constexpr int PACK_SHAPES_DIMENSION = 3; constexpr int PACK_SHAPES_POINTER = 2; constexpr int PACK_SHAPES_END = -1; - // lhs_progress x depth_progress x rhs_progress (depth_progress > 1 matrix ops) x pointer to next rhs_progress on the shapes map template -constexpr int SHAPES[SHAPES_COUNT][SHAPES_DIMENSION] = {{1,1,1,SHAPES_POINTER_END},{4,1,1,0},{1,1,4,1},{4,1,4,1}}; +constexpr int SHAPES[SHAPES_COUNT][SHAPES_DIMENSION] = + { {1,1,1,SHAPES_POINTER_END, SHAPES_POINTER_END, SHAPES_POINTER_END}, + {4,1,1, 0, 0, SHAPES_POINTER_END}, + {1,1,4, 1, SHAPES_POINTER_END, SHAPES_POINTER_END}, + {4,1,4, 1, 2, SHAPES_POINTER_END}}; // d1progress x d2progress template constexpr int PACK_SHAPES[PACK_SHAPES_COUNT][PACK_SHAPES_DIMENSION] = {{1,1,PACK_SHAPES_END},{4,1,0}}; -//template -//constexpr int PACK_SHAPES[PACK_SHAPES_COUNT][PACK_SHAPES_DIMENSION] = {{1,1,PACK_SHAPES_END},{4,1,0}}; - template struct PackingOperator { @@ -816,27 +277,29 @@ struct MicroKernel template struct DepthLoopStruct { - DepthLoopStruct depthLS; + static constexpr auto PREVIOUS = SHAPES[IDX][SHAPES_DEP_POINTER]; + + DepthLoopStruct depthLS; + EIGEN_STRONG_INLINE void operator()(Index rowIdx, Index colIdx, Index depthIdx, const DataMapper& res, Index rows, Index depth, Index cols, ResScalar alpha, const ResPacket& pAlpha, LhsPackMap& lhsPackMap, RhsPackMap& rhsPackMap) { constexpr auto rhsProgress = SHAPES[RHS_SHAPE_IDX][SHAPES_RHS_DIMENSION]; constexpr auto lhsProgress = SHAPES[LHS_SHAPE_IDX][SHAPES_LHS_DIMENSION]; constexpr auto depthProgress = SHAPES[IDX][SHAPES_DEP_DIMENSION]; + typedef Accumulator AccumulatorType; - if(rhsProgress == SHAPES[IDX][SHAPES_RHS_DIMENSION] && lhsProgress == SHAPES[IDX][SHAPES_LHS_DIMENSION]) + MicroKernel mkt; + AccumulatorType acc; + acc.zero(); + for(; depthIdx + depthProgress <= depth; depthIdx+=depthProgress) { - MicroKernel mkt; - AccumulatorType acc; - acc.zero(); - for(; depthIdx + depthProgress <= depth; depthIdx+=depthProgress) - { - mkt(lhsPackMap, rhsPackMap, rowIdx, colIdx, depthIdx, acc); - } - acc.scale(alpha, pAlpha); - acc.store(res, rowIdx, colIdx); + mkt(lhsPackMap, rhsPackMap, rowIdx, colIdx, depthIdx, acc); } + acc.scale(alpha, pAlpha); + acc.store(res, rowIdx, colIdx); + depthLS(rowIdx, colIdx, depthIdx, res, rows, depth, cols, alpha, pAlpha, lhsPackMap, rhsPackMap); } }; @@ -851,7 +314,9 @@ struct DepthLoopStruct struct LhsLoopStruct { - LhsLoopStruct lhsLS; + static constexpr auto PREVIOUS = SHAPES[IDX][SHAPES_LHS_POINTER]; + LhsLoopStruct lhsLS; + EIGEN_STRONG_INLINE void operator()(Index rowIdx, int colIdx, const DataMapper& res, Index rows, Index depth, Index cols, ResScalar alpha, const ResPacket& pAlpha, LhsPackMap& lhsPackMap, RhsPackMap& rhsPackMap) { @@ -878,7 +343,7 @@ struct LhsLoopStruct struct RhsLoopStruct { - static constexpr auto PREVIOUS = SHAPES[IDX][SHAPES_POINTER]; + static constexpr auto PREVIOUS = SHAPES[IDX][SHAPES_RHS_POINTER]; RhsLoopStruct rhsLS; EIGEN_STRONG_INLINE void operator()(Index colIdx, const DataMapper& res,