From f826663a3ac4b1bd71f3562732fdde9fe3b48254 Mon Sep 17 00:00:00 2001 From: Everton Constantino Date: Tue, 20 Apr 2021 20:10:21 +0000 Subject: [PATCH] WIP --- Eigen/Core | 1 + Eigen/src/Core/arch/NEON/MatrixProduct.h | 848 +++++++++++++++++++++++ 2 files changed, 849 insertions(+) create mode 100644 Eigen/src/Core/arch/NEON/MatrixProduct.h diff --git a/Eigen/Core b/Eigen/Core index 5921e15f9..5545fd011 100644 --- a/Eigen/Core +++ b/Eigen/Core @@ -350,6 +350,7 @@ using std::ptrdiff_t; #include "src/Core/arch/AltiVec/MatrixProduct.h" #elif defined EIGEN_VECTORIZE_NEON #include "src/Core/arch/NEON/GeneralBlockPanelKernel.h" + #include "src/Core/arch/NEON/MatrixProduct.h" #endif #include "src/Core/BooleanRedux.h" diff --git a/Eigen/src/Core/arch/NEON/MatrixProduct.h b/Eigen/src/Core/arch/NEON/MatrixProduct.h new file mode 100644 index 000000000..14370ede3 --- /dev/null +++ b/Eigen/src/Core/arch/NEON/MatrixProduct.h @@ -0,0 +1,848 @@ +// This file is part of Eigen, a lightweight C++ template library +// for linear algebra. +// +// Copyright (C) 2021 Everton Constantino (everton.constantino@hotmail.com) +// +// This Source Code Form is subject to the terms of the Mozilla +// Public License v. 2.0. If a copy of the MPL was not distributed +// with this file, You can obtain one at http://mozilla.org/MPL/2.0/. + +#ifndef EIGEN_MATRIX_PRODUCT_NEON_H +#define EIGEN_MATRIX_PRODUCT_NEON_H + +#ifdef __DEBUG__ +#include +#endif + +namespace Eigen { + +namespace internal { + + +#ifdef __OLD__ +template +class PackMap +{ + const int packetSize = packet_traits::size; + const Scalar *packed_block; + const Scalar *residue_block; + Index packed_stride; + Index residue_size; + Index rows, cols; + Index offset, stride; + Scalar *cur; +public: + PackMap(const Scalar *packed_block, const Scalar *residue_block, Index rows, Index cols, Index offset, Index stride) : packed_block(packed_block), residue_block(residue_block), rows(rows), cols(cols), offset(offset), stride(stride) + { + if(IsLhs) + { + packed_stride = (rows / packetSize) * packetSize; + residue_size = rows % packetSize; + } + else { + packed_stride = (cols / packetSize) * packetSize; + residue_size = cols % packetSize; + } + }; + + PackMap(const Scalar *packed_block, Index rows, Index cols, Index offset, Index stride) : packed_block(packed_block), rows(rows), cols(cols) + { + if(IsLhs) + { + packed_stride = (rows / packetSize) * packetSize; + residue_block = packed_block + packed_stride*cols; + residue_size = rows % packetSize; + } + else { + packed_stride = (cols / packetSize) * packetSize; + residue_block = packed_block + packed_stride*rows; + residue_size = cols % packetSize; + } + + }; + + EIGEN_STRONG_INLINE Index get_packed_size() + { + return packed_stride; + }; + + EIGEN_STRONG_INLINE Index get_residue_size() + { + return residue_size; + }; + + EIGEN_STRONG_INLINE const Scalar* get_packed_at(Index at) + { + return IsLhs ? packed_block + at : packed_block + at*packetSize*rows; + }; + + EIGEN_STRONG_INLINE const Scalar* get_residue_at(Index at) + { + return residue_block + stride*at; + }; +}; + +template +EIGEN_STRONG_INLINE void gemm(const DataMapper& res, const LhsScalar* blockA, const RhsScalar* blockB, + Index rows, Index depth, Index cols, ResScalar alpha, Index strideA, Index strideB, Index offsetA, Index offsetB) +{ + using AccPacket = typename packet_traits::type; + using LhsPacket = typename packet_traits::type; + using RhsPacket = typename packet_traits::type; + using ResPacket = typename packet_traits::type; + using LinearMapper = typename DataMapper::LinearMapper; + + if( strideA == -1 ) strideA = depth; + if( strideB == -1 ) strideB = depth; + + ResPacket pAlpha = pset1(alpha); + +#ifdef __DEBUG__ + std::cout << "blockA" << std::endl; + for(auto i = 0; i < rows*depth; i++) + { + if(i % strideA == 0 && i > 0) + std::cout << std::endl; + std::cout << blockA[i] << " "; + } + std::cout << std::endl; + std::cout << "blockB" << std::endl; + for(auto i = 0; i < depth*cols; i++) + { + if(i % strideB == 0 && i > 0) + std::cout << std::endl; + std::cout << blockB[i] << " "; + } + std::cout << std::endl; +#endif + + int accLhsProgress = 4; + int accRhsProgress = 4; + + PackMap lhsMap(blockA, rows, depth, offsetA, strideA); + PackMap rhsMap(blockB, depth, cols, offsetB, strideB); + auto col = 0; + for(; col + accRhsProgress <= rhsMap.get_packed_size(); col+=accRhsProgress) + { + auto row = 0; + for(; row + 3*accLhsProgress <= lhsMap.get_packed_size(); row+=3*accLhsProgress) + { + const LhsScalar *lhs_ptr1 = lhsMap.get_packed_at(row + 0*accLhsProgress); + const LhsScalar *lhs_ptr2 = lhsMap.get_packed_at(row + 1*accLhsProgress); + const LhsScalar *lhs_ptr3 = lhsMap.get_packed_at(row + 2*accLhsProgress); + const RhsScalar *rhs_ptr = rhsMap.get_packed_at(col/accRhsProgress); + + PacketBlock acc1; + acc1.packet[0] = pset1(0); + acc1.packet[1] = pset1(0); + acc1.packet[2] = pset1(0); + acc1.packet[3] = pset1(0); + + PacketBlock acc2; + acc2.packet[0] = pset1(0); + acc2.packet[1] = pset1(0); + acc2.packet[2] = pset1(0); + acc2.packet[3] = pset1(0); + + PacketBlock acc3; + acc3.packet[0] = pset1(0); + acc3.packet[1] = pset1(0); + acc3.packet[2] = pset1(0); + acc3.packet[3] = pset1(0); + + LinearMapper r00 = res.getLinearMapper(row + 0*accLhsProgress, col + 0); + LinearMapper r01 = res.getLinearMapper(row + 0*accLhsProgress, col + 1); + LinearMapper r02 = res.getLinearMapper(row + 0*accLhsProgress, col + 2); + LinearMapper r03 = res.getLinearMapper(row + 0*accLhsProgress, col + 3); + + LinearMapper r10 = res.getLinearMapper(row + 1*accLhsProgress, col + 0); + LinearMapper r11 = res.getLinearMapper(row + 1*accLhsProgress, col + 1); + LinearMapper r12 = res.getLinearMapper(row + 1*accLhsProgress, col + 2); + LinearMapper r13 = res.getLinearMapper(row + 1*accLhsProgress, col + 3); + + LinearMapper r20 = res.getLinearMapper(row + 2*accLhsProgress, col + 0); + LinearMapper r21 = res.getLinearMapper(row + 2*accLhsProgress, col + 1); + LinearMapper r22 = res.getLinearMapper(row + 2*accLhsProgress, col + 2); + LinearMapper r23 = res.getLinearMapper(row + 2*accLhsProgress, col + 3); + + auto k = 0; + for(; k < depth; k++) + { + RhsPacket prhs = pload(rhs_ptr); + PacketBlock pbrhs; + pbrhs.packet[0] = pset1(prhs[0]); + pbrhs.packet[1] = pset1(prhs[1]); + pbrhs.packet[2] = pset1(prhs[2]); + pbrhs.packet[3] = pset1(prhs[3]); + + LhsPacket plhs1 = pload(lhs_ptr1); + LhsPacket plhs2 = pload(lhs_ptr2); + LhsPacket plhs3 = pload(lhs_ptr3); + + acc1.packet[0] += plhs1*pbrhs.packet[0]; + acc1.packet[1] += plhs1*pbrhs.packet[1]; + acc1.packet[2] += plhs1*pbrhs.packet[2]; + acc1.packet[3] += plhs1*pbrhs.packet[3]; + + acc2.packet[0] += plhs2*pbrhs.packet[0]; + acc2.packet[1] += plhs2*pbrhs.packet[1]; + acc2.packet[2] += plhs2*pbrhs.packet[2]; + acc2.packet[3] += plhs2*pbrhs.packet[3]; + + acc3.packet[0] += plhs3*pbrhs.packet[0]; + acc3.packet[1] += plhs3*pbrhs.packet[1]; + acc3.packet[2] += plhs3*pbrhs.packet[2]; + acc3.packet[3] += plhs3*pbrhs.packet[3]; + + lhs_ptr1 += (rows/accLhsProgress)*accLhsProgress; + lhs_ptr2 += (rows/accLhsProgress)*accLhsProgress; + lhs_ptr3 += (rows/accLhsProgress)*accLhsProgress; + rhs_ptr += accRhsProgress; + } + + r00.storePacket(0,r00.template loadPacket(0) + acc1.packet[0]); + r01.storePacket(0,r01.template loadPacket(0) + acc1.packet[1]); + r02.storePacket(0,r02.template loadPacket(0) + acc1.packet[2]); + r03.storePacket(0,r03.template loadPacket(0) + acc1.packet[3]); + + r10.storePacket(0,r10.template loadPacket(0) + acc2.packet[0]); + r11.storePacket(0,r11.template loadPacket(0) + acc2.packet[1]); + r12.storePacket(0,r12.template loadPacket(0) + acc2.packet[2]); + r13.storePacket(0,r13.template loadPacket(0) + acc2.packet[3]); + + r20.storePacket(0,r20.template loadPacket(0) + acc3.packet[0]); + r21.storePacket(0,r21.template loadPacket(0) + acc3.packet[1]); + r22.storePacket(0,r22.template loadPacket(0) + acc3.packet[2]); + r23.storePacket(0,r23.template loadPacket(0) + acc3.packet[3]); + } + for(; row + 2*accLhsProgress <= lhsMap.get_packed_size(); row+=2*accLhsProgress) + { + const LhsScalar *lhs_ptr1 = lhsMap.get_packed_at(row + 0*accLhsProgress); + const LhsScalar *lhs_ptr2 = lhsMap.get_packed_at(row + 1*accLhsProgress); + const RhsScalar *rhs_ptr = rhsMap.get_packed_at(col/accRhsProgress); + + PacketBlock acc1; + acc1.packet[0] = pset1(0); + acc1.packet[1] = pset1(0); + acc1.packet[2] = pset1(0); + acc1.packet[3] = pset1(0); + + PacketBlock acc2; + acc2.packet[0] = pset1(0); + acc2.packet[1] = pset1(0); + acc2.packet[2] = pset1(0); + acc2.packet[3] = pset1(0); + + LinearMapper r00 = res.getLinearMapper(row + 0*accLhsProgress, col + 0); + LinearMapper r01 = res.getLinearMapper(row + 0*accLhsProgress, col + 1); + LinearMapper r02 = res.getLinearMapper(row + 0*accLhsProgress, col + 2); + LinearMapper r03 = res.getLinearMapper(row + 0*accLhsProgress, col + 3); + + LinearMapper r10 = res.getLinearMapper(row + 1*accLhsProgress, col + 0); + LinearMapper r11 = res.getLinearMapper(row + 1*accLhsProgress, col + 1); + LinearMapper r12 = res.getLinearMapper(row + 1*accLhsProgress, col + 2); + LinearMapper r13 = res.getLinearMapper(row + 1*accLhsProgress, col + 3); + + auto k = 0; + for(; k < depth; k++) + { + RhsPacket prhs = pload(rhs_ptr); + PacketBlock pbrhs; + pbrhs.packet[0] = pset1(prhs[0]); + pbrhs.packet[1] = pset1(prhs[1]); + pbrhs.packet[2] = pset1(prhs[2]); + pbrhs.packet[3] = pset1(prhs[3]); + + LhsPacket plhs1 = pload(lhs_ptr1); + LhsPacket plhs2 = pload(lhs_ptr2); + + acc1.packet[0] += plhs1*pbrhs.packet[0]; + acc1.packet[1] += plhs1*pbrhs.packet[1]; + acc1.packet[2] += plhs1*pbrhs.packet[2]; + acc1.packet[3] += plhs1*pbrhs.packet[3]; + + acc2.packet[0] += plhs2*pbrhs.packet[0]; + acc2.packet[1] += plhs2*pbrhs.packet[1]; + acc2.packet[2] += plhs2*pbrhs.packet[2]; + acc2.packet[3] += plhs2*pbrhs.packet[3]; + + lhs_ptr1 += (rows/accLhsProgress)*accLhsProgress; + lhs_ptr2 += (rows/accLhsProgress)*accLhsProgress; + rhs_ptr += accRhsProgress; + } + + r00.storePacket(0,r00.template loadPacket(0) + acc1.packet[0]); + r01.storePacket(0,r01.template loadPacket(0) + acc1.packet[1]); + r02.storePacket(0,r02.template loadPacket(0) + acc1.packet[2]); + r03.storePacket(0,r03.template loadPacket(0) + acc1.packet[3]); + + r10.storePacket(0,r10.template loadPacket(0) + acc2.packet[0]); + r11.storePacket(0,r11.template loadPacket(0) + acc2.packet[1]); + r12.storePacket(0,r12.template loadPacket(0) + acc2.packet[2]); + r13.storePacket(0,r13.template loadPacket(0) + acc2.packet[3]); + } + for(; row + accLhsProgress <= lhsMap.get_packed_size(); row+=accLhsProgress) + { + const LhsScalar *lhs_ptr = lhsMap.get_packed_at(row); + const RhsScalar *rhs_ptr = rhsMap.get_packed_at(col/accRhsProgress); + PacketBlock acc; + acc.packet[0] = pset1(0); + acc.packet[1] = pset1(0); + acc.packet[2] = pset1(0); + acc.packet[3] = pset1(0); + + LinearMapper r0 = res.getLinearMapper(row, col + 0); + LinearMapper r1 = res.getLinearMapper(row, col + 1); + LinearMapper r2 = res.getLinearMapper(row, col + 2); + LinearMapper r3 = res.getLinearMapper(row, col + 3); + + auto k = 0; + for(; k < depth; k++) + { + RhsPacket prhs = pload(rhs_ptr); + PacketBlock pbrhs; + pbrhs.packet[0] = pset1(prhs[0]); + pbrhs.packet[1] = pset1(prhs[1]); + pbrhs.packet[2] = pset1(prhs[2]); + pbrhs.packet[3] = pset1(prhs[3]); + + LhsPacket plhs = pload(lhs_ptr); + +#ifdef __NDEBUG__ + std::cout << "(" << row << "," << k << "," << col << ")" << std::endl; + std::cout << "lhs " << plhs[0] << " " << plhs[1] << " " << plhs[2] << " " << plhs[3] << std::endl; + std::cout << "rhs " << prhs[0] << " " << prhs[1] << " " << prhs[2] << " " << prhs[3] << std::endl; +#endif + acc.packet[0] += plhs*pbrhs.packet[0]; + acc.packet[1] += plhs*pbrhs.packet[1]; + acc.packet[2] += plhs*pbrhs.packet[2]; + acc.packet[3] += plhs*pbrhs.packet[3]; + + lhs_ptr += (rows/accLhsProgress)*accLhsProgress; + rhs_ptr += accRhsProgress; + } + + r0.storePacket(0,r0.template loadPacket(0) + acc.packet[0]); + r1.storePacket(0,r1.template loadPacket(0) + acc.packet[1]); + r2.storePacket(0,r2.template loadPacket(0) + acc.packet[2]); + r3.storePacket(0,r3.template loadPacket(0) + acc.packet[3]); + } + auto row_residue = 0; + for(;row < rows; row++) + { + const LhsScalar *lhs_ptr = lhsMap.get_residue_at(row_residue); + const RhsScalar *rhs_ptr = rhsMap.get_packed_at(col/accRhsProgress); + PacketBlock acc; + acc.packet[0] = pset1(0); + + auto k = 0; + for(; k < depth; k++) + { + RhsPacket prhs = pload(rhs_ptr); + LhsPacket plhs = pset1(*lhs_ptr); + +#ifdef __NDEBUG__ + std::cout << "(" << row << "," << k << "," << col << ")" << std::endl; + std::cout << "lhs " << plhs[0] << " " << plhs[1] << " " << plhs[2] << " " << plhs[3] << std::endl; + std::cout << "rhs " << prhs[0] << " " << prhs[1] << " " << prhs[2] << " " << prhs[3] << std::endl; +#endif + acc.packet[0] += (*lhs_ptr)*prhs; + + lhs_ptr++; + rhs_ptr += accRhsProgress; + } + + res(row, col + 0) += acc.packet[0][0]; + res(row, col + 1) += acc.packet[0][1]; + res(row, col + 2) += acc.packet[0][2]; + res(row, col + 3) += acc.packet[0][3]; + row_residue++; + } + } + auto col_residue = 0; + for(; col < cols; col++) + { + auto row = 0; + for(; row + accLhsProgress <= lhsMap.get_packed_size(); row+=accLhsProgress) + { + const LhsScalar *lhs_ptr = lhsMap.get_packed_at(row); + const RhsScalar *rhs_ptr = rhsMap.get_residue_at(col_residue); + PacketBlock acc; + acc.packet[0] = pset1(0); + + LinearMapper r0 = res.getLinearMapper(row, col + 0); + + auto k = 0; + for(; k < depth; k++) + { + RhsPacket prhs = pset1(*rhs_ptr); + + LhsPacket plhs = pload(lhs_ptr); + +#ifdef __NDEBUG__ + std::cout << "(" << row << "," << k << "," << col << ")" << std::endl; + std::cout << "lhs " << plhs[0] << " " << plhs[1] << " " << plhs[2] << " " << plhs[3] << std::endl; + std::cout << "rhs " << prhs[0] << " " << prhs[1] << " " << prhs[2] << " " << prhs[3] << std::endl; +#endif + acc.packet[0] += plhs*prhs; + + lhs_ptr += (rows/accLhsProgress)*accLhsProgress; + rhs_ptr++; + } + + r0.storePacket(0,r0.template loadPacket(0) + acc.packet[0]); + } + auto row_residue = 0; + for(;row < rows; row++) + { + const LhsScalar *lhs_ptr = lhsMap.get_residue_at(row_residue); + const RhsScalar *rhs_ptr = rhsMap.get_residue_at(col_residue); + AccScalar acc = 0; + auto k = 0; + for(; k < depth; k++) + { +#ifdef __NDEBUG__ + std::cout << "(" << row << "," << k << "," << col << ")" << std::endl; + std::cout << "lhs " << plhs[0] << " " << plhs[1] << " " << plhs[2] << " " << plhs[3] << std::endl; + std::cout << "rhs " << prhs[0] << " " << prhs[1] << " " << prhs[2] << " " << prhs[3] << std::endl; +#endif + acc += (*lhs_ptr)*(*rhs_ptr); + + lhs_ptr++; + rhs_ptr++; + } + + //r0.storePacket(0,r0.template loadPacket(0) + acc.packet[0]); + res(row, col) += acc; + row_residue++; + } + col_residue++; + } +} + +template +EIGEN_STRONG_INLINE void gemm_old(const DataMapper& res, const LhsScalar* blockA, const RhsScalar* blockB, + Index rows, Index depth, Index cols, ResScalar alpha, Index strideA, Index strideB, Index offsetA, Index offsetB) +{ + using AccPacket = typename packet_traits::type; + using LhsPacket = typename packet_traits::type; + using RhsPacket = typename packet_traits::type; + using ResPacket = typename packet_traits::type; + + ResPacket pAlpha = pset1(alpha); + +#ifdef __DEBUG__ + std::cout << "blockA" << std::endl; + for(auto i = 0; i < rows*depth; i++) + { + if(i % 4 == 0 && i > 0) + std::cout << std::endl; + std::cout << blockA[i] << " "; + } + std::cout << std::endl; + std::cout << "blockB" << std::endl; + for(auto i = 0; i < depth*cols; i++) + { + if(i % 4 == 0 && i > 0) + std::cout << std::endl; + std::cout << blockB[i] << " "; + } + std::cout << std::endl; +#endif + + if( strideA == -1 ) strideA = depth; + if( strideB == -1 ) strideB = depth; + + int accLhsProgress = 4; + int accRhsProgress = 4; + + PackMap lhsMap(blockA, rows, depth, offsetA, strideA); + PackMap rhsMap(blockB, depth, cols, offsetB, strideB); + auto col = 0; + for(; col < rhsMap.get_packed_size(); col+=accRhsProgress) + { + for(auto k = 0; k < depth; k++) + { + const LhsScalar *lhs_ptr = lhsMap.get_packed_at(k); + const RhsScalar *rhs_ptr = rhsMap.get_packed_at(col/accRhsProgress) + k*accRhsProgress; + PacketBlock acc; + RhsPacket prhs = pload(rhs_ptr); + PacketBlock pbrhs; + pbrhs.packet[0] = pset1(prhs[0]); + pbrhs.packet[1] = pset1(prhs[1]); + pbrhs.packet[2] = pset1(prhs[2]); + pbrhs.packet[3] = pset1(prhs[3]); + auto row = 0; + using LinearMapper = typename DataMapper::LinearMapper; + for(; row < lhsMap.get_packed_size(); row+=accLhsProgress) + { + LinearMapper r0 = res.getLinearMapper(row, col + 0); + LinearMapper r1 = res.getLinearMapper(row, col + 1); + LinearMapper r2 = res.getLinearMapper(row, col + 2); + LinearMapper r3 = res.getLinearMapper(row, col + 3); + + LhsPacket plhs = pload(lhs_ptr); +#ifdef __NDEBUG__ + std::cout << "(" << row << "," << k << "," << col << ")" << std::endl; + std::cout << "lhs " << plhs[0] << " " << plhs[1] << " " << plhs[2] << " " << plhs[3] << std::endl; + std::cout << "rhs " << prhs[0] << " " << prhs[1] << " " << prhs[2] << " " << prhs[3] << std::endl; +#endif + acc.packet[0] = plhs*pbrhs.packet[0]; + acc.packet[1] = plhs*pbrhs.packet[1]; + acc.packet[2] = plhs*pbrhs.packet[2]; + acc.packet[3] = plhs*pbrhs.packet[3]; + + r0.storePacket(0,r0.template loadPacket(0) + acc.packet[0]); + r1.storePacket(0,r1.template loadPacket(0) + acc.packet[1]); + r2.storePacket(0,r2.template loadPacket(0) + acc.packet[2]); + r3.storePacket(0,r3.template loadPacket(0) + acc.packet[3]); + lhs_ptr += accLhsProgress; + } + auto residue = 0; + for(;row < rows; row++) + { + LhsScalar lhs = *(lhsMap.get_residue_at(residue) + k); +#ifdef __NDEBUG__ + std::cout << "(" << row << "," << k << "," << col << ")" << std::endl; + std::cout << "lhs " << lhs << " (" << prhs[0] << " " << prhs[1] << " " << prhs[2] << " " << prhs[3] << ")" << std::endl; +#endif + res(row, col + 0) += lhs*prhs[0]; + res(row, col + 1) += lhs*prhs[1]; + res(row, col + 2) += lhs*prhs[2]; + res(row, col + 3) += lhs*prhs[3]; + residue++; + } + } + } + auto colResidue = 0; + for(;col < cols; col++) + { + for(auto k = 0; k < depth; k++) + { + const LhsScalar *lhs_ptr = lhsMap.get_packed_at(k); + const RhsScalar *rhs_ptr = rhsMap.get_residue_at(colResidue) + k; + AccPacket acc; + + RhsPacket prhs = pset1(*rhs_ptr); + + auto row = 0; + using LinearMapper = typename DataMapper::LinearMapper; + for(; row < lhsMap.get_packed_size(); row+=accLhsProgress) + { + LinearMapper r0 = res.getLinearMapper(row, col + 0); + + LhsPacket plhs = pload(lhs_ptr); +#ifdef __DEBUG__ + std::cout << "(" << row << "," << k << "," << col << ")" << std::endl; + std::cout << "lhs " << plhs[0] << " " << plhs[1] << " " << plhs[2] << " " << plhs[3] << std::endl; + std::cout << "rhs " << prhs[0] << " " << prhs[1] << " " << prhs[2] << " " << prhs[3] << std::endl; +#endif + acc = plhs*prhs; + + r0.storePacket(0,r0.template loadPacket(0) + acc); + lhs_ptr += accLhsProgress; + } + auto residue = 0; + for(;row < rows; row++) + { + LhsScalar lhs = *(lhsMap.get_residue_at(residue) + k); +#ifdef __DEBUG__ + std::cout << "(" << row << "," << k << "," << col << ")" << std::endl; + std::cout << "lhs " << lhs << " (" << prhs[0] << " " << prhs[1] << " " << prhs[2] << " " << prhs[3] << ")" << std::endl; +#endif + res(row, col + 0) += lhs*prhs[0]; + residue++; + } + } + colResidue++; + } +} +#endif + +template +constexpr int SHAPES_COUNT = 3; + +constexpr int SHAPES_DIMENSION = 4; +constexpr int SHAPES_LHS_DIMENSION = 0; +constexpr int SHAPES_DEP_DIMENSION = 1; +constexpr int SHAPES_RHS_DIMENSION = 2; +constexpr int SHAPES_POINTER = 3; +constexpr int SHAPES_POINTER_END = -1; + +template +constexpr int PACK_SHAPES_COUNT = 3; +constexpr int PACK_SHAPES_DIMENSION = 3; +constexpr int PACK_SHAPES_POINTER = 2; +constexpr int PACK_SHAPES_END = -1; + + +// lhs_progress x depth_progress x rhs_progress (depth_progress > 1 matrix ops) x pointer to next rhs_progress on the shapes map +template +constexpr int SHAPES[SHAPES_COUNT][SHAPES_DIMENSION] = {{1,1,1,SHAPES_POINTER_END},{4,1,4,0},{8,1,8,1}}; +//constexpr int SHAPES[SHAPES_COUNT][SHAPES_DIMENSION] = {{1,1,1,SHAPES_POINTER_END},{2,1,1,0},{1,1,2,1},{2,1,2,1},{2,2,2,1}}; + +// d1progress x d2progress +template +constexpr int PACK_SHAPES[PACK_SHAPES_COUNT][PACK_SHAPES_DIMENSION] = {{1,1,PACK_SHAPES_END},{4,1,0},{8,1,1}}; + +template +constexpr int PACK_SHAPES[PACK_SHAPES_COUNT][PACK_SHAPES_DIMENSION] = {{1,1,PACK_SHAPES_END},{4,1,0},{8,1,1}}; + +template +struct PackingOperator +{ + EIGEN_STRONG_INLINE void operator()(Index d1Idx, Index d2Idx, Scalar **block, const DataMapper& data) + { + std::cout << M << "x" << N << " ( " << d1Idx << ", " << d2Idx <<") -> ( " << d1Idx + M << ", " << d2Idx + N << ")" << std::endl; + Scalar *c = *block; + for(auto i = 0; i < M; i++) + for(auto j = 0; j < N; j++) + { + *c = data(d1Idx + i, d2Idx + j); + c++; + } + + *block = c; + } +}; + +template +struct PackingInnerStruct +{ + EIGEN_STRONG_INLINE void operator()(Index d1Idx, Index d2Idx, Scalar *block, const DataMapper& data, Index d1Size, Index d2Size, Index stride, Index offset) + { + constexpr auto d2Progress = PACK_SHAPES[IDX][1]; + PackingOperator po; + + for(;d2Idx + d2Progress <= d2Size; d2Idx+=d2Progress) + { + po(d1Idx, d2Idx, &block, data); + } + + if(PACK_SHAPES[IDX-1][0] == D1PROGRESS) + { + PackingInnerStruct pis; + pis(d1Idx, d2Idx, block, data, d1Size, d2Size, stride, offset); + } + } +}; + +template +struct PackingInnerStruct +{ + EIGEN_STRONG_INLINE void operator()(Index d1Idx, Index d2Idx, Scalar *block, const DataMapper& data, Index d1Size, Index d2Size, Index stride, Index offset) + { + constexpr auto d2Progress = PACK_SHAPES[0][1]; + for(;d2Idx < d2Size; d2Idx++) + { + PackingOperator po; + po(d1Idx, d2Idx, &block, data); + } + } +}; + +template +struct PackingStruct +{ + PackingStruct[PACK_SHAPE_IDX][PACK_SHAPES_POINTER]> ps; + + EIGEN_STRONG_INLINE void operator()(Index d1Idx, Scalar *block, const DataMapper& data, Index d1Size, Index d2Size, Index stride, Index offset) + { + constexpr auto d1Progress = PACK_SHAPES[PACK_SHAPE_IDX][0]; + + for(; d1Idx + d1Progress <= d1Size; d1Idx += d1Progress) + { + PackingInnerStruct pis; + pis(d1Idx, 0, block, data, d1Size, d2Size, stride, offset); + } + ps(d1Idx, block, data, d1Size, d2Size, stride, offset); + } +}; + +template +struct PackingStruct +{ + EIGEN_STRONG_INLINE void operator()(Index, Scalar *, const DataMapper&, Index, Index, Index, Index) {} +}; + +template +struct lhs_pack +{ + EIGEN_STRONG_INLINE void operator()(Scalar *blockA, const DataMapper &lhs, Index depth, Index rows, Index stride, Index offset) + { + PackingStruct-1> ps; + ps(0, blockA, lhs, rows, depth, stride, offset); + } +}; + +template +struct MicroKernel +{ + EIGEN_STRONG_INLINE void operator()(const LhsScalar** ppLhs,const RhsScalar** ppRhs, Index rowIdx, Index colIdx, Index depthIdx) + { + const LhsScalar *pLhs = *ppLhs; + const RhsScalar *pRhs = *ppRhs; + + std::cout << "Kernel " << M << " x " << K << " x " << N << " @ " << rowIdx << ", " << depthIdx << ", " << colIdx << std::endl; + std::cout << "LHS "; + for(auto i = rowIdx; i < M + rowIdx; i++) + { + for(auto j = depthIdx; j < K + depthIdx; j++) + { + std::cout << *pLhs << " "; + pLhs++; + } + } + std::cout << std::endl << "RHS "; + for(auto i = depthIdx; i < K + depthIdx; i++) + { + for(auto j = colIdx; j < N + colIdx; j++) + { + std::cout << *pRhs << " "; + pRhs++; + } + } + std::cout << std::endl; + *ppLhs += M*K; + *ppRhs += N*K; + }; +}; + +template +struct DepthLoopStruct +{ + DepthLoopStruct depthLS; + EIGEN_STRONG_INLINE void operator()(Index rowIdx, Index colIdx, Index depthIdx, const DataMapper& res, const LhsScalar* blockA, const RhsScalar*blockB, + Index rows, Index depth, Index cols, ResScalar alpha, Index strideA, Index strideB, Index offsetA, Index offsetB) + { + constexpr auto rhsProgress = SHAPES[RHS_SHAPE_IDX][SHAPES_RHS_DIMENSION]; + constexpr auto lhsProgress = SHAPES[LHS_SHAPE_IDX][SHAPES_LHS_DIMENSION]; + constexpr auto depthProgress = SHAPES[IDX][SHAPES_DEP_DIMENSION]; + + if(rhsProgress == SHAPES[IDX][SHAPES_RHS_DIMENSION] && lhsProgress == SHAPES[IDX][SHAPES_LHS_DIMENSION]) + { + MicroKernel mkt; + for(; depthIdx + depthProgress <= depth; depthIdx+=depthProgress) + { + mkt(&blockA, &blockB, rowIdx, colIdx, depthIdx); + } + } + depthLS(rowIdx, colIdx, depthIdx, res, blockA, blockB, rows, depth, cols, alpha, strideA, strideB, offsetA, offsetB); + } +}; + +template +struct DepthLoopStruct +{ + EIGEN_STRONG_INLINE void operator()(Index, Index, Index, const DataMapper&, const LhsScalar*, const RhsScalar*, + Index, Index, Index, ResScalar, Index, Index, Index, Index) {} +}; + +template +struct LhsLoopStruct +{ + LhsLoopStruct lhsLS; + EIGEN_STRONG_INLINE void operator()(Index rowIdx, int colIdx, const DataMapper& res, const LhsScalar* blockA, const RhsScalar*blockB, + Index rows, Index depth, Index cols, ResScalar alpha, Index strideA, Index strideB, Index offsetA, Index offsetB) + { + constexpr auto lhsProgress = SHAPES[IDX][SHAPES_LHS_DIMENSION]; + + DepthLoopStruct depthLS; + for(;rowIdx + lhsProgress <= rows; rowIdx+=lhsProgress) + { + depthLS(rowIdx, colIdx, 0, res, blockA, blockB, rows, depth, cols, alpha, strideA, strideB, offsetA, offsetB); + } + lhsLS(rowIdx, colIdx, res, blockA, blockB, rows, depth, cols, alpha, strideA, strideB, offsetA, offsetB); + } +}; + +template +struct LhsLoopStruct +{ + EIGEN_STRONG_INLINE void operator()(Index, Index, const DataMapper&, const LhsScalar*, const RhsScalar*, + Index, Index, Index, ResScalar, Index, Index, Index, Index) {} +}; + +template +struct RhsLoopStruct +{ + static constexpr auto PREVIOUS = SHAPES[IDX][SHAPES_POINTER]; + RhsLoopStruct rhsLS; + + EIGEN_STRONG_INLINE void operator()(Index colIdx, const DataMapper& res, const LhsScalar* blockA, const RhsScalar*blockB, + Index rows, Index depth, Index cols, ResScalar alpha, Index strideA, Index strideB, Index offsetA, Index offsetB) + { + constexpr auto rhsProgress = SHAPES[IDX][SHAPES_RHS_DIMENSION]; + + std::cout << __PRETTY_FUNCTION__ << std::endl; + for(;colIdx + rhsProgress <= cols; colIdx+=rhsProgress) + { + LhsLoopStruct lhsLS; + lhsLS(0, colIdx, res, blockA, blockB, rows, depth, cols, alpha, strideA, strideB, offsetA, offsetB); + } + rhsLS(colIdx, res, blockA, blockB, rows, depth, cols, alpha, strideA, strideB, offsetA, offsetB); + } +}; + +template +struct RhsLoopStruct +{ + EIGEN_STRONG_INLINE void operator()(Index colIdx, const DataMapper&, const LhsScalar*, const RhsScalar*, + Index, Index, Index, ResScalar, Index, Index, Index, Index) {} +}; + +template +EIGEN_STRONG_INLINE void gemm(const DataMapper& res, const LhsScalar* blockA, const RhsScalar* blockB, + Index rows, Index depth, Index cols, ResScalar alpha, Index strideA, Index strideB, Index offsetA, Index offsetB) +{ + RhsLoopStruct<0, 0, Index, LhsScalar, RhsScalar, AccScalar, ResScalar, DataMapper, SHAPES_COUNT<0, 0, LhsScalar, RhsScalar>-1> rhsLS; + rhsLS(0, res, blockA, blockB, rows, depth, cols, alpha, strideA, strideB, offsetA, offsetB); +} + +template +struct gemm_pack_lhs +{ + void operator()(float* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride=0, Index offset=0); +}; + +template +void gemm_pack_lhs + ::operator()(float* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride, Index offset) +{ + +} + +template +struct gemm_pack_lhs +{ + void operator()(float* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride=0, Index offset=0); +}; + +template +void gemm_pack_lhs + ::operator()(float* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride, Index offset) +{ + lhs_pack<0, 0, Index, float, DataMapper, Conjugate, PanelMode, ColMajor> pack; + pack(blockA, lhs, depth, rows, stride, offset); +} + +template +struct gebp_kernel +{ + void operator()(const DataMapper& res, const float* blockA, const float* blockB, + Index rows, Index depth, Index cols, float alpha, + Index strideA=-1, Index strideB=-1, Index offsetA=0, Index offsetB=0); +}; + +template +void gebp_kernel + ::operator()(const DataMapper& res, const float* blockA, const float* blockB, + Index rows, Index depth, Index cols, float alpha, + Index strideA, Index strideB, Index offsetA, Index offsetB) + { + gemm(res, blockA, blockB, rows, depth, cols, alpha, strideA, strideB, offsetA, offsetB); + } +} // end namespace internal + +} // end namespace Eigen +#endif // EIGEN_MATRIX_PRODUCT_NEON_H \ No newline at end of file