From 9b51dc7972c9f64727e9c8e8db0c60aaf9aae532 Mon Sep 17 00:00:00 2001 From: Chip Kerchner Date: Wed, 17 Feb 2021 17:49:23 +0000 Subject: [PATCH] Fixed performance issues for VSX and P10 MMA in general_matrix_matrix_product --- Eigen/src/Core/arch/AltiVec/MatrixProduct.h | 1330 +++++++++-------- .../src/Core/arch/AltiVec/MatrixProductMMA.h | 389 +++-- Eigen/src/Core/arch/AltiVec/PacketMath.h | 11 +- 3 files changed, 986 insertions(+), 744 deletions(-) diff --git a/Eigen/src/Core/arch/AltiVec/MatrixProduct.h b/Eigen/src/Core/arch/AltiVec/MatrixProduct.h index 53116ad89..9d9bbebe5 100644 --- a/Eigen/src/Core/arch/AltiVec/MatrixProduct.h +++ b/Eigen/src/Core/arch/AltiVec/MatrixProduct.h @@ -70,7 +70,7 @@ struct quad_traits // MatrixProduct decomposes real/imaginary vectors into a real vector and an imaginary vector, this turned out // to be faster than Eigen's usual approach of having real/imaginary pairs on a single vector. This constants then // are responsible to extract from convert between Eigen's and MatrixProduct approach. -const static Packet4f p4f_CONJUGATE = {-1.0f, -1.0f, -1.0f, -1.0f}; +const static Packet4f p4f_CONJUGATE = {float(-1.0), float(-1.0), float(-1.0), float(-1.0)}; const static Packet2d p2d_CONJUGATE = {-1.0, -1.0}; @@ -122,7 +122,7 @@ EIGEN_STRONG_INLINE std::complex getAdjointVal(Index i, Index j, const_b v.imag(dt(i,j).imag()); } else { v.real(dt(i,j).real()); - v.imag((Scalar)0.0f); + v.imag((Scalar)0.0); } return v; } @@ -136,7 +136,7 @@ EIGEN_STRONG_INLINE void symm_pack_complex_rhs_helper(std::complex *bloc Scalar* blockBf = reinterpret_cast(blockB); Index ri = 0, j = 0; - for(; j + vectorSize < cols; j+=vectorSize) + for(; j + vectorSize <= cols; j+=vectorSize) { Index i = k2; for(; i < depth; i++) @@ -192,7 +192,7 @@ EIGEN_STRONG_INLINE void symm_pack_complex_lhs_helper(std::complex *bloc Index ri = 0, j = 0; Scalar *blockAf = (Scalar *)(blockA); - for(; j + vectorSize < rows; j+=vectorSize) + for(; j + vectorSize <= rows; j+=vectorSize) { Index i = 0; @@ -247,7 +247,7 @@ EIGEN_STRONG_INLINE void symm_pack_rhs_helper(Scalar *blockB, const Scalar* _rhs const int vectorSize = quad_traits::vectorsize; Index ri = 0, j = 0; - for(; j + N*vectorSize < cols; j+=N*vectorSize) + for(; j + N*vectorSize <= cols; j+=N*vectorSize) { Index i = k2; for(; i < depth; i++) @@ -284,7 +284,7 @@ EIGEN_STRONG_INLINE void symm_pack_lhs_helper(Scalar *blockA, const Scalar* _lhs const int vectorSize = quad_traits::vectorsize; Index ri = 0, j = 0; - for(j = 0; j + vectorSize < rows; j+=vectorSize) + for(j = 0; j + vectorSize <= rows; j+=vectorSize) { Index i = 0; @@ -410,15 +410,15 @@ struct lhs_cpack { const int vectorSize = quad_traits::vectorsize; Index ri = 0, j = 0; Scalar *blockAt = reinterpret_cast(blockA); - Packet conj = pset1((Scalar)-1.0f); + Packet conj = pset1((Scalar)-1.0); - for(j = 0; j + vectorSize < rows; j+=vectorSize) + for(j = 0; j + vectorSize <= rows; j+=vectorSize) { Index i = 0; if(PanelMode) ri += vectorSize*offset; - for(; i + vectorSize < depth; i+=vectorSize) + for(; i + vectorSize <= depth; i+=vectorSize) { PacketBlock block; @@ -446,10 +446,10 @@ struct lhs_cpack { cblock.packet[7] = lhs.template loadPacket(j + 3, i + 2); } - block.packet[0] = vec_perm(cblock.packet[0].v , cblock.packet[4].v, p16uc_GETREAL32); - block.packet[1] = vec_perm(cblock.packet[1].v , cblock.packet[5].v, p16uc_GETREAL32); - block.packet[2] = vec_perm(cblock.packet[2].v , cblock.packet[6].v, p16uc_GETREAL32); - block.packet[3] = vec_perm(cblock.packet[3].v , cblock.packet[7].v, p16uc_GETREAL32); + block.packet[0] = vec_perm(cblock.packet[0].v, cblock.packet[4].v, p16uc_GETREAL32); + block.packet[1] = vec_perm(cblock.packet[1].v, cblock.packet[5].v, p16uc_GETREAL32); + block.packet[2] = vec_perm(cblock.packet[2].v, cblock.packet[6].v, p16uc_GETREAL32); + block.packet[3] = vec_perm(cblock.packet[3].v, cblock.packet[7].v, p16uc_GETREAL32); if(StorageOrder == RowMajor) ptranspose(block); @@ -475,7 +475,7 @@ struct lhs_cpack { if(PanelMode) ri += vectorSize*offset; - for(; i + vectorSize < depth; i+=vectorSize) + for(; i + vectorSize <= depth; i+=vectorSize) { PacketBlock cblock; if(StorageOrder == ColMajor) @@ -502,10 +502,10 @@ struct lhs_cpack { } PacketBlock block; - block.packet[0] = vec_perm(cblock.packet[0].v , cblock.packet[4].v, p16uc_GETIMAG32); - block.packet[1] = vec_perm(cblock.packet[1].v , cblock.packet[5].v, p16uc_GETIMAG32); - block.packet[2] = vec_perm(cblock.packet[2].v , cblock.packet[6].v, p16uc_GETIMAG32); - block.packet[3] = vec_perm(cblock.packet[3].v , cblock.packet[7].v, p16uc_GETIMAG32); + block.packet[0] = vec_perm(cblock.packet[0].v, cblock.packet[4].v, p16uc_GETIMAG32); + block.packet[1] = vec_perm(cblock.packet[1].v, cblock.packet[5].v, p16uc_GETIMAG32); + block.packet[2] = vec_perm(cblock.packet[2].v, cblock.packet[6].v, p16uc_GETIMAG32); + block.packet[3] = vec_perm(cblock.packet[3].v, cblock.packet[7].v, p16uc_GETIMAG32); if(Conjugate) { @@ -585,13 +585,13 @@ struct lhs_pack{ const int vectorSize = quad_traits::vectorsize; Index ri = 0, j = 0; - for(j = 0; j + vectorSize < rows; j+=vectorSize) + for(j = 0; j + vectorSize <= rows; j+=vectorSize) { Index i = 0; if(PanelMode) ri += vectorSize*offset; - for(; i + vectorSize < depth; i+=vectorSize) + for(; i + vectorSize <= depth; i+=vectorSize) { PacketBlock block; @@ -637,13 +637,16 @@ struct lhs_pack{ if(PanelMode) ri += offset*(rows - j); - for(Index i = 0; i < depth; i++) + if (j < rows) { - Index k = j; - for(; k < rows; k++) + for(Index i = 0; i < depth; i++) { - blockA[ri] = lhs(k, i); - ri += 1; + Index k = j; + for(; k < rows; k++) + { + blockA[ri] = lhs(k, i); + ri += 1; + } } } @@ -659,16 +662,16 @@ struct rhs_cpack { const int vectorSize = quad_traits::vectorsize; Scalar *blockBt = reinterpret_cast(blockB); - Packet conj = pset1((Scalar)-1.0f); + Packet conj = pset1((Scalar)-1.0); Index ri = 0, j = 0; - for(; j + vectorSize < cols; j+=vectorSize) + for(; j + vectorSize <= cols; j+=vectorSize) { Index i = 0; if(PanelMode) ri += offset*vectorSize; - for(; i + vectorSize < depth; i+=vectorSize) + for(; i + vectorSize <= depth; i+=vectorSize) { PacketBlock cblock; if(StorageOrder == ColMajor) @@ -695,10 +698,10 @@ struct rhs_cpack } PacketBlock block; - block.packet[0] = vec_perm(cblock.packet[0].v , cblock.packet[4].v, p16uc_GETREAL32); - block.packet[1] = vec_perm(cblock.packet[1].v , cblock.packet[5].v, p16uc_GETREAL32); - block.packet[2] = vec_perm(cblock.packet[2].v , cblock.packet[6].v, p16uc_GETREAL32); - block.packet[3] = vec_perm(cblock.packet[3].v , cblock.packet[7].v, p16uc_GETREAL32); + block.packet[0] = vec_perm(cblock.packet[0].v, cblock.packet[4].v, p16uc_GETREAL32); + block.packet[1] = vec_perm(cblock.packet[1].v, cblock.packet[5].v, p16uc_GETREAL32); + block.packet[2] = vec_perm(cblock.packet[2].v, cblock.packet[6].v, p16uc_GETREAL32); + block.packet[3] = vec_perm(cblock.packet[3].v, cblock.packet[7].v, p16uc_GETREAL32); if(StorageOrder == ColMajor) ptranspose(block); @@ -724,7 +727,7 @@ struct rhs_cpack if(PanelMode) ri += offset*vectorSize; - for(; i + vectorSize < depth; i+=vectorSize) + for(; i + vectorSize <= depth; i+=vectorSize) { PacketBlock cblock; if(StorageOrder == ColMajor) @@ -752,10 +755,10 @@ struct rhs_cpack } PacketBlock block; - block.packet[0] = vec_perm(cblock.packet[0].v , cblock.packet[4].v, p16uc_GETIMAG32); - block.packet[1] = vec_perm(cblock.packet[1].v , cblock.packet[5].v, p16uc_GETIMAG32); - block.packet[2] = vec_perm(cblock.packet[2].v , cblock.packet[6].v, p16uc_GETIMAG32); - block.packet[3] = vec_perm(cblock.packet[3].v , cblock.packet[7].v, p16uc_GETIMAG32); + block.packet[0] = vec_perm(cblock.packet[0].v, cblock.packet[4].v, p16uc_GETIMAG32); + block.packet[1] = vec_perm(cblock.packet[1].v, cblock.packet[5].v, p16uc_GETIMAG32); + block.packet[2] = vec_perm(cblock.packet[2].v, cblock.packet[6].v, p16uc_GETIMAG32); + block.packet[3] = vec_perm(cblock.packet[3].v, cblock.packet[7].v, p16uc_GETIMAG32); if(Conjugate) { @@ -832,13 +835,14 @@ struct rhs_pack { { const int vectorSize = quad_traits::vectorsize; Index ri = 0, j = 0; - for(; j + vectorSize < cols; j+=vectorSize) + + for(; j + vectorSize <= cols; j+=vectorSize) { Index i = 0; if(PanelMode) ri += offset*vectorSize; - for(; i + vectorSize < depth; i+=vectorSize) + for(; i + vectorSize <= depth; i+=vectorSize) { PacketBlock block; if(StorageOrder == ColMajor) @@ -883,13 +887,16 @@ struct rhs_pack { if(PanelMode) ri += offset*(cols - j); - for(Index i = 0; i < depth; i++) + if (j < cols) { - Index k = j; - for(; k < cols; k++) + for(Index i = 0; i < depth; i++) { - blockB[ri] = rhs(i, k); - ri += 1; + Index k = j; + for(; k < cols; k++) + { + blockB[ri] = rhs(i, k); + ri += 1; + } } } if(PanelMode) ri += (cols - j)*(stride - offset - depth); @@ -905,13 +912,13 @@ struct lhs_pack const int vectorSize = quad_traits::vectorsize; Index ri = 0, j = 0; - for(j = 0; j + vectorSize < rows; j+=vectorSize) + for(j = 0; j + vectorSize <= rows; j+=vectorSize) { Index i = 0; if(PanelMode) ri += vectorSize*offset; - for(; i + vectorSize < depth; i+=vectorSize) + for(; i + vectorSize <= depth; i+=vectorSize) { PacketBlock block; if(StorageOrder == RowMajor) @@ -970,12 +977,12 @@ struct rhs_pack { const int vectorSize = quad_traits::vectorsize; Index ri = 0, j = 0; - for(; j + 2*vectorSize < cols; j+=2*vectorSize) + for(; j + 2*vectorSize <= cols; j+=2*vectorSize) { Index i = 0; if(PanelMode) ri += offset*(2*vectorSize); - for(; i + vectorSize < depth; i+=vectorSize) + for(; i + vectorSize <= depth; i+=vectorSize) { PacketBlock block; if(StorageOrder == ColMajor) @@ -1059,13 +1066,13 @@ struct lhs_cpack(blockA); Packet conj = pset1(-1.0); - for(j = 0; j + vectorSize < rows; j+=vectorSize) + for(j = 0; j + vectorSize <= rows; j+=vectorSize) { Index i = 0; if(PanelMode) ri += vectorSize*offset; - for(; i + vectorSize < depth; i+=vectorSize) + for(; i + vectorSize <= depth; i+=vectorSize) { PacketBlock block; @@ -1078,8 +1085,8 @@ struct lhs_cpack(j + 1, i + 0); //[a2 a2i] cblock.packet[3] = lhs.template loadPacket(j + 1, i + 1); //[b2 b2i] - block.packet[0] = vec_perm(cblock.packet[0].v , cblock.packet[2].v, p16uc_GETREAL64); //[a1 a2] - block.packet[1] = vec_perm(cblock.packet[1].v , cblock.packet[3].v, p16uc_GETREAL64); //[b1 b2] + block.packet[0] = vec_perm(cblock.packet[0].v, cblock.packet[2].v, p16uc_GETREAL64); //[a1 a2] + block.packet[1] = vec_perm(cblock.packet[1].v, cblock.packet[3].v, p16uc_GETREAL64); //[b1 b2] } else { cblock.packet[0] = lhs.template loadPacket(j + 0, i); //[a1 a1i] cblock.packet[1] = lhs.template loadPacket(j + 1, i); //[a2 a2i] @@ -1087,8 +1094,8 @@ struct lhs_cpack(j + 0, i + 1); //[b1 b1i] cblock.packet[3] = lhs.template loadPacket(j + 1, i + 1); //[b2 b2i] - block.packet[0] = vec_perm(cblock.packet[0].v , cblock.packet[1].v, p16uc_GETREAL64); //[a1 a2] - block.packet[1] = vec_perm(cblock.packet[2].v , cblock.packet[3].v, p16uc_GETREAL64); //[b1 b2] + block.packet[0] = vec_perm(cblock.packet[0].v, cblock.packet[1].v, p16uc_GETREAL64); //[a1 a2] + block.packet[1] = vec_perm(cblock.packet[2].v, cblock.packet[3].v, p16uc_GETREAL64); //[b1 b2] } pstore(blockAt + ri , block.packet[0]); @@ -1108,7 +1115,7 @@ struct lhs_cpack block; @@ -1121,8 +1128,8 @@ struct lhs_cpack(j + 1, i + 0); cblock.packet[3] = lhs.template loadPacket(j + 1, i + 1); - block.packet[0] = vec_perm(cblock.packet[0].v , cblock.packet[2].v, p16uc_GETIMAG64); - block.packet[1] = vec_perm(cblock.packet[1].v , cblock.packet[3].v, p16uc_GETIMAG64); + block.packet[0] = vec_perm(cblock.packet[0].v, cblock.packet[2].v, p16uc_GETIMAG64); + block.packet[1] = vec_perm(cblock.packet[1].v, cblock.packet[3].v, p16uc_GETIMAG64); } else { cblock.packet[0] = lhs.template loadPacket(j + 0, i); cblock.packet[1] = lhs.template loadPacket(j + 1, i); @@ -1130,8 +1137,8 @@ struct lhs_cpack(j + 0, i + 1); cblock.packet[3] = lhs.template loadPacket(j + 1, i + 1); - block.packet[0] = vec_perm(cblock.packet[0].v , cblock.packet[1].v, p16uc_GETIMAG64); - block.packet[1] = vec_perm(cblock.packet[2].v , cblock.packet[3].v, p16uc_GETIMAG64); + block.packet[0] = vec_perm(cblock.packet[0].v, cblock.packet[1].v, p16uc_GETIMAG64); + block.packet[1] = vec_perm(cblock.packet[2].v, cblock.packet[3].v, p16uc_GETIMAG64); } if(Conjugate) @@ -1205,7 +1212,7 @@ struct rhs_cpack(-1.0); Index ri = 0, j = 0; - for(; j + 2*vectorSize < cols; j+=2*vectorSize) + for(; j + 2*vectorSize <= cols; j+=2*vectorSize) { Index i = 0; @@ -1221,8 +1228,8 @@ struct rhs_cpack(i, j + 2); cblock.packet[3] = rhs.template loadPacket(i, j + 3); - block.packet[0] = vec_perm(cblock.packet[0].v , cblock.packet[1].v, p16uc_GETREAL64); - block.packet[1] = vec_perm(cblock.packet[2].v , cblock.packet[3].v, p16uc_GETREAL64); + block.packet[0] = vec_perm(cblock.packet[0].v, cblock.packet[1].v, p16uc_GETREAL64); + block.packet[1] = vec_perm(cblock.packet[2].v, cblock.packet[3].v, p16uc_GETREAL64); pstore(blockBt + ri , block.packet[0]); pstore(blockBt + ri + 2, block.packet[1]); @@ -1246,8 +1253,8 @@ struct rhs_cpack(i, j + 2); //[c1 c1i] cblock.packet[3] = rhs.template loadPacket(i, j + 3); //[d1 d1i] - block.packet[0] = vec_perm(cblock.packet[0].v , cblock.packet[1].v, p16uc_GETIMAG64); - block.packet[1] = vec_perm(cblock.packet[2].v , cblock.packet[3].v, p16uc_GETIMAG64); + block.packet[0] = vec_perm(cblock.packet[0].v, cblock.packet[1].v, p16uc_GETIMAG64); + block.packet[1] = vec_perm(cblock.packet[2].v, cblock.packet[3].v, p16uc_GETIMAG64); if(Conjugate) { @@ -1300,25 +1307,84 @@ struct rhs_cpack -EIGEN_STRONG_INLINE void pger(PacketBlock *acc, const Scalar* lhs, const Scalar* rhs) +EIGEN_STRONG_INLINE void pger(PacketBlock* acc, const Scalar* lhs, const Packet* rhsV) { - Packet lhsV = *((Packet *) lhs); - Packet rhsV1 = pset1(rhs[0]); - Packet rhsV2 = pset1(rhs[1]); - Packet rhsV3 = pset1(rhs[2]); - Packet rhsV4 = pset1(rhs[3]); + asm("#pger begin"); + Packet lhsV = pload(lhs); if(NegativeAccumulate) { - acc->packet[0] -= lhsV*rhsV1; - acc->packet[1] -= lhsV*rhsV2; - acc->packet[2] -= lhsV*rhsV3; - acc->packet[3] -= lhsV*rhsV4; + acc->packet[0] = vec_nmsub(lhsV, rhsV[0], acc->packet[0]); + acc->packet[1] = vec_nmsub(lhsV, rhsV[1], acc->packet[1]); + acc->packet[2] = vec_nmsub(lhsV, rhsV[2], acc->packet[2]); + acc->packet[3] = vec_nmsub(lhsV, rhsV[3], acc->packet[3]); } else { - acc->packet[0] += lhsV*rhsV1; - acc->packet[1] += lhsV*rhsV2; - acc->packet[2] += lhsV*rhsV3; - acc->packet[3] += lhsV*rhsV4; + acc->packet[0] = vec_madd(lhsV, rhsV[0], acc->packet[0]); + acc->packet[1] = vec_madd(lhsV, rhsV[1], acc->packet[1]); + acc->packet[2] = vec_madd(lhsV, rhsV[2], acc->packet[2]); + acc->packet[3] = vec_madd(lhsV, rhsV[3], acc->packet[3]); + } + asm("#pger end"); +} + +template +EIGEN_STRONG_INLINE void pger(PacketBlock* acc, const Scalar* lhs, const Packet* rhsV) +{ + Packet lhsV = pload(lhs); + + if(NegativeAccumulate) + { + acc->packet[0] = vec_nmsub(lhsV, rhsV[0], acc->packet[0]); + } else { + acc->packet[0] = vec_madd(lhsV, rhsV[0], acc->packet[0]); + } +} + +template +EIGEN_STRONG_INLINE void pger(PacketBlock* acc, const Scalar* lhs, const Packet* rhsV, Index remaining_rows) +{ +#ifdef _ARCH_PWR9 + Packet lhsV = vec_xl_len((Scalar *)lhs, remaining_rows * sizeof(Scalar)); +#else + Packet lhsV; + Index i = 0; + do { + lhsV[i] = lhs[i]; + } while (++i < remaining_rows); +#endif + + if(NegativeAccumulate) + { + acc->packet[0] = vec_nmsub(lhsV, rhsV[0], acc->packet[0]); + acc->packet[1] = vec_nmsub(lhsV, rhsV[1], acc->packet[1]); + acc->packet[2] = vec_nmsub(lhsV, rhsV[2], acc->packet[2]); + acc->packet[3] = vec_nmsub(lhsV, rhsV[3], acc->packet[3]); + } else { + acc->packet[0] = vec_madd(lhsV, rhsV[0], acc->packet[0]); + acc->packet[1] = vec_madd(lhsV, rhsV[1], acc->packet[1]); + acc->packet[2] = vec_madd(lhsV, rhsV[2], acc->packet[2]); + acc->packet[3] = vec_madd(lhsV, rhsV[3], acc->packet[3]); + } +} + +template +EIGEN_STRONG_INLINE void pger(PacketBlock* acc, const Scalar* lhs, const Packet* rhsV, Index remaining_rows) +{ +#ifdef _ARCH_PWR9 + Packet lhsV = vec_xl_len((Scalar *)lhs, remaining_rows * sizeof(Scalar)); +#else + Packet lhsV; + Index i = 0; + do { + lhsV[i] = lhs[i]; + } while (++i < remaining_rows); +#endif + + if(NegativeAccumulate) + { + acc->packet[0] = vec_nmsub(lhsV, rhsV[0], acc->packet[0]); + } else { + acc->packet[0] = vec_madd(lhsV, rhsV[0], acc->packet[0]); } } @@ -1399,7 +1465,7 @@ EIGEN_STRONG_INLINE void pgerc(PacketBlock& accReal, PacketBlock EIGEN_STRONG_INLINE Packet ploadLhs(const Scalar *lhs) { - return *((Packet *)lhs); + return *((Packet *)lhs); } // Zero the accumulator on PacketBlock. @@ -1412,14 +1478,26 @@ EIGEN_STRONG_INLINE void bsetzero(PacketBlock& acc) acc.packet[3] = pset1((Scalar)0); } +template +EIGEN_STRONG_INLINE void bsetzero(PacketBlock& acc) +{ + acc.packet[0] = pset1((Scalar)0); +} + // Scale the PacketBlock vectors by alpha. template EIGEN_STRONG_INLINE void bscale(PacketBlock& acc, PacketBlock& accZ, const Packet& pAlpha) { - acc.packet[0] = pmadd(pAlpha,accZ.packet[0], acc.packet[0]); - acc.packet[1] = pmadd(pAlpha,accZ.packet[1], acc.packet[1]); - acc.packet[2] = pmadd(pAlpha,accZ.packet[2], acc.packet[2]); - acc.packet[3] = pmadd(pAlpha,accZ.packet[3], acc.packet[3]); + acc.packet[0] = pmadd(pAlpha, accZ.packet[0], acc.packet[0]); + acc.packet[1] = pmadd(pAlpha, accZ.packet[1], acc.packet[1]); + acc.packet[2] = pmadd(pAlpha, accZ.packet[2], acc.packet[2]); + acc.packet[3] = pmadd(pAlpha, accZ.packet[3], acc.packet[3]); +} + +template +EIGEN_STRONG_INLINE void bscale(PacketBlock& acc, PacketBlock& accZ, const Packet& pAlpha) +{ + acc.packet[0] = pmadd(pAlpha, accZ.packet[0], acc.packet[0]); } // Complex version of PacketBlock scaling. @@ -1471,16 +1549,453 @@ EIGEN_STRONG_INLINE void bload(PacketBlock& acc, const DataMapper& res acc.packet[7] = res.template loadPacket(row + (N+1)*accCols, col + 3); } +const static Packet4i mask41 = { -1, 0, 0, 0 }; +const static Packet4i mask42 = { -1, -1, 0, 0 }; +const static Packet4i mask43 = { -1, -1, -1, 0 }; + +const static Packet2l mask21 = { -1, 0 }; + +template +EIGEN_STRONG_INLINE Packet bmask(const int remaining_rows) +{ + if (remaining_rows == 0) { + return pset1(float(0.0)); + } else { + switch (remaining_rows) { + case 1: return Packet(mask41); + case 2: return Packet(mask42); + default: return Packet(mask43); + } + } +} + +template<> +EIGEN_STRONG_INLINE Packet2d bmask(const int remaining_rows) +{ + if (remaining_rows == 0) { + return pset1(double(0.0)); + } else { + return Packet2d(mask21); + } +} + +template +EIGEN_STRONG_INLINE void bscale(PacketBlock& acc, PacketBlock& accZ, const Packet& pAlpha, const Packet& pMask) +{ + acc.packet[0] = pmadd(pAlpha, pand(accZ.packet[0], pMask), acc.packet[0]); + acc.packet[1] = pmadd(pAlpha, pand(accZ.packet[1], pMask), acc.packet[1]); + acc.packet[2] = pmadd(pAlpha, pand(accZ.packet[2], pMask), acc.packet[2]); + acc.packet[3] = pmadd(pAlpha, pand(accZ.packet[3], pMask), acc.packet[3]); +} // PEEL loop factor. #define PEEL 10 +template +EIGEN_STRONG_INLINE void MICRO_EXTRA_COL( + const Scalar* &lhs_ptr, + const Scalar* &rhs_ptr, + PacketBlock &accZero, + Index remaining_rows, + Index remaining_cols) +{ + Packet rhsV[1]; + rhsV[0] = pset1(rhs_ptr[0]); + pger(&accZero, lhs_ptr, rhsV); + lhs_ptr += remaining_rows; + rhs_ptr += remaining_cols; +} + +template +EIGEN_STRONG_INLINE void gemm_extra_col( + const DataMapper& res, + const Scalar *lhs_base, + const Scalar *rhs_base, + Index depth, + Index strideA, + Index offsetA, + Index row, + Index col, + Index remaining_rows, + Index remaining_cols, + const Packet& pAlpha) +{ + const Scalar *rhs_ptr = rhs_base; + const Scalar *lhs_ptr = lhs_base + row*strideA + remaining_rows*offsetA; + PacketBlock accZero, acc; + + bsetzero(accZero); + + Index remaining_depth = (depth & -accRows); + Index k = 0; + for(; k + PEEL <= remaining_depth; k+= PEEL) + { + prefetch(rhs_ptr); + prefetch(lhs_ptr); + for (int l = 0; l < PEEL; l++) { + MICRO_EXTRA_COL(lhs_ptr, rhs_ptr, accZero, remaining_rows, remaining_cols); + } + } + for(; k < remaining_depth; k++) + { + MICRO_EXTRA_COL(lhs_ptr, rhs_ptr, accZero, remaining_rows, remaining_cols); + } + for(; k < depth; k++) + { + Packet rhsV[1]; + rhsV[0] = pset1(rhs_ptr[0]); + pger(&accZero, lhs_ptr, rhsV, remaining_rows); + lhs_ptr += remaining_rows; + rhs_ptr += remaining_cols; + } + + acc.packet[0] = vec_mul(pAlpha, accZero.packet[0]); + for(Index i = 0; i < remaining_rows; i++){ + res(row + i, col) += acc.packet[0][i]; + } +} + +template +EIGEN_STRONG_INLINE void MICRO_EXTRA_ROW( + const Scalar* &lhs_ptr, + const Scalar* &rhs_ptr, + PacketBlock &accZero, + Index remaining_rows) +{ + Packet rhsV[4]; + pbroadcast4(rhs_ptr, rhsV[0], rhsV[1], rhsV[2], rhsV[3]); + pger(&accZero, lhs_ptr, rhsV); + lhs_ptr += remaining_rows; + rhs_ptr += accRows; +} + +template +EIGEN_STRONG_INLINE void gemm_extra_row( + const DataMapper& res, + const Scalar *lhs_base, + const Scalar *rhs_base, + Index depth, + Index strideA, + Index offsetA, + Index row, + Index col, + Index cols, + Index remaining_rows, + const Packet& pAlpha, + const Packet& pMask) +{ + const Scalar *rhs_ptr = rhs_base; + const Scalar *lhs_ptr = lhs_base + row*strideA + remaining_rows*offsetA; + PacketBlock accZero, acc; + + bsetzero(accZero); + + Index remaining_depth = (col + accRows < cols) ? depth : (depth & -accRows); + Index k = 0; + for(; k + PEEL <= remaining_depth; k+= PEEL) + { + prefetch(rhs_ptr); + prefetch(lhs_ptr); + for (int l = 0; l < PEEL; l++) { + MICRO_EXTRA_ROW(lhs_ptr, rhs_ptr, accZero, remaining_rows); + } + } + for(; k < remaining_depth; k++) + { + MICRO_EXTRA_ROW(lhs_ptr, rhs_ptr, accZero, remaining_rows); + } + + if (remaining_depth == depth) + { + for(Index j = 0; j < 4; j++){ + acc.packet[j] = res.template loadPacket(row, col + j); + } + bscale(acc, accZero, pAlpha, pMask); + res.template storePacketBlock(row, col, acc); + } else { + for(; k < depth; k++) + { + Packet rhsV[4]; + pbroadcast4(rhs_ptr, rhsV[0], rhsV[1], rhsV[2], rhsV[3]); + pger(&accZero, lhs_ptr, rhsV, remaining_rows); + lhs_ptr += remaining_rows; + rhs_ptr += accRows; + } + + for(Index j = 0; j < 4; j++){ + acc.packet[j] = vec_mul(pAlpha, accZero.packet[j]); + } + for(Index j = 0; j < 4; j++){ + for(Index i = 0; i < remaining_rows; i++){ + res(row + i, col + j) += acc.packet[j][i]; + } + } + } +} + +#define MICRO_DST \ + PacketBlock *accZero0, PacketBlock *accZero1, PacketBlock *accZero2, \ + PacketBlock *accZero3, PacketBlock *accZero4, PacketBlock *accZero5, \ + PacketBlock *accZero6, PacketBlock *accZero7 + +#define MICRO_COL_DST \ + PacketBlock *accZero0, PacketBlock *accZero1, PacketBlock *accZero2, \ + PacketBlock *accZero3, PacketBlock *accZero4, PacketBlock *accZero5, \ + PacketBlock *accZero6, PacketBlock *accZero7 + +#define MICRO_SRC \ + const Scalar **lhs_ptr0, const Scalar **lhs_ptr1, const Scalar **lhs_ptr2, \ + const Scalar **lhs_ptr3, const Scalar **lhs_ptr4, const Scalar **lhs_ptr5, \ + const Scalar **lhs_ptr6, const Scalar **lhs_ptr7 + +#define MICRO_ONE \ + MICRO(\ + &lhs_ptr0, &lhs_ptr1, &lhs_ptr2, &lhs_ptr3, &lhs_ptr4, &lhs_ptr5, &lhs_ptr6, &lhs_ptr7, \ + rhs_ptr, \ + &accZero0, &accZero1, &accZero2, &accZero3, &accZero4, &accZero5, &accZero6, &accZero7); + +#define MICRO_COL_ONE \ + MICRO_COL(\ + &lhs_ptr0, &lhs_ptr1, &lhs_ptr2, &lhs_ptr3, &lhs_ptr4, &lhs_ptr5, &lhs_ptr6, &lhs_ptr7, \ + rhs_ptr, \ + &accZero0, &accZero1, &accZero2, &accZero3, &accZero4, &accZero5, &accZero6, &accZero7, \ + remaining_cols); + +#define MICRO_WORK_ONE(iter) \ + if (N > iter) { \ + pger(accZero##iter, *lhs_ptr##iter, rhsV); \ + *lhs_ptr##iter += accCols; \ + } else { \ + EIGEN_UNUSED_VARIABLE(accZero##iter); \ + EIGEN_UNUSED_VARIABLE(lhs_ptr##iter); \ + } + +#define MICRO_UNROLL(func) \ + func(0) func(1) func(2) func(3) func(4) func(5) func(6) func(7) + +#define MICRO_WORK MICRO_UNROLL(MICRO_WORK_ONE) + +#define MICRO_DST_PTR_ONE(iter) \ + if (unroll_factor > iter){ \ + bsetzero(accZero##iter); \ + } else { \ + EIGEN_UNUSED_VARIABLE(accZero##iter); \ + } + +#define MICRO_DST_PTR MICRO_UNROLL(MICRO_DST_PTR_ONE) + +#define MICRO_SRC_PTR_ONE(iter) \ + if (unroll_factor > iter) { \ + lhs_ptr##iter = lhs_base + ( (row/accCols) + iter )*strideA*accCols + accCols*offsetA; \ + } else { \ + EIGEN_UNUSED_VARIABLE(lhs_ptr##iter); \ + } + +#define MICRO_SRC_PTR MICRO_UNROLL(MICRO_SRC_PTR_ONE) + +#define MICRO_PREFETCH_ONE(iter) \ + if (unroll_factor > iter){ \ + prefetch(lhs_ptr##iter); \ + } + +#define MICRO_PREFETCH MICRO_UNROLL(MICRO_PREFETCH_ONE) + +#define MICRO_STORE_ONE(iter) \ + if (unroll_factor > iter){ \ + acc.packet[0] = res.template loadPacket(row + iter*accCols, col + 0); \ + acc.packet[1] = res.template loadPacket(row + iter*accCols, col + 1); \ + acc.packet[2] = res.template loadPacket(row + iter*accCols, col + 2); \ + acc.packet[3] = res.template loadPacket(row + iter*accCols, col + 3); \ + bscale(acc, accZero##iter, pAlpha); \ + res.template storePacketBlock(row + iter*accCols, col, acc); \ + } + +#define MICRO_STORE MICRO_UNROLL(MICRO_STORE_ONE) + +#define MICRO_COL_STORE_ONE(iter) \ + if (unroll_factor > iter){ \ + acc.packet[0] = res.template loadPacket(row + iter*accCols, col + 0); \ + bscale(acc, accZero##iter, pAlpha); \ + res.template storePacketBlock(row + iter*accCols, col, acc); \ + } + +#define MICRO_COL_STORE MICRO_UNROLL(MICRO_COL_STORE_ONE) + +template +EIGEN_STRONG_INLINE void MICRO( + MICRO_SRC, + const Scalar* &rhs_ptr, + MICRO_DST) + { + Packet rhsV[4]; + pbroadcast4(rhs_ptr, rhsV[0], rhsV[1], rhsV[2], rhsV[3]); + asm("#unrolled pger? begin"); + MICRO_WORK + asm("#unrolled pger? end"); + rhs_ptr += accRows; + } + +template +EIGEN_STRONG_INLINE void gemm_unrolled_iteration( + const DataMapper& res, + const Scalar *lhs_base, + const Scalar *rhs_base, + Index depth, + Index strideA, + Index offsetA, + Index& row, + Index col, + const Packet& pAlpha) +{ + const Scalar *rhs_ptr = rhs_base; + const Scalar *lhs_ptr0, *lhs_ptr1, *lhs_ptr2, *lhs_ptr3, *lhs_ptr4, *lhs_ptr5, *lhs_ptr6, *lhs_ptr7; + PacketBlock accZero0, accZero1, accZero2, accZero3, accZero4, accZero5, accZero6, accZero7; + PacketBlock acc; + + asm("#unrolled start"); + MICRO_SRC_PTR + asm("#unrolled zero?"); + MICRO_DST_PTR + + Index k = 0; + for(; k + PEEL <= depth; k+= PEEL) + { + prefetch(rhs_ptr); + MICRO_PREFETCH + asm("#unrolled inner loop?"); + for (int l = 0; l < PEEL; l++) { + MICRO_ONE + } + asm("#unrolled inner loop end?"); + } + for(; k < depth; k++) + { + MICRO_ONE + } + MICRO_STORE + + row += unroll_factor*accCols; +} + +template +EIGEN_STRONG_INLINE void MICRO_COL( + MICRO_SRC, + const Scalar* &rhs_ptr, + MICRO_COL_DST, + Index remaining_rows) + { + Packet rhsV[1]; + rhsV[0] = pset1(rhs_ptr[0]); + asm("#unrolled pger? begin"); + MICRO_WORK + asm("#unrolled pger? end"); + rhs_ptr += remaining_rows; + } + +template +EIGEN_STRONG_INLINE void gemm_unrolled_col_iteration( + const DataMapper& res, + const Scalar *lhs_base, + const Scalar *rhs_base, + Index depth, + Index strideA, + Index offsetA, + Index& row, + Index col, + Index remaining_cols, + const Packet& pAlpha) +{ + const Scalar *rhs_ptr = rhs_base; + const Scalar *lhs_ptr0, *lhs_ptr1, *lhs_ptr2, *lhs_ptr3, *lhs_ptr4, *lhs_ptr5, *lhs_ptr6, *lhs_ptr7; + PacketBlock accZero0, accZero1, accZero2, accZero3, accZero4, accZero5, accZero6, accZero7; + PacketBlock acc; + + MICRO_SRC_PTR + MICRO_DST_PTR + + Index k = 0; + for(; k + PEEL <= depth; k+= PEEL) + { + prefetch(rhs_ptr); + MICRO_PREFETCH + for (int l = 0; l < PEEL; l++) { + MICRO_COL_ONE + } + } + for(; k < depth; k++) + { + MICRO_COL_ONE + } + MICRO_COL_STORE + + row += unroll_factor*accCols; +} + +template +EIGEN_STRONG_INLINE void gemm_unrolled_col( + const DataMapper& res, + const Scalar *lhs_base, + const Scalar *rhs_base, + Index depth, + Index strideA, + Index offsetA, + Index& row, + Index rows, + Index col, + Index remaining_cols, + const Packet& pAlpha) +{ +#define MAX_UNROLL 6 + while(row + MAX_UNROLL*accCols <= rows){ + gemm_unrolled_col_iteration(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, remaining_cols, pAlpha); + } + switch( (rows-row)/accCols ){ +#if MAX_UNROLL > 7 + case 7: + gemm_unrolled_col_iteration<7, Scalar, Packet, DataMapper, Index, accCols>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, remaining_cols, pAlpha); + break; +#endif +#if MAX_UNROLL > 6 + case 6: + gemm_unrolled_col_iteration<6, Scalar, Packet, DataMapper, Index, accCols>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, remaining_cols, pAlpha); + break; +#endif +#if MAX_UNROLL > 5 + case 5: + gemm_unrolled_col_iteration<5, Scalar, Packet, DataMapper, Index, accCols>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, remaining_cols, pAlpha); + break; +#endif +#if MAX_UNROLL > 4 + case 4: + gemm_unrolled_col_iteration<4, Scalar, Packet, DataMapper, Index, accCols>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, remaining_cols, pAlpha); + break; +#endif +#if MAX_UNROLL > 3 + case 3: + gemm_unrolled_col_iteration<3, Scalar, Packet, DataMapper, Index, accCols>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, remaining_cols, pAlpha); + break; +#endif +#if MAX_UNROLL > 2 + case 2: + gemm_unrolled_col_iteration<2, Scalar, Packet, DataMapper, Index, accCols>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, remaining_cols, pAlpha); + break; +#endif +#if MAX_UNROLL > 1 + case 1: + gemm_unrolled_col_iteration<1, Scalar, Packet, DataMapper, Index, accCols>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, remaining_cols, pAlpha); + break; +#endif + default: + break; + } +#undef MAX_UNROLL +} + /**************** * GEMM kernels * * **************/ -template -EIGEN_STRONG_INLINE void gemm(const DataMapper& res, const Scalar* blockA, const Scalar* blockB, - Index rows, Index depth, Index cols, Scalar alpha, Index strideA, Index strideB, Index offsetA, Index offsetB, const int accRows, const int accCols) +template +EIGEN_STRONG_INLINE void gemm(const DataMapper& res, const Scalar* blockA, const Scalar* blockB, Index rows, Index depth, Index cols, Scalar alpha, Index strideA, Index strideB, Index offsetA, Index offsetB) { const Index remaining_rows = rows % accCols; const Index remaining_cols = cols % accRows; @@ -1489,516 +2004,91 @@ EIGEN_STRONG_INLINE void gemm(const DataMapper& res, const Scalar* blockA, const if( strideB == -1 ) strideB = depth; const Packet pAlpha = pset1(alpha); + const Packet pMask = bmask((const int)(remaining_rows)); + Index col = 0; for(; col + accRows <= cols; col += accRows) { - const Scalar *rhs_base = blockB + ( col/accRows )*strideB*accRows; + const Scalar *rhs_base = blockB + col*strideB + accRows*offsetB; const Scalar *lhs_base = blockA; - Index row = 0; - for(; row + 6*accCols <= rows; row += 6*accCols) - { -#define MICRO() \ - pger(&accZero1, lhs_ptr1, rhs_ptr); \ - lhs_ptr1 += accCols; \ - pger(&accZero2, lhs_ptr2, rhs_ptr); \ - lhs_ptr2 += accCols; \ - pger(&accZero3, lhs_ptr3, rhs_ptr); \ - lhs_ptr3 += accCols; \ - pger(&accZero4, lhs_ptr4, rhs_ptr); \ - lhs_ptr4 += accCols; \ - pger(&accZero5, lhs_ptr5, rhs_ptr); \ - lhs_ptr5 += accCols; \ - pger(&accZero6, lhs_ptr6, rhs_ptr); \ - lhs_ptr6 += accCols; \ - rhs_ptr += accRows; - const Scalar *rhs_ptr = rhs_base; - const Scalar *lhs_ptr1 = lhs_base + (row/accCols)*strideA*accCols; - const Scalar *lhs_ptr2 = lhs_base + ((row/accCols) + 1)*strideA*accCols; - const Scalar *lhs_ptr3 = lhs_base + ((row/accCols) + 2)*strideA*accCols; - const Scalar *lhs_ptr4 = lhs_base + ((row/accCols) + 3)*strideA*accCols; - const Scalar *lhs_ptr5 = lhs_base + ((row/accCols) + 4)*strideA*accCols; - const Scalar *lhs_ptr6 = lhs_base + ((row/accCols) + 5)*strideA*accCols; - - PacketBlock acc1, accZero1; - PacketBlock acc2, accZero2; - PacketBlock acc3, accZero3; - PacketBlock acc4, accZero4; - PacketBlock acc5, accZero5; - PacketBlock acc6, accZero6; - - bload(acc1, res, row, col, accCols); - bsetzero(accZero1); - bload(acc2, res, row, col, accCols); - bsetzero(accZero2); - bload(acc3, res, row, col, accCols); - bsetzero(accZero3); - bload(acc4, res, row, col, accCols); - bsetzero(accZero4); - bload(acc5, res, row, col, accCols); - bsetzero(accZero5); - bload(acc6, res, row, col, accCols); - bsetzero(accZero6); - - lhs_ptr1 += accCols*offsetA; - lhs_ptr2 += accCols*offsetA; - lhs_ptr3 += accCols*offsetA; - lhs_ptr4 += accCols*offsetA; - lhs_ptr5 += accCols*offsetA; - lhs_ptr6 += accCols*offsetA; - rhs_ptr += accRows*offsetB; - - Index k = 0; - for(; k + PEEL < depth; k+= PEEL) - { - prefetch(rhs_ptr); - prefetch(lhs_ptr1); - prefetch(lhs_ptr2); - prefetch(lhs_ptr3); - prefetch(lhs_ptr4); - prefetch(lhs_ptr5); - prefetch(lhs_ptr6); - MICRO(); - MICRO(); - MICRO(); - MICRO(); - MICRO(); - MICRO(); - MICRO(); - MICRO(); -#if PEEL > 8 - MICRO(); - MICRO(); -#endif - } - for(; k < depth; k++) - { - MICRO(); - } - - bscale(acc1,accZero1, pAlpha); - bscale(acc2,accZero2, pAlpha); - bscale(acc3,accZero3, pAlpha); - bscale(acc4,accZero4, pAlpha); - bscale(acc5,accZero5, pAlpha); - bscale(acc6,accZero6, pAlpha); - - res.template storePacketBlock(row + 0*accCols, col, acc1); - res.template storePacketBlock(row + 1*accCols, col, acc2); - res.template storePacketBlock(row + 2*accCols, col, acc3); - res.template storePacketBlock(row + 3*accCols, col, acc4); - res.template storePacketBlock(row + 4*accCols, col, acc5); - res.template storePacketBlock(row + 5*accCols, col, acc6); -#undef MICRO + asm("#jump table"); +#define MAX_UNROLL 6 + while(row + MAX_UNROLL*accCols <= rows){ + gemm_unrolled_iteration(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, pAlpha); } - for(; row + 5*accCols <= rows; row += 5*accCols) - { -#define MICRO() \ - pger(&accZero1, lhs_ptr1, rhs_ptr); \ - lhs_ptr1 += accCols; \ - pger(&accZero2, lhs_ptr2, rhs_ptr); \ - lhs_ptr2 += accCols; \ - pger(&accZero3, lhs_ptr3, rhs_ptr); \ - lhs_ptr3 += accCols; \ - pger(&accZero4, lhs_ptr4, rhs_ptr); \ - lhs_ptr4 += accCols; \ - pger(&accZero5, lhs_ptr5, rhs_ptr); \ - lhs_ptr5 += accCols; \ - rhs_ptr += accRows; - - const Scalar *rhs_ptr = rhs_base; - const Scalar *lhs_ptr1 = lhs_base + (row/accCols )*strideA*accCols; - const Scalar *lhs_ptr2 = lhs_base + ((row/accCols) + 1)*strideA*accCols; - const Scalar *lhs_ptr3 = lhs_base + ((row/accCols) + 2)*strideA*accCols; - const Scalar *lhs_ptr4 = lhs_base + ((row/accCols) + 3)*strideA*accCols; - const Scalar *lhs_ptr5 = lhs_base + ((row/accCols) + 4)*strideA*accCols; - - PacketBlock acc1, accZero1; - PacketBlock acc2, accZero2; - PacketBlock acc3, accZero3; - PacketBlock acc4, accZero4; - PacketBlock acc5, accZero5; - - bload(acc1, res, row, col, accCols); - bsetzero(accZero1); - bload(acc2, res, row, col, accCols); - bsetzero(accZero2); - bload(acc3, res, row, col, accCols); - bsetzero(accZero3); - bload(acc4, res, row, col, accCols); - bsetzero(accZero4); - bload(acc5, res, row, col, accCols); - bsetzero(accZero5); - - lhs_ptr1 += accCols*offsetA; - lhs_ptr2 += accCols*offsetA; - lhs_ptr3 += accCols*offsetA; - lhs_ptr4 += accCols*offsetA; - lhs_ptr5 += accCols*offsetA; - rhs_ptr += accRows*offsetB; - Index k = 0; - - for(; k + PEEL < depth; k+= PEEL) - { - prefetch(rhs_ptr); - prefetch(lhs_ptr1); - prefetch(lhs_ptr2); - prefetch(lhs_ptr3); - prefetch(lhs_ptr4); - prefetch(lhs_ptr5); - - MICRO(); - MICRO(); - MICRO(); - MICRO(); - MICRO(); - MICRO(); - MICRO(); - MICRO(); -#if PEEL > 8 - MICRO(); - MICRO(); + switch( (rows-row)/accCols ){ +#if MAX_UNROLL > 7 + case 7: + gemm_unrolled_iteration<7, Scalar, Packet, DataMapper, Index, accRows, accCols>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, pAlpha); + break; #endif - } - for(; k < depth; k++) - { - MICRO(); - } - - bscale(acc1,accZero1, pAlpha); - bscale(acc2,accZero2, pAlpha); - bscale(acc3,accZero3, pAlpha); - bscale(acc4,accZero4, pAlpha); - bscale(acc5,accZero5, pAlpha); - - res.template storePacketBlock(row + 0*accCols, col, acc1); - res.template storePacketBlock(row + 1*accCols, col, acc2); - res.template storePacketBlock(row + 2*accCols, col, acc3); - res.template storePacketBlock(row + 3*accCols, col, acc4); - res.template storePacketBlock(row + 4*accCols, col, acc5); -#undef MICRO - } - for(; row + 4*accCols <= rows; row += 4*accCols) - { -#define MICRO() \ - pger(&accZero1, lhs_ptr1, rhs_ptr); \ - lhs_ptr1 += accCols; \ - pger(&accZero2, lhs_ptr2, rhs_ptr); \ - lhs_ptr2 += accCols; \ - pger(&accZero3, lhs_ptr3, rhs_ptr); \ - lhs_ptr3 += accCols; \ - pger(&accZero4, lhs_ptr4, rhs_ptr); \ - lhs_ptr4 += accCols; \ - rhs_ptr += accRows; - - const Scalar *rhs_ptr = rhs_base; - const Scalar *lhs_ptr1 = lhs_base + (row/accCols )*strideA*accCols; - const Scalar *lhs_ptr2 = lhs_base + ((row/accCols) + 1)*strideA*accCols; - const Scalar *lhs_ptr3 = lhs_base + ((row/accCols) + 2)*strideA*accCols; - const Scalar *lhs_ptr4 = lhs_base + ((row/accCols) + 3)*strideA*accCols; - - PacketBlock acc1, accZero1; - PacketBlock acc2, accZero2; - PacketBlock acc3, accZero3; - PacketBlock acc4, accZero4; - - bload(acc1, res, row, col, accCols); - bsetzero(accZero1); - bload(acc2, res, row, col, accCols); - bsetzero(accZero2); - bload(acc3, res, row, col, accCols); - bsetzero(accZero3); - bload(acc4, res, row, col, accCols); - bsetzero(accZero4); - - lhs_ptr1 += accCols*offsetA; - lhs_ptr2 += accCols*offsetA; - lhs_ptr3 += accCols*offsetA; - lhs_ptr4 += accCols*offsetA; - rhs_ptr += accRows*offsetB; - Index k = 0; - - for(; k + PEEL < depth; k+= PEEL) - { - prefetch(rhs_ptr); - prefetch(lhs_ptr1); - prefetch(lhs_ptr2); - prefetch(lhs_ptr3); - prefetch(lhs_ptr4); - MICRO(); - MICRO(); - MICRO(); - MICRO(); - MICRO(); - MICRO(); - MICRO(); - MICRO(); -#if PEEL > 8 - MICRO(); - MICRO(); +#if MAX_UNROLL > 6 + case 6: + gemm_unrolled_iteration<6, Scalar, Packet, DataMapper, Index, accRows, accCols>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, pAlpha); + break; #endif - } - for(; k < depth; k++) - { - MICRO(); - } - - bscale(acc1,accZero1, pAlpha); - bscale(acc2,accZero2, pAlpha); - bscale(acc3,accZero3, pAlpha); - bscale(acc4,accZero4, pAlpha); - - res.template storePacketBlock(row + 0*accCols, col, acc1); - res.template storePacketBlock(row + 1*accCols, col, acc2); - res.template storePacketBlock(row + 2*accCols, col, acc3); - res.template storePacketBlock(row + 3*accCols, col, acc4); -#undef MICRO - } - for(; row + 3*accCols <= rows; row += 3*accCols) - { -#define MICRO() \ - pger(&accZero1, lhs_ptr1, rhs_ptr); \ - lhs_ptr1 += accCols; \ - pger(&accZero2, lhs_ptr2, rhs_ptr); \ - lhs_ptr2 += accCols; \ - pger(&accZero3, lhs_ptr3, rhs_ptr); \ - lhs_ptr3 += accCols; \ - rhs_ptr += accRows; - - const Scalar *rhs_ptr = rhs_base; - const Scalar *lhs_ptr1 = lhs_base + (row/accCols )*strideA*accCols; - const Scalar *lhs_ptr2 = lhs_base + ((row/accCols) + 1)*strideA*accCols; - const Scalar *lhs_ptr3 = lhs_base + ((row/accCols) + 2)*strideA*accCols; - - PacketBlock acc1, accZero1; - PacketBlock acc2, accZero2; - PacketBlock acc3, accZero3; - - bload(acc1, res, row, col, accCols); - bsetzero(accZero1); - bload(acc2, res, row, col, accCols); - bsetzero(accZero2); - bload(acc3, res, row, col, accCols); - bsetzero(accZero3); - - lhs_ptr1 += accCols*offsetA; - lhs_ptr2 += accCols*offsetA; - lhs_ptr3 += accCols*offsetA; - rhs_ptr += accRows*offsetB; - Index k = 0; - for(; k + PEEL < depth; k+= PEEL) - { - prefetch(rhs_ptr); - prefetch(lhs_ptr1); - prefetch(lhs_ptr2); - prefetch(lhs_ptr3); - - MICRO(); - MICRO(); - MICRO(); - MICRO(); - MICRO(); - MICRO(); - MICRO(); - MICRO(); -#if PEEL > 8 - MICRO(); - MICRO(); +#if MAX_UNROLL > 5 + case 5: + gemm_unrolled_iteration<5, Scalar, Packet, DataMapper, Index, accRows, accCols>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, pAlpha); + break; #endif - } - for(; k < depth; k++) - { - MICRO(); - } - - bscale(acc1,accZero1, pAlpha); - bscale(acc2,accZero2, pAlpha); - bscale(acc3,accZero3, pAlpha); - - res.template storePacketBlock(row + 0*accCols, col, acc1); - res.template storePacketBlock(row + 1*accCols, col, acc2); - res.template storePacketBlock(row + 2*accCols, col, acc3); -#undef MICRO - } - for(; row + 2*accCols <= rows; row += 2*accCols) - { -#define MICRO() \ - pger(&accZero1, lhs_ptr1, rhs_ptr); \ - lhs_ptr1 += accCols; \ - pger(&accZero2, lhs_ptr2, rhs_ptr); \ - lhs_ptr2 += accCols; \ - rhs_ptr += accRows; - - const Scalar *rhs_ptr = rhs_base; - const Scalar *lhs_ptr1 = lhs_base + (row/accCols )*strideA*accCols; - const Scalar *lhs_ptr2 = lhs_base + ((row/accCols) + 1)*strideA*accCols; - - PacketBlock acc1, accZero1; - PacketBlock acc2, accZero2; - - bload(acc1, res, row, col, accCols); - bsetzero(accZero1); - bload(acc2, res, row, col, accCols); - bsetzero(accZero2); - - lhs_ptr1 += accCols*offsetA; - lhs_ptr2 += accCols*offsetA; - rhs_ptr += accRows*offsetB; - Index k = 0; - for(; k + PEEL < depth; k+= PEEL) - { - prefetch(rhs_ptr); - prefetch(lhs_ptr1); - prefetch(lhs_ptr2); - - MICRO(); - MICRO(); - MICRO(); - MICRO(); - MICRO(); - MICRO(); - MICRO(); - MICRO(); -#if PEEL > 8 - MICRO(); - MICRO(); +#if MAX_UNROLL > 4 + case 4: + gemm_unrolled_iteration<4, Scalar, Packet, DataMapper, Index, accRows, accCols>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, pAlpha); + break; #endif - } - for(; k < depth; k++) - { - MICRO(); - } - - bscale(acc1,accZero1, pAlpha); - bscale(acc2,accZero2, pAlpha); - - res.template storePacketBlock(row + 0*accCols, col, acc1); - res.template storePacketBlock(row + 1*accCols, col, acc2); -#undef MICRO - } - - for(; row + accCols <= rows; row += accCols) - { -#define MICRO() \ - pger(&accZero1, lhs_ptr1, rhs_ptr); \ - lhs_ptr1 += accCols; \ - rhs_ptr += accRows; - - const Scalar *rhs_ptr = rhs_base; - const Scalar *lhs_ptr1 = lhs_base + (row/accCols)*strideA*accCols; - - PacketBlock acc1, accZero1; - - bload(acc1, res, row, col, accCols); - bsetzero(accZero1); - - lhs_ptr1 += accCols*offsetA; - rhs_ptr += accRows*offsetB; - Index k = 0; - for(; k + PEEL < depth; k+= PEEL) - { - prefetch(rhs_ptr); - prefetch(lhs_ptr1); - - MICRO(); - MICRO(); - MICRO(); - MICRO(); - MICRO(); - MICRO(); - MICRO(); - MICRO(); -#if PEEL > 8 - MICRO(); - MICRO(); +#if MAX_UNROLL > 3 + case 3: + gemm_unrolled_iteration<3, Scalar, Packet, DataMapper, Index, accRows, accCols>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, pAlpha); + break; #endif - } - for(; k < depth; k++) - { - MICRO(); - } - - bscale(acc1,accZero1, pAlpha); - - res.template storePacketBlock(row, col, acc1); -#undef MICRO +#if MAX_UNROLL > 2 + case 2: + gemm_unrolled_iteration<2, Scalar, Packet, DataMapper, Index, accRows, accCols>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, pAlpha); + break; +#endif +#if MAX_UNROLL > 1 + case 1: + gemm_unrolled_iteration<1, Scalar, Packet, DataMapper, Index, accRows, accCols>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, pAlpha); + break; +#endif + default: + break; } +#undef MAX_UNROLL + asm("#jump table end"); + if(remaining_rows > 0) { - const Scalar *rhs_ptr = rhs_base; - const Scalar *lhs_ptr = lhs_base + (row/accCols)*strideA*accCols; - - lhs_ptr += remaining_rows*offsetA; - rhs_ptr += accRows*offsetB; - for(Index k = 0; k < depth; k++) - { - for(Index arow = 0; arow < remaining_rows; arow++) - { - for(Index acol = 0; acol < accRows; acol++ ) - { - res(row + arow, col + acol) += alpha*lhs_ptr[arow]*rhs_ptr[acol]; - } - } - rhs_ptr += accRows; - lhs_ptr += remaining_rows; - } + gemm_extra_row(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, cols, remaining_rows, pAlpha, pMask); } } if(remaining_cols > 0) { - const Scalar *rhs_base = blockB + (col/accRows)*strideB*accRows; + const Scalar *rhs_base = blockB + col*strideB + remaining_cols*offsetB; const Scalar *lhs_base = blockA; - Index row = 0; - for(; row + accCols <= rows; row += accCols) + for(; col < cols; col++) { - const Scalar *rhs_ptr = rhs_base; - const Scalar *lhs_ptr = lhs_base + (row/accCols)*strideA*accCols; + Index row = 0; - lhs_ptr += accCols*offsetA; - rhs_ptr += remaining_cols*offsetB; - for(Index k = 0; k < depth; k++) - { - for(Index arow = 0; arow < accCols; arow++) - { - for(Index acol = 0; acol < remaining_cols; acol++ ) - { - res(row + arow, col + acol) += alpha*lhs_ptr[arow]*rhs_ptr[acol]; - } - } - rhs_ptr += remaining_cols; - lhs_ptr += accCols; - } - } - - if(remaining_rows > 0 ) - { - const Scalar *rhs_ptr = rhs_base; - const Scalar *lhs_ptr = lhs_base + (row/accCols)*strideA*accCols; + gemm_unrolled_col(res, lhs_base, rhs_base, depth, strideA, offsetA, row, rows, col, remaining_cols, pAlpha); - lhs_ptr += remaining_rows*offsetA; - rhs_ptr += remaining_cols*offsetB; - for(Index k = 0; k < depth; k++) + if (remaining_rows > 0) { - for(Index arow = 0; arow < remaining_rows; arow++) - { - for(Index acol = 0; acol < remaining_cols; acol++ ) - { - res(row + arow, col + acol) += alpha*lhs_ptr[arow]*rhs_ptr[acol]; - } - } - rhs_ptr += remaining_cols; - lhs_ptr += remaining_rows; + gemm_extra_col(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, remaining_rows, remaining_cols, pAlpha); } + rhs_base++; } } } -template +template EIGEN_STRONG_INLINE void gemm_complex(const DataMapper& res, const LhsScalar* blockAc, const RhsScalar* blockBc, - Index rows, Index depth, Index cols, Scalarc alpha, Index strideA, Index strideB, Index offsetA, Index offsetB, const int accRows, const int accCols) + Index rows, Index depth, Index cols, Scalarc alpha, Index strideA, Index strideB, Index offsetA, Index offsetB) { const int remaining_rows = rows % accCols; const int remaining_cols = cols % accRows; @@ -2018,7 +2108,7 @@ EIGEN_STRONG_INLINE void gemm_complex(const DataMapper& res, const LhsScalar* bl const Scalar *blockA = (Scalar *) blockAc; const Scalar *blockB = (Scalar *) blockBc; - Packet conj = pset1((Scalar)-1.0f); + Packet conj = pset1((Scalar)-1.0); Index col = 0; for(; col + accRows <= cols; col += accRows) @@ -2054,7 +2144,7 @@ EIGEN_STRONG_INLINE void gemm_complex(const DataMapper& res, const LhsScalar* bl if(!RhsIsReal) rhs_ptr_imag += accRows*offsetB; Index k = 0; - for(; k + PEEL < depth; k+=PEEL) + for(; k + PEEL <= depth; k+=PEEL) { prefetch(rhs_ptr); prefetch(rhs_ptr_imag); @@ -2180,8 +2270,8 @@ EIGEN_STRONG_INLINE void gemm_complex(const DataMapper& res, const LhsScalar* bl { for(Index acol = 0; acol < 4; acol++ ) { - scalarAcc[arow][acol].real((Scalar)0.0f); - scalarAcc[arow][acol].imag((Scalar)0.0f); + scalarAcc[arow][acol].real((Scalar)0.0); + scalarAcc[arow][acol].imag((Scalar)0.0); } } for(Index k = 0; k < depth; k++) @@ -2550,24 +2640,24 @@ void gebp_kernel::rows; - const int accCols = quad_traits::size; - void (*gemm_function)(const DataMapper&, const float*, const float*, Index, Index, Index, float, Index, Index, Index, Index, const int, const int); + const Index accRows = quad_traits::rows; + const Index accCols = quad_traits::size; + void (*gemm_function)(const DataMapper&, const float*, const float*, Index, Index, Index, float, Index, Index, Index, Index); #ifdef EIGEN_ALTIVEC_MMA_ONLY //generate with MMA only - gemm_function = &Eigen::internal::gemmMMA; + gemm_function = &Eigen::internal::gemmMMA; #elif defined(ALTIVEC_MMA_SUPPORT) && !defined(EIGEN_ALTIVEC_DISABLE_MMA) if (__builtin_cpu_supports ("arch_3_1") && __builtin_cpu_supports ("mma")){ - gemm_function = &Eigen::internal::gemmMMA; + gemm_function = &Eigen::internal::gemmMMA; } else{ - gemm_function = &Eigen::internal::gemm; + gemm_function = &Eigen::internal::gemm; } #else - gemm_function = &Eigen::internal::gemm; + gemm_function = &Eigen::internal::gemm; #endif - gemm_function(res, blockA, blockB, rows, depth, cols, alpha, strideA, strideB, offsetA, offsetB, accRows, accCols); + gemm_function(res, blockA, blockB, rows, depth, cols, alpha, strideA, strideB, offsetA, offsetB); } template @@ -2591,22 +2681,22 @@ void gebp_kernel, std::complex, Index, DataMapper, mr const int accRows = quad_traits::rows; const int accCols = quad_traits::size; void (*gemm_function)(const DataMapper&, const std::complex*, const std::complex*, - Index, Index, Index, std::complex, Index, Index , Index, Index, const int, const int); + Index, Index, Index, std::complex, Index, Index, Index, Index); #ifdef EIGEN_ALTIVEC_MMA_ONLY //generate with MMA only - gemm_function = &Eigen::internal::gemm_complexMMA, std::complex, std::complex, float, Index, Packet, Packetc, RhsPacket, DataMapper, ConjugateLhs, ConjugateRhs, false, false>; + gemm_function = &Eigen::internal::gemm_complexMMA, std::complex, std::complex, float, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, false, false>; #elif defined(ALTIVEC_MMA_SUPPORT) && !defined(EIGEN_ALTIVEC_DISABLE_MMA) if (__builtin_cpu_supports ("arch_3_1") && __builtin_cpu_supports ("mma")){ - gemm_function = &Eigen::internal::gemm_complexMMA, std::complex, std::complex, float, Index, Packet, Packetc, RhsPacket, DataMapper, ConjugateLhs, ConjugateRhs, false, false>; + gemm_function = &Eigen::internal::gemm_complexMMA, std::complex, std::complex, float, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, false, false>; } else{ - gemm_function = &Eigen::internal::gemm_complex, std::complex, std::complex, float, Index, Packet, Packetc, RhsPacket, DataMapper, ConjugateLhs, ConjugateRhs, false, false>; + gemm_function = &Eigen::internal::gemm_complex, std::complex, std::complex, float, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, false, false>; } #else - gemm_function = &Eigen::internal::gemm_complex, std::complex, std::complex, float, Index, Packet, Packetc, RhsPacket, DataMapper, ConjugateLhs, ConjugateRhs, false, false>; + gemm_function = &Eigen::internal::gemm_complex, std::complex, std::complex, float, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, false, false>; #endif - gemm_function(res, blockA, blockB, rows, depth, cols, alpha, strideA, strideB, offsetA, offsetB, accRows, accCols); + gemm_function(res, blockA, blockB, rows, depth, cols, alpha, strideA, strideB, offsetA, offsetB); } template @@ -2630,21 +2720,21 @@ void gebp_kernel, Index, DataMapper, mr, nr, Conjugat const int accRows = quad_traits::rows; const int accCols = quad_traits::size; void (*gemm_function)(const DataMapper&, const float*, const std::complex*, - Index, Index, Index, std::complex, Index, Index , Index, Index, const int, const int); + Index, Index, Index, std::complex, Index, Index, Index, Index); #ifdef EIGEN_ALTIVEC_MMA_ONLY //generate with MMA only - gemm_function = &Eigen::internal::gemm_complexMMA, std::complex, float, Index, Packet, Packetc, RhsPacket, DataMapper, ConjugateLhs, ConjugateRhs, true, false>; + gemm_function = &Eigen::internal::gemm_complexMMA, std::complex, float, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, true, false>; #elif defined(ALTIVEC_MMA_SUPPORT) && !defined(EIGEN_ALTIVEC_DISABLE_MMA) if (__builtin_cpu_supports ("arch_3_1") && __builtin_cpu_supports ("mma")){ - gemm_function = &Eigen::internal::gemm_complexMMA, std::complex, float, Index, Packet, Packetc, RhsPacket, DataMapper, ConjugateLhs, ConjugateRhs, true, false>; + gemm_function = &Eigen::internal::gemm_complexMMA, std::complex, float, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, true, false>; } else{ - gemm_function = &Eigen::internal::gemm_complex, std::complex, float, Index, Packet, Packetc, RhsPacket, DataMapper, ConjugateLhs, ConjugateRhs, true, false>; + gemm_function = &Eigen::internal::gemm_complex, std::complex, float, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, true, false>; } #else - gemm_function = &Eigen::internal::gemm_complex, std::complex, float, Index, Packet, Packetc, RhsPacket, DataMapper, ConjugateLhs, ConjugateRhs, true, false>; + gemm_function = &Eigen::internal::gemm_complex, std::complex, float, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, true, false>; #endif - gemm_function(res, blockA, blockB, rows, depth, cols, alpha, strideA, strideB, offsetA, offsetB, accRows, accCols); + gemm_function(res, blockA, blockB, rows, depth, cols, alpha, strideA, strideB, offsetA, offsetB); } template @@ -2668,21 +2758,21 @@ void gebp_kernel, float, Index, DataMapper, mr, nr, Conjugat const int accRows = quad_traits::rows; const int accCols = quad_traits::size; void (*gemm_function)(const DataMapper&, const std::complex*, const float*, - Index, Index, Index, std::complex, Index, Index , Index, Index, const int, const int); + Index, Index, Index, std::complex, Index, Index, Index, Index); #ifdef EIGEN_ALTIVEC_MMA_ONLY //generate with MMA only - gemm_function = &Eigen::internal::gemm_complexMMA, float, std::complex, float, Index, Packet, Packetc, RhsPacket, DataMapper, ConjugateLhs, ConjugateRhs, false, true>; + gemm_function = &Eigen::internal::gemm_complexMMA, float, std::complex, float, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, false, true>; #elif defined(ALTIVEC_MMA_SUPPORT) && !defined(EIGEN_ALTIVEC_DISABLE_MMA) if (__builtin_cpu_supports ("arch_3_1") && __builtin_cpu_supports ("mma")){ - gemm_function = &Eigen::internal::gemm_complexMMA, float, std::complex, float, Index, Packet, Packetc, RhsPacket, DataMapper, ConjugateLhs, ConjugateRhs, false, true>; + gemm_function = &Eigen::internal::gemm_complexMMA, float, std::complex, float, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, false, true>; } else{ - gemm_function = &Eigen::internal::gemm_complex, float, std::complex, float, Index, Packet, Packetc, RhsPacket, DataMapper, ConjugateLhs, ConjugateRhs, false, true>; + gemm_function = &Eigen::internal::gemm_complex, float, std::complex, float, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, false, true>; } #else - gemm_function = &Eigen::internal::gemm_complex, float, std::complex, float, Index, Packet, Packetc, RhsPacket, DataMapper, ConjugateLhs, ConjugateRhs, false, true>; + gemm_function = &Eigen::internal::gemm_complex, float, std::complex, float, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, false, true>; #endif - gemm_function(res, blockA, blockB, rows, depth, cols, alpha, strideA, strideB, offsetA, offsetB, accRows, accCols); + gemm_function(res, blockA, blockB, rows, depth, cols, alpha, strideA, strideB, offsetA, offsetB); } template @@ -2702,24 +2792,24 @@ void gebp_kernel::rows; - const int accCols = quad_traits::size; - void (*gemm_function)(const DataMapper&, const double*, const double*, Index, Index, Index, double, Index, Index, Index, Index, const int, const int); + const Index accRows = quad_traits::rows; + const Index accCols = quad_traits::size; + void (*gemm_function)(const DataMapper&, const double*, const double*, Index, Index, Index, double, Index, Index, Index, Index); #ifdef EIGEN_ALTIVEC_MMA_ONLY //generate with MMA only - gemm_function = &Eigen::internal::gemmMMA; + gemm_function = &Eigen::internal::gemmMMA; #elif defined(ALTIVEC_MMA_SUPPORT) && !defined(EIGEN_ALTIVEC_DISABLE_MMA) if (__builtin_cpu_supports ("arch_3_1") && __builtin_cpu_supports ("mma")){ - gemm_function = &Eigen::internal::gemmMMA; + gemm_function = &Eigen::internal::gemmMMA; } else{ - gemm_function = &Eigen::internal::gemm; + gemm_function = &Eigen::internal::gemm; } #else - gemm_function = &Eigen::internal::gemm; + gemm_function = &Eigen::internal::gemm; #endif - gemm_function(res, blockA, blockB, rows, depth, cols, alpha, strideA, strideB, offsetA, offsetB, accRows, accCols); + gemm_function(res, blockA, blockB, rows, depth, cols, alpha, strideA, strideB, offsetA, offsetB); } template @@ -2743,21 +2833,21 @@ void gebp_kernel, std::complex, Index, DataMapper, const int accRows = quad_traits::rows; const int accCols = quad_traits::size; void (*gemm_function)(const DataMapper&, const std::complex*, const std::complex*, - Index, Index, Index, std::complex, Index, Index , Index, Index, const int, const int); + Index, Index, Index, std::complex, Index, Index, Index, Index); #ifdef EIGEN_ALTIVEC_MMA_ONLY //generate with MMA only - gemm_function = &Eigen::internal::gemm_complexMMA, std::complex, std::complex, double, Index, Packet, Packetc, RhsPacket, DataMapper, ConjugateLhs, ConjugateRhs, false, false>; + gemm_function = &Eigen::internal::gemm_complexMMA, std::complex, std::complex, double, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, false, false>; #elif defined(ALTIVEC_MMA_SUPPORT) && !defined(EIGEN_ALTIVEC_DISABLE_MMA) if (__builtin_cpu_supports ("arch_3_1") && __builtin_cpu_supports ("mma")){ - gemm_function = &Eigen::internal::gemm_complexMMA, std::complex, std::complex, double, Index, Packet, Packetc, RhsPacket, DataMapper, ConjugateLhs, ConjugateRhs, false, false>; + gemm_function = &Eigen::internal::gemm_complexMMA, std::complex, std::complex, double, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, false, false>; } else{ - gemm_function = &Eigen::internal::gemm_complex, std::complex, std::complex, double, Index, Packet, Packetc, RhsPacket, DataMapper, ConjugateLhs, ConjugateRhs, false, false>; + gemm_function = &Eigen::internal::gemm_complex, std::complex, std::complex, double, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, false, false>; } #else - gemm_function = &Eigen::internal::gemm_complex, std::complex, std::complex, double, Index, Packet, Packetc, RhsPacket, DataMapper, ConjugateLhs, ConjugateRhs, false, false>; + gemm_function = &Eigen::internal::gemm_complex, std::complex, std::complex, double, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, false, false>; #endif - gemm_function(res, blockA, blockB, rows, depth, cols, alpha, strideA, strideB, offsetA, offsetB, accRows, accCols); + gemm_function(res, blockA, blockB, rows, depth, cols, alpha, strideA, strideB, offsetA, offsetB); } template @@ -2781,21 +2871,21 @@ void gebp_kernel, double, Index, DataMapper, mr, nr, Conjug const int accRows = quad_traits::rows; const int accCols = quad_traits::size; void (*gemm_function)(const DataMapper&, const std::complex*, const double*, - Index, Index, Index, std::complex, Index, Index , Index, Index, const int, const int); + Index, Index, Index, std::complex, Index, Index, Index, Index); #ifdef EIGEN_ALTIVEC_MMA_ONLY //generate with MMA only - gemm_function = &Eigen::internal::gemm_complexMMA, double, std::complex, double, Index, Packet, Packetc, RhsPacket, DataMapper, ConjugateLhs, ConjugateRhs, false, true>; + gemm_function = &Eigen::internal::gemm_complexMMA, double, std::complex, double, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, false, true>; #elif defined(ALTIVEC_MMA_SUPPORT) && !defined(EIGEN_ALTIVEC_DISABLE_MMA) if (__builtin_cpu_supports ("arch_3_1") && __builtin_cpu_supports ("mma")){ - gemm_function = &Eigen::internal::gemm_complexMMA, double, std::complex, double, Index, Packet, Packetc, RhsPacket, DataMapper, ConjugateLhs, ConjugateRhs, false, true>; + gemm_function = &Eigen::internal::gemm_complexMMA, double, std::complex, double, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, false, true>; } else{ - gemm_function = &Eigen::internal::gemm_complex, double, std::complex, double, Index, Packet, Packetc, RhsPacket, DataMapper, ConjugateLhs, ConjugateRhs, false, true>; + gemm_function = &Eigen::internal::gemm_complex, double, std::complex, double, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, false, true>; } #else - gemm_function = &Eigen::internal::gemm_complex, double, std::complex, double, Index, Packet, Packetc, RhsPacket, DataMapper, ConjugateLhs, ConjugateRhs, false, true>; + gemm_function = &Eigen::internal::gemm_complex, double, std::complex, double, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, false, true>; #endif - gemm_function(res, blockA, blockB, rows, depth, cols, alpha, strideA, strideB, offsetA, offsetB, accRows, accCols); + gemm_function(res, blockA, blockB, rows, depth, cols, alpha, strideA, strideB, offsetA, offsetB); } template @@ -2819,21 +2909,21 @@ void gebp_kernel, Index, DataMapper, mr, nr, Conjug const int accRows = quad_traits::rows; const int accCols = quad_traits::size; void (*gemm_function)(const DataMapper&, const double*, const std::complex*, - Index, Index, Index, std::complex, Index, Index , Index, Index, const int, const int); + Index, Index, Index, std::complex, Index, Index, Index, Index); #ifdef EIGEN_ALTIVEC_MMA_ONLY //generate with MMA only - gemm_function = &Eigen::internal::gemm_complexMMA, std::complex, double, Index, Packet, Packetc, RhsPacket, DataMapper, ConjugateLhs, ConjugateRhs, true, false>; + gemm_function = &Eigen::internal::gemm_complexMMA, std::complex, double, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, true, false>; #elif defined(ALTIVEC_MMA_SUPPORT) && !defined(EIGEN_ALTIVEC_DISABLE_MMA) if (__builtin_cpu_supports ("arch_3_1") && __builtin_cpu_supports ("mma")){ - gemm_function = &Eigen::internal::gemm_complexMMA, std::complex, double, Index, Packet, Packetc, RhsPacket, DataMapper, ConjugateLhs, ConjugateRhs, true, false>; + gemm_function = &Eigen::internal::gemm_complexMMA, std::complex, double, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, true, false>; } else{ - gemm_function = &Eigen::internal::gemm_complex, std::complex, double, Index, Packet, Packetc, RhsPacket, DataMapper, ConjugateLhs, ConjugateRhs, true, false>; + gemm_function = &Eigen::internal::gemm_complex, std::complex, double, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, true, false>; } #else - gemm_function = &Eigen::internal::gemm_complex, std::complex, double, Index, Packet, Packetc, RhsPacket, DataMapper, ConjugateLhs, ConjugateRhs, true, false>; + gemm_function = &Eigen::internal::gemm_complex, std::complex, double, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, true, false>; #endif - gemm_function(res, blockA, blockB, rows, depth, cols, alpha, strideA, strideB, offsetA, offsetB, accRows, accCols); + gemm_function(res, blockA, blockB, rows, depth, cols, alpha, strideA, strideB, offsetA, offsetB); } } // end namespace internal diff --git a/Eigen/src/Core/arch/AltiVec/MatrixProductMMA.h b/Eigen/src/Core/arch/AltiVec/MatrixProductMMA.h index 1866a71bf..a67dbccf3 100644 --- a/Eigen/src/Core/arch/AltiVec/MatrixProductMMA.h +++ b/Eigen/src/Core/arch/AltiVec/MatrixProductMMA.h @@ -16,29 +16,22 @@ namespace Eigen { namespace internal { -template -union Packetx2u -{ - __vector_pair vectorpair; - PacketBlock pair; -}; const static Packet16uc MMA_p16uc_SETCOMPLEX32_FIRST = { 0, 1, 2, 3, - 16, 17, 18, 19, - 4, 5, 6, 7, - 20, 21, 22, 23}; + 16, 17, 18, 19, + 4, 5, 6, 7, + 20, 21, 22, 23}; const static Packet16uc MMA_p16uc_SETCOMPLEX32_SECOND = { 8, 9, 10, 11, - 24, 25, 26, 27, - 12, 13, 14, 15, - 28, 29, 30, 31}; + 24, 25, 26, 27, + 12, 13, 14, 15, + 28, 29, 30, 31}; //[a,b],[ai,bi] = [a,ai] - This is equivalent to p16uc_GETREAL64 const static Packet16uc MMA_p16uc_SETCOMPLEX64_FIRST = { 0, 1, 2, 3, 4, 5, 6, 7, - 16, 17, 18, 19, 20, 21, 22, 23}; + 16, 17, 18, 19, 20, 21, 22, 23}; //[a,b],[ai,bi] = [b,bi] - This is equivalent to p16uc_GETIMAG64 const static Packet16uc MMA_p16uc_SETCOMPLEX64_SECOND = { 8, 9, 10, 11, 12, 13, 14, 15, - 24, 25, 26, 27, 28, 29, 30, 31}; - + 24, 25, 26, 27, 28, 29, 30, 31}; // Grab two decouples real/imaginary PacketBlocks and return two coupled (real/imaginary pairs) PacketBlocks. @@ -93,11 +86,11 @@ EIGEN_STRONG_INLINE void bcoupleMMA(PacketBlock template EIGEN_STRONG_INLINE Packet ploadLhsMMA(const Scalar *lhs) { - return *((Packet *)lhs); + return *((Packet *)lhs); } template -EIGEN_STRONG_INLINE PacketBlock pmul (const PacketBlock& a, const Packet& b) +EIGEN_STRONG_INLINE PacketBlock pmul(const PacketBlock& a, const Packet& b) { PacketBlock pb; pb.packet[0] = a.packet[0]*b; @@ -117,13 +110,12 @@ EIGEN_STRONG_INLINE void storeAccumulator(Index i, Index j, const DataMapper& da PacketBlock result; __builtin_mma_disassemble_acc(&result.packet, acc); - PacketBlock block; - block.packet[0] = data.template loadPacket(i, j + 0) + pmul(alpha, result.packet[0]); - block.packet[1] = data.template loadPacket(i, j + 1) + pmul(alpha, result.packet[1]); - block.packet[2] = data.template loadPacket(i, j + 2) + pmul(alpha, result.packet[2]); - block.packet[3] = data.template loadPacket(i, j + 3) + pmul(alpha, result.packet[3]); + result.packet[0] = pmadd(alpha, result.packet[0], data.template loadPacket(i, j + 0)); + result.packet[1] = pmadd(alpha, result.packet[1], data.template loadPacket(i, j + 1)); + result.packet[2] = pmadd(alpha, result.packet[2], data.template loadPacket(i, j + 2)); + result.packet[3] = pmadd(alpha, result.packet[3], data.template loadPacket(i, j + 3)); - data.template storePacketBlock(i, j, block); + data.template storePacketBlock(i, j, result); } template @@ -187,38 +179,221 @@ EIGEN_STRONG_INLINE void pgerMMA(__vector_quad *acc, const RhsPacket& a, const L template<> EIGEN_STRONG_INLINE void pgerMMA, false>(__vector_quad *acc, const PacketBlock& a, const Packet2d& b) { - Packetx2u p; - p.pair = a; - __builtin_mma_xvf64gerpp(acc, p.vectorpair, (__vector unsigned char)b); + __vector_pair *a0 = (__vector_pair *)(&a.packet[0]); + __builtin_mma_xvf64gerpp(acc, *a0, (__vector unsigned char)b); } template<> EIGEN_STRONG_INLINE void pgerMMA, true>(__vector_quad *acc, const PacketBlock& a, const Packet2d& b) { - Packetx2u p; - p.pair = a; - __builtin_mma_xvf64gernp(acc, p.vectorpair, (__vector unsigned char)b); + __vector_pair *a0 = (__vector_pair *)(&a.packet[0]); + __builtin_mma_xvf64gernp(acc, *a0, (__vector unsigned char)b); +} + +template<> +EIGEN_STRONG_INLINE void pgerMMA(__vector_quad *acc, const __vector_pair& a, const Packet2d& b) +{ + __builtin_mma_xvf64gerpp(acc, a, (__vector unsigned char)b); +} + +template<> +EIGEN_STRONG_INLINE void pgerMMA(__vector_quad *acc, const __vector_pair& a, const Packet2d& b) +{ + __builtin_mma_xvf64gernp(acc, a, (__vector unsigned char)b); } // This is necessary because ploadRhs for double returns a pair of vectors when MMA is enabled. template -EIGEN_STRONG_INLINE Packet ploadRhsMMA(const Scalar *rhs) +EIGEN_STRONG_INLINE void ploadRhsMMA(const Scalar *rhs, Packet &rhsV) { - return *((Packet *)rhs); + rhsV = *((Packet *)rhs); } template<> -EIGEN_STRONG_INLINE PacketBlock ploadRhsMMA >(const double *rhs) +EIGEN_STRONG_INLINE void ploadRhsMMA >(const double *rhs, PacketBlock &rhsV) { - PacketBlock pair; - pair.packet[0] = *((Packet2d *)rhs ); - pair.packet[1] = *(((Packet2d *)rhs) + 1); - return pair; + rhsV.packet[0] = *((Packet2d *)rhs ); + rhsV.packet[1] = *(((Packet2d *)rhs) + 1); } -template -void gemmMMA(const DataMapper& res, const Scalar* blockA, const Scalar* blockB, - Index rows, Index depth, Index cols, Scalar alpha, Index strideA, Index strideB, Index offsetA, Index offsetB, const int accRows, const int accCols) +template<> +EIGEN_STRONG_INLINE void ploadRhsMMA(const double *rhs, __vector_pair &rhsV) +{ + __builtin_mma_assemble_pair(&rhsV, (__vector unsigned char)(*(((Packet2d *)rhs) + 1)), (__vector unsigned char)(*((Packet2d *)rhs))); +} + +template +EIGEN_STRONG_INLINE void gemm_extra_col( + const DataMapper& res, + const Scalar *lhs_base, + const Scalar *rhs_base, + Index depth, + Index strideA, + Index offsetA, + Index row, + Index col, + Index remaining_rows, + Index remaining_cols, + const Packet& pAlpha); + +template +EIGEN_STRONG_INLINE void gemm_extra_row( + const DataMapper& res, + const Scalar *lhs_base, + const Scalar *rhs_base, + Index depth, + Index strideA, + Index offsetA, + Index row, + Index col, + Index cols, + Index remaining_rows, + const Packet& pAlpha, + const Packet& pMask); + +template +EIGEN_STRONG_INLINE void gemm_unrolled_col( + const DataMapper& res, + const Scalar *lhs_base, + const Scalar *rhs_base, + Index depth, + Index strideA, + Index offsetA, + Index& row, + Index rows, + Index col, + Index remaining_cols, + const Packet& pAlpha); + +template +EIGEN_STRONG_INLINE Packet bmask(const int remaining_rows); + +#define MICRO_MMA_DST \ + __vector_quad *accZero0, __vector_quad *accZero1, __vector_quad *accZero2, \ + __vector_quad *accZero3, __vector_quad *accZero4, __vector_quad *accZero5, \ + __vector_quad *accZero6, __vector_quad *accZero7 + +#define MICRO_MMA_SRC \ + const Scalar **lhs_ptr0, const Scalar **lhs_ptr1, const Scalar **lhs_ptr2, \ + const Scalar **lhs_ptr3, const Scalar **lhs_ptr4, const Scalar **lhs_ptr5, \ + const Scalar **lhs_ptr6, const Scalar **lhs_ptr7 + +#define MICRO_MMA_ONE \ + if (sizeof(Scalar) == sizeof(float)) { \ + MICRO_MMA(\ + &lhs_ptr0, &lhs_ptr1, &lhs_ptr2, &lhs_ptr3, &lhs_ptr4, &lhs_ptr5, &lhs_ptr6, &lhs_ptr7, \ + rhs_ptr, \ + &accZero0, &accZero1, &accZero2, &accZero3, &accZero4, &accZero5, &accZero6, &accZero7); \ + } else { \ + MICRO_MMA(\ + &lhs_ptr0, &lhs_ptr1, &lhs_ptr2, &lhs_ptr3, &lhs_ptr4, &lhs_ptr5, &lhs_ptr6, &lhs_ptr7, \ + rhs_ptr, \ + &accZero0, &accZero1, &accZero2, &accZero3, &accZero4, &accZero5, &accZero6, &accZero7); \ + } + +#define MICRO_MMA_WORK_ONE(iter) \ + if (N > iter) { \ + Packet lhsV = ploadLhsMMA(*lhs_ptr##iter); \ + pgerMMA(accZero##iter, rhsV, lhsV); \ + *lhs_ptr##iter += accCols; \ + } else { \ + EIGEN_UNUSED_VARIABLE(accZero##iter); \ + EIGEN_UNUSED_VARIABLE(lhs_ptr##iter); \ + } + +#define MICRO_MMA_UNROLL(func) \ + func(0) func(1) func(2) func(3) func(4) func(5) func(6) func(7) + +#define MICRO_MMA_WORK MICRO_MMA_UNROLL(MICRO_MMA_WORK_ONE) + +#define MICRO_MMA_DST_PTR_ONE(iter) \ + if (unroll_factor > iter){ \ + bsetzeroMMA(&accZero##iter); \ + } else { \ + EIGEN_UNUSED_VARIABLE(accZero##iter); \ + } + +#define MICRO_MMA_DST_PTR MICRO_MMA_UNROLL(MICRO_MMA_DST_PTR_ONE) + +#define MICRO_MMA_SRC_PTR_ONE(iter) \ + if (unroll_factor > iter) { \ + lhs_ptr##iter = lhs_base + ( (row/accCols) + iter )*strideA*accCols + accCols*offsetA; \ + } else { \ + EIGEN_UNUSED_VARIABLE(lhs_ptr##iter); \ + } + +#define MICRO_MMA_SRC_PTR MICRO_MMA_UNROLL(MICRO_MMA_SRC_PTR_ONE) + +#define MICRO_MMA_PREFETCH_ONE(iter) \ + if (unroll_factor > iter){ \ + prefetch(lhs_ptr##iter); \ + } + +#define MICRO_MMA_PREFETCH MICRO_MMA_UNROLL(MICRO_MMA_PREFETCH_ONE) + +#define MICRO_MMA_STORE_ONE(iter) \ + if (unroll_factor > iter){ \ + storeAccumulator(row + iter*accCols, col, res, pAlpha, &accZero##iter); \ + } + +#define MICRO_MMA_STORE MICRO_MMA_UNROLL(MICRO_MMA_STORE_ONE) + +// PEEL_MMA loop factor. +#define PEEL_MMA 10 + +template +EIGEN_STRONG_INLINE void MICRO_MMA( + MICRO_MMA_SRC, + const Scalar* &rhs_ptr, + MICRO_MMA_DST) + { + RhsPacket rhsV; + ploadRhsMMA(rhs_ptr, rhsV); + MICRO_MMA_WORK + rhs_ptr += accRows; + } + +template +EIGEN_STRONG_INLINE void gemm_unrolled_MMA_iteration( + const DataMapper& res, + const Scalar *lhs_base, + const Scalar *rhs_base, + Index depth, + Index strideA, + Index offsetA, + Index& row, + Index col, + const Packet& pAlpha) +{ + const Scalar *rhs_ptr = rhs_base; + const Scalar *lhs_ptr0, *lhs_ptr1, *lhs_ptr2, *lhs_ptr3, *lhs_ptr4, *lhs_ptr5, *lhs_ptr6, *lhs_ptr7; + __vector_quad accZero0, accZero1, accZero2, accZero3, accZero4, accZero5, accZero6, accZero7; + + asm("#unrolled MMA start"); + MICRO_MMA_SRC_PTR + MICRO_MMA_DST_PTR + + Index k = 0; + for(; k + PEEL_MMA <= depth; k+= PEEL_MMA) + { + prefetch(rhs_ptr); + MICRO_MMA_PREFETCH + for (int l = 0; l < PEEL_MMA; l++) { + MICRO_MMA_ONE + } + } + for(; k < depth; k++) + { + MICRO_MMA_ONE + } + MICRO_MMA_STORE + + row += unroll_factor*accCols; + asm("#unrolled MMA end"); +} + +template +void gemmMMA(const DataMapper& res, const Scalar* blockA, const Scalar* blockB, Index rows, Index depth, Index cols, Scalar alpha, Index strideA, Index strideB, Index offsetA, Index offsetB) { const Index remaining_rows = rows % accCols; const Index remaining_cols = cols % accRows; @@ -227,111 +402,89 @@ void gemmMMA(const DataMapper& res, const Scalar* blockA, const Scalar* blockB, if( strideB == -1 ) strideB = depth; const Packet pAlpha = pset1(alpha); + const Packet pMask = bmask((const int)(remaining_rows)); + Index col = 0; for(; col + accRows <= cols; col += accRows) { - const Scalar *rhs_base = blockB + ( col/accRows )*strideB*accRows; + const Scalar *rhs_base = blockB + col*strideB + accRows*offsetB; const Scalar *lhs_base = blockA; Index row = 0; - for(; row + accCols <= rows; row += accCols) - { - const Scalar *rhs_ptr = rhs_base; - const Scalar *lhs_ptr1 = lhs_base + (row/accCols)*strideA*accCols; - - __vector_quad acc; - bsetzeroMMA(&acc); - - lhs_ptr1 += accCols*offsetA; - rhs_ptr += accRows*offsetB; - for(Index k = 0; k < depth; k++) - { - Packet lhsV = ploadLhsMMA(lhs_ptr1); - RhsPacket rhsV = ploadRhsMMA(rhs_ptr); - - pgerMMA(&acc, rhsV, lhsV); - - lhs_ptr1 += accCols; - rhs_ptr += accRows; - } - - storeAccumulator(row, col, res, pAlpha, &acc); +#define MAX_MMA_UNROLL 7 + while(row + MAX_MMA_UNROLL*accCols <= rows){ + gemm_unrolled_MMA_iteration(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, pAlpha); } + switch( (rows-row)/accCols ){ +#if MAX_MMA_UNROLL > 7 + case 7: + gemm_unrolled_MMA_iteration<7, Scalar, Packet, RhsPacket, DataMapper, Index, accRows, accCols>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, pAlpha); + break; +#endif +#if MAX_MMA_UNROLL > 6 + case 6: + gemm_unrolled_MMA_iteration<6, Scalar, Packet, RhsPacket, DataMapper, Index, accRows, accCols>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, pAlpha); + break; +#endif +#if MAX_MMA_UNROLL > 5 + case 5: + gemm_unrolled_MMA_iteration<5, Scalar, Packet, RhsPacket, DataMapper, Index, accRows, accCols>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, pAlpha); + break; +#endif +#if MAX_MMA_UNROLL > 4 + case 4: + gemm_unrolled_MMA_iteration<4, Scalar, Packet, RhsPacket, DataMapper, Index, accRows, accCols>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, pAlpha); + break; +#endif +#if MAX_MMA_UNROLL > 3 + case 3: + gemm_unrolled_MMA_iteration<3, Scalar, Packet, RhsPacket, DataMapper, Index, accRows, accCols>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, pAlpha); + break; +#endif +#if MAX_MMA_UNROLL > 2 + case 2: + gemm_unrolled_MMA_iteration<2, Scalar, Packet, RhsPacket, DataMapper, Index, accRows, accCols>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, pAlpha); + break; +#endif +#if MAX_MMA_UNROLL > 1 + case 1: + gemm_unrolled_MMA_iteration<1, Scalar, Packet, RhsPacket, DataMapper, Index, accRows, accCols>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, pAlpha); + break; +#endif + default: + break; + } +#undef MAX_MMA_UNROLL + if(remaining_rows > 0) { - const Scalar *rhs_ptr = rhs_base; - const Scalar *lhs_ptr = lhs_base + (row/accCols)*strideA*accCols; - - lhs_ptr += remaining_rows*offsetA; - rhs_ptr += accRows*offsetB; - for(Index k = 0; k < depth; k++) - { - for(Index arow = 0; arow < remaining_rows; arow++) - { - for(Index acol = 0; acol < accRows; acol++ ) - { - res(row + arow, col + acol) += alpha*lhs_ptr[arow]*rhs_ptr[acol]; - } - } - rhs_ptr += accRows; - lhs_ptr += remaining_rows; - } + gemm_extra_row(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, cols, remaining_rows, pAlpha, pMask); } } if(remaining_cols > 0) { - const Scalar *rhs_base = blockB + (col/accRows)*strideB*accRows; + const Scalar *rhs_base = blockB + col*strideB + remaining_cols*offsetB; const Scalar *lhs_base = blockA; - Index row = 0; - for(; row + accCols <= rows; row += accCols) + for(; col < cols; col++) { - const Scalar *rhs_ptr = rhs_base; - const Scalar *lhs_ptr = lhs_base + (row/accCols)*strideA*accCols; + Index row = 0; - lhs_ptr += accCols*offsetA; - rhs_ptr += remaining_cols*offsetB; - for(Index k = 0; k < depth; k++) - { - for(Index arow = 0; arow < accCols; arow++) - { - for(Index acol = 0; acol < remaining_cols; acol++ ) - { - res(row + arow, col + acol) += alpha*lhs_ptr[arow]*rhs_ptr[acol]; - } - } - rhs_ptr += remaining_cols; - lhs_ptr += accCols; - } - } - - if(remaining_rows > 0 ) - { - const Scalar *rhs_ptr = rhs_base; - const Scalar *lhs_ptr = lhs_base + (row/accCols)*strideA*accCols; + gemm_unrolled_col(res, lhs_base, rhs_base, depth, strideA, offsetA, row, rows, col, remaining_cols, pAlpha); - lhs_ptr += remaining_rows*offsetA; - rhs_ptr += remaining_cols*offsetB; - for(Index k = 0; k < depth; k++) + if (remaining_rows > 0) { - for(Index arow = 0; arow < remaining_rows; arow++) - { - for(Index acol = 0; acol < remaining_cols; acol++ ) - { - res(row + arow, col + acol) += alpha*lhs_ptr[arow]*rhs_ptr[acol]; - } - } - rhs_ptr += remaining_cols; - lhs_ptr += remaining_rows; + gemm_extra_col(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, remaining_rows, remaining_cols, pAlpha); } + rhs_base++; } } } -template +template void gemm_complexMMA(const DataMapper& res, const LhsScalar* blockAc, const RhsScalar* blockBc, - Index rows, Index depth, Index cols, Scalarc alpha, Index strideA, Index strideB, Index offsetA, Index offsetB, const int accRows, const int accCols) + Index rows, Index depth, Index cols, Scalarc alpha, Index strideA, Index strideB, Index offsetA, Index offsetB) { const int remaining_rows = rows % accCols; const int remaining_cols = cols % accRows; diff --git a/Eigen/src/Core/arch/AltiVec/PacketMath.h b/Eigen/src/Core/arch/AltiVec/PacketMath.h index df04b8e0f..495afac90 100755 --- a/Eigen/src/Core/arch/AltiVec/PacketMath.h +++ b/Eigen/src/Core/arch/AltiVec/PacketMath.h @@ -2338,12 +2338,11 @@ template<> EIGEN_STRONG_INLINE void pbroadcast4(const double *a, Packet2d& a0, Packet2d& a1, Packet2d& a2, Packet2d& a3) { - a1 = pload(a); - a0 = vec_splat_dbl<0>(a1); - a1 = vec_splat_dbl<1>(a1); - a3 = pload(a+2); - a2 = vec_splat_dbl<0>(a3); - a3 = vec_splat_dbl<1>(a3); + //This way is faster than vec_splat (at least for doubles in Power 9) + a0 = pset1(a[0]); + a1 = pset1(a[1]); + a2 = pset1(a[2]); + a3 = pset1(a[3]); } template<> EIGEN_DEVICE_FUNC inline Packet2d pgather(const double* from, Index stride)