From 4e598ad259bf9561cd5882326e7cafc585d14f47 Mon Sep 17 00:00:00 2001 From: Chip Kerchner Date: Wed, 6 Sep 2023 20:03:45 +0000 Subject: [PATCH] New panel modes for GEMM MMA (real & complex). --- Eigen/src/Core/arch/AltiVec/MatrixProduct.h | 456 ++++++++++------ .../src/Core/arch/AltiVec/MatrixProductMMA.h | 510 ++++++++++++------ 2 files changed, 631 insertions(+), 335 deletions(-) diff --git a/Eigen/src/Core/arch/AltiVec/MatrixProduct.h b/Eigen/src/Core/arch/AltiVec/MatrixProduct.h index b232a8fb5..e9a930711 100644 --- a/Eigen/src/Core/arch/AltiVec/MatrixProduct.h +++ b/Eigen/src/Core/arch/AltiVec/MatrixProduct.h @@ -49,11 +49,6 @@ #include "MatrixProductMMA.h" #endif -/************************************************************************************************** - * TODO * - * - Check StorageOrder on dhs_pack (the innermost second loop seems unvectorized when it could). * - * - Check the possibility of transposing as GETREAL and GETIMAG when needed. * - **************************************************************************************************/ // IWYU pragma: private #include "../../InternalHeaderCheck.h" @@ -120,6 +115,16 @@ const static Packet16uc p16uc_GETIMAG32 = { 4, 5, 6, 7, 20, 21, 22, 23, 28, 29, 30, 31}; +const static Packet16uc p16uc_GETREAL32b = { 0, 1, 2, 3, + 16, 17, 18, 19, + 8, 9, 10, 11, + 24, 25, 26, 27}; + +const static Packet16uc p16uc_GETIMAG32b = { 4, 5, 6, 7, + 20, 21, 22, 23, + 12, 13, 14, 15, + 28, 29, 30, 31}; + /********************************************* * Single precision real and complex packing * * *******************************************/ @@ -440,6 +445,78 @@ EIGEN_ALWAYS_INLINE void storeBlock(Scalar* to, PacketBlock& block) // General template for lhs & rhs complex packing. template struct dhs_cpack { + template + EIGEN_ALWAYS_INLINE void dhs_cblock(PacketBlock& cblock, PacketBlock& block, Packet16uc permute) + { + if (transpose) { + block.packet[0] = vec_perm(cblock.packet[0].v, cblock.packet[1].v, permute); + block.packet[1] = vec_perm(cblock.packet[2].v, cblock.packet[3].v, permute); + block.packet[2] = vec_perm(cblock.packet[4].v, cblock.packet[5].v, permute); + block.packet[3] = vec_perm(cblock.packet[6].v, cblock.packet[7].v, permute); + + Packet4f t0, t1, t2, t3; +#ifdef EIGEN_VECTORIZE_VSX + t0 = reinterpret_cast(vec_mergeh(reinterpret_cast(block.packet[0]), reinterpret_cast(block.packet[1]))); + t1 = reinterpret_cast(vec_mergel(reinterpret_cast(block.packet[0]), reinterpret_cast(block.packet[1]))); + t2 = reinterpret_cast(vec_mergeh(reinterpret_cast(block.packet[2]), reinterpret_cast(block.packet[3]))); + t3 = reinterpret_cast(vec_mergel(reinterpret_cast(block.packet[2]), reinterpret_cast(block.packet[3]))); +#else + t0 = reinterpret_cast(vec_perm(block.packet[0], block.packet[1], p16uc_TRANSPOSE64_HI)); + t1 = reinterpret_cast(vec_perm(block.packet[0], block.packet[1], p16uc_TRANSPOSE64_LO)); + t2 = reinterpret_cast(vec_perm(block.packet[2], block.packet[3], p16uc_TRANSPOSE64_HI)); + t3 = reinterpret_cast(vec_perm(block.packet[2], block.packet[3], p16uc_TRANSPOSE64_LO)); +#endif + + block.packet[0] = t0; + block.packet[1] = t1; + block.packet[2] = t2; + block.packet[3] = t3; + } else { + block.packet[0] = vec_perm(cblock.packet[0].v, cblock.packet[4].v, permute); + block.packet[1] = vec_perm(cblock.packet[1].v, cblock.packet[5].v, permute); + block.packet[2] = vec_perm(cblock.packet[2].v, cblock.packet[6].v, permute); + block.packet[3] = vec_perm(cblock.packet[3].v, cblock.packet[7].v, permute); + } + } + + EIGEN_ALWAYS_INLINE void dhs_ccopy(Scalar* blockAt, const DataMapper& lhs2, Index& i, Index& rir, Index& rii, Index depth, const Index vectorSize) + { + PacketBlock blockr, blocki; + PacketBlock cblock; + + for(; i + vectorSize <= depth; i+=vectorSize) + { + if (UseLhs) { + bload(cblock, lhs2, 0, i); + } else { + bload(cblock, lhs2, i, 0); + } + + if(((StorageOrder == RowMajor) && UseLhs) || (((StorageOrder == ColMajor) && !UseLhs))) + { + dhs_cblock(cblock, blockr, p16uc_GETREAL32b); + dhs_cblock(cblock, blocki, p16uc_GETIMAG32b); + } else { + dhs_cblock(cblock, blockr, p16uc_GETREAL32); + dhs_cblock(cblock, blocki, p16uc_GETIMAG32); + } + + if(Conjugate) + { + blocki.packet[0] = -blocki.packet[0]; + blocki.packet[1] = -blocki.packet[1]; + blocki.packet[2] = -blocki.packet[2]; + blocki.packet[3] = -blocki.packet[3]; + } + + storeBlock(blockAt + rir, blockr); + storeBlock(blockAt + rii, blocki); + + rir += 4*vectorSize; + rii += 4*vectorSize; + } + } + EIGEN_STRONG_INLINE void operator()(std::complex* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride, Index offset) { const Index vectorSize = quad_traits::vectorsize; @@ -455,47 +532,8 @@ struct dhs_cpack { rii = rir + vectorDelta; - for(; i + vectorSize <= depth; i+=vectorSize) - { - PacketBlock blockr, blocki; - PacketBlock cblock; + dhs_ccopy(blockAt, lhs2, i, rir, rii, depth, vectorSize); - if (UseLhs) { - bload(cblock, lhs2, 0, i); - } else { - bload(cblock, lhs2, i, 0); - } - - blockr.packet[0] = vec_perm(cblock.packet[0].v, cblock.packet[4].v, p16uc_GETREAL32); - blockr.packet[1] = vec_perm(cblock.packet[1].v, cblock.packet[5].v, p16uc_GETREAL32); - blockr.packet[2] = vec_perm(cblock.packet[2].v, cblock.packet[6].v, p16uc_GETREAL32); - blockr.packet[3] = vec_perm(cblock.packet[3].v, cblock.packet[7].v, p16uc_GETREAL32); - - blocki.packet[0] = vec_perm(cblock.packet[0].v, cblock.packet[4].v, p16uc_GETIMAG32); - blocki.packet[1] = vec_perm(cblock.packet[1].v, cblock.packet[5].v, p16uc_GETIMAG32); - blocki.packet[2] = vec_perm(cblock.packet[2].v, cblock.packet[6].v, p16uc_GETIMAG32); - blocki.packet[3] = vec_perm(cblock.packet[3].v, cblock.packet[7].v, p16uc_GETIMAG32); - - if(Conjugate) - { - blocki.packet[0] = -blocki.packet[0]; - blocki.packet[1] = -blocki.packet[1]; - blocki.packet[2] = -blocki.packet[2]; - blocki.packet[3] = -blocki.packet[3]; - } - - if(((StorageOrder == RowMajor) && UseLhs) || (((StorageOrder == ColMajor) && !UseLhs))) - { - ptranspose(blockr); - ptranspose(blocki); - } - - storeBlock(blockAt + rir, blockr); - storeBlock(blockAt + rii, blocki); - - rir += 4*vectorSize; - rii += 4*vectorSize; - } for(; i < depth; i++) { PacketBlock blockr, blocki; @@ -592,6 +630,36 @@ struct dhs_cpack { // General template for lhs & rhs packing. template struct dhs_pack{ + template + EIGEN_ALWAYS_INLINE void dhs_copy(Scalar* blockA, const DataMapper& lhs2, Index& i, Index& ri, Index depth, const Index vectorSize) + { + PacketBlock block[n]; + + for(; i + n*vectorSize <= depth; i+=n*vectorSize) + { + for (Index k = 0; k < n; k++) { + if (UseLhs) { + bload(block[k], lhs2, 0, i + k*vectorSize); + } else { + bload(block[k], lhs2, i + k*vectorSize, 0); + } + } + + if(((StorageOrder == RowMajor) && UseLhs) || ((StorageOrder == ColMajor) && !UseLhs)) + { + for (Index k = 0; k < n; k++) { + ptranspose(block[k]); + } + } + + for (Index k = 0; k < n; k++) { + storeBlock(blockA + ri + k*4*vectorSize, block[k]); + } + + ri += n*4*vectorSize; + } + } + EIGEN_STRONG_INLINE void operator()(Scalar* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride, Index offset) { const Index vectorSize = quad_traits::vectorsize; @@ -604,24 +672,10 @@ struct dhs_pack{ if(PanelMode) ri += vectorSize*offset; - for(; i + vectorSize <= depth; i+=vectorSize) - { - PacketBlock block; + dhs_copy<4>(blockA, lhs2, i, ri, depth, vectorSize); + dhs_copy<2>(blockA, lhs2, i, ri, depth, vectorSize); + dhs_copy<1>(blockA, lhs2, i, ri, depth, vectorSize); - if (UseLhs) { - bload(block, lhs2, 0, i); - } else { - bload(block, lhs2, i, 0); - } - if(((StorageOrder == RowMajor) && UseLhs) || ((StorageOrder == ColMajor) && !UseLhs)) - { - ptranspose(block); - } - - storeBlock(blockA + ri, block); - - ri += 4*vectorSize; - } for(; i < depth; i++) { if(((StorageOrder == RowMajor) && UseLhs) || ((StorageOrder == ColMajor) && !UseLhs)) @@ -691,6 +745,39 @@ struct dhs_pack{ template struct dhs_pack { + template + EIGEN_ALWAYS_INLINE void dhs_copy(double* blockA, const DataMapper& lhs2, Index& i, Index& ri, Index depth, const Index vectorSize) + { + PacketBlock block[n]; + + for(; i + n*vectorSize <= depth; i+=n*vectorSize) + { + for (Index k = 0; k < n; k++) { + if(StorageOrder == RowMajor) + { + block[k].packet[0] = lhs2.template loadPacket(0, i + k*vectorSize); + block[k].packet[1] = lhs2.template loadPacket(1, i + k*vectorSize); + } else { + block[k].packet[0] = lhs2.template loadPacket(0, i + k*vectorSize + 0); + block[k].packet[1] = lhs2.template loadPacket(0, i + k*vectorSize + 1); + } + } + + if(StorageOrder == RowMajor) + { + for (Index k = 0; k < n; k++) { + ptranspose(block[k]); + } + } + + for (Index k = 0; k < n; k++) { + storeBlock(blockA + ri + k*2*vectorSize, block[k]); + } + + ri += n*2*vectorSize; + } + } + EIGEN_STRONG_INLINE void operator()(double* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride, Index offset) { const Index vectorSize = quad_traits::vectorsize; @@ -703,24 +790,10 @@ struct dhs_pack if(PanelMode) ri += vectorSize*offset; - for(; i + vectorSize <= depth; i+=vectorSize) - { - PacketBlock block; - if(StorageOrder == RowMajor) - { - block.packet[0] = lhs2.template loadPacket(0, i); - block.packet[1] = lhs2.template loadPacket(1, i); + dhs_copy<4>(blockA, lhs2, i, ri, depth, vectorSize); + dhs_copy<2>(blockA, lhs2, i, ri, depth, vectorSize); + dhs_copy<1>(blockA, lhs2, i, ri, depth, vectorSize); - ptranspose(block); - } else { - block.packet[0] = lhs2.template loadPacket(0, i + 0); - block.packet[1] = lhs2.template loadPacket(0, i + 1); - } - - storeBlock(blockA + ri, block); - - ri += 2*vectorSize; - } for(; i < depth; i++) { if(StorageOrder == RowMajor) @@ -759,6 +832,53 @@ struct dhs_pack template struct dhs_pack { + template + EIGEN_ALWAYS_INLINE void dhs_copy(double* blockB, const DataMapper& rhs2, Index& i, Index& ri, Index depth, const Index vectorSize) + { + PacketBlock block1[n], block2[n]; + PacketBlock block3[n]; + + for(; i + n*vectorSize <= depth; i+=n*vectorSize) + { + for (Index k = 0; k < n; k++) { + if(StorageOrder == ColMajor) + { + block1[k].packet[0] = rhs2.template loadPacket(i + k*vectorSize, 0); + block1[k].packet[1] = rhs2.template loadPacket(i + k*vectorSize, 1); + block2[k].packet[0] = rhs2.template loadPacket(i + k*vectorSize, 2); + block2[k].packet[1] = rhs2.template loadPacket(i + k*vectorSize, 3); + } else { + block3[k].packet[0] = rhs2.template loadPacket(i + k*vectorSize + 0, 0); //[a1 a2] + block3[k].packet[1] = rhs2.template loadPacket(i + k*vectorSize + 0, 2); //[a3 a4] + block3[k].packet[2] = rhs2.template loadPacket(i + k*vectorSize + 1, 0); //[b1 b2] + block3[k].packet[3] = rhs2.template loadPacket(i + k*vectorSize + 1, 2); //[b3 b4] + } + } + + if(StorageOrder == ColMajor) + { + for (Index k = 0; k < n; k++) { + ptranspose(block1[k]); + ptranspose(block2[k]); + } + } + + for (Index k = 0; k < n; k++) { + if(StorageOrder == ColMajor) + { + pstore(blockB + ri + k*4*vectorSize , block1[k].packet[0]); + pstore(blockB + ri + k*4*vectorSize + 2, block2[k].packet[0]); + pstore(blockB + ri + k*4*vectorSize + 4, block1[k].packet[1]); + pstore(blockB + ri + k*4*vectorSize + 6, block2[k].packet[1]); + } else { + storeBlock(blockB + ri + k*4*vectorSize, block3[k]); + } + } + + ri += n*4*vectorSize; + } + } + EIGEN_STRONG_INLINE void operator()(double* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride, Index offset) { const Index vectorSize = quad_traits::vectorsize; @@ -771,35 +891,10 @@ struct dhs_pack if(PanelMode) ri += offset*(2*vectorSize); - for(; i + vectorSize <= depth; i+=vectorSize) - { - PacketBlock block; - if(StorageOrder == ColMajor) - { - PacketBlock block1, block2; - block1.packet[0] = rhs2.template loadPacket(i, 0); - block1.packet[1] = rhs2.template loadPacket(i, 1); - block2.packet[0] = rhs2.template loadPacket(i, 2); - block2.packet[1] = rhs2.template loadPacket(i, 3); + dhs_copy<4>(blockB, rhs2, i, ri, depth, vectorSize); + dhs_copy<2>(blockB, rhs2, i, ri, depth, vectorSize); + dhs_copy<1>(blockB, rhs2, i, ri, depth, vectorSize); - ptranspose(block1); - ptranspose(block2); - - pstore(blockB + ri , block1.packet[0]); - pstore(blockB + ri + 2, block2.packet[0]); - pstore(blockB + ri + 4, block1.packet[1]); - pstore(blockB + ri + 6, block2.packet[1]); - } else { - block.packet[0] = rhs2.template loadPacket(i + 0, 0); //[a1 a2] - block.packet[1] = rhs2.template loadPacket(i + 0, 2); //[a3 a4] - block.packet[2] = rhs2.template loadPacket(i + 1, 0); //[b1 b2] - block.packet[3] = rhs2.template loadPacket(i + 1, 2); //[b3 b4] - - storeBlock(blockB + ri, block); - } - - ri += 4*vectorSize; - } for(; i < depth; i++) { if(StorageOrder == ColMajor) @@ -1296,6 +1391,54 @@ struct dhs_pack template struct dhs_cpack { + EIGEN_ALWAYS_INLINE void dhs_ccopy(double* blockAt, const DataMapper& lhs2, Index& i, Index& rir, Index& rii, Index depth, const Index vectorSize) + { + PacketBlock blockr, blocki; + PacketBlock cblock; + + for(; i + vectorSize <= depth; i+=vectorSize) + { + if(StorageOrder == ColMajor) + { + cblock.packet[0] = lhs2.template loadPacket(0, i + 0); //[a1 a1i] + cblock.packet[1] = lhs2.template loadPacket(0, i + 1); //[b1 b1i] + + cblock.packet[2] = lhs2.template loadPacket(1, i + 0); //[a2 a2i] + cblock.packet[3] = lhs2.template loadPacket(1, i + 1); //[b2 b2i] + + blockr.packet[0] = vec_mergeh(cblock.packet[0].v, cblock.packet[2].v); //[a1 a2] + blockr.packet[1] = vec_mergeh(cblock.packet[1].v, cblock.packet[3].v); //[b1 b2] + + blocki.packet[0] = vec_mergel(cblock.packet[0].v, cblock.packet[2].v); + blocki.packet[1] = vec_mergel(cblock.packet[1].v, cblock.packet[3].v); + } else { + cblock.packet[0] = lhs2.template loadPacket(0, i); //[a1 a1i] + cblock.packet[1] = lhs2.template loadPacket(1, i); //[a2 a2i] + + cblock.packet[2] = lhs2.template loadPacket(0, i + 1); //[b1 b1i] + cblock.packet[3] = lhs2.template loadPacket(1, i + 1); //[b2 b2i + + blockr.packet[0] = vec_mergeh(cblock.packet[0].v, cblock.packet[1].v); //[a1 a2] + blockr.packet[1] = vec_mergeh(cblock.packet[2].v, cblock.packet[3].v); //[b1 b2] + + blocki.packet[0] = vec_mergel(cblock.packet[0].v, cblock.packet[1].v); + blocki.packet[1] = vec_mergel(cblock.packet[2].v, cblock.packet[3].v); + } + + if(Conjugate) + { + blocki.packet[0] = -blocki.packet[0]; + blocki.packet[1] = -blocki.packet[1]; + } + + storeBlock(blockAt + rir, blockr); + storeBlock(blockAt + rii, blocki); + + rir += 2*vectorSize; + rii += 2*vectorSize; + } + } + EIGEN_STRONG_INLINE void operator()(std::complex* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride, Index offset) { const Index vectorSize = quad_traits::vectorsize; @@ -1311,50 +1454,8 @@ struct dhs_cpack blockr, blocki; - PacketBlock cblock; + dhs_ccopy(blockAt, lhs2, i, rir, rii, depth, vectorSize); - if(StorageOrder == ColMajor) - { - cblock.packet[0] = lhs2.template loadPacket(0, i + 0); //[a1 a1i] - cblock.packet[1] = lhs2.template loadPacket(0, i + 1); //[b1 b1i] - - cblock.packet[2] = lhs2.template loadPacket(1, i + 0); //[a2 a2i] - cblock.packet[3] = lhs2.template loadPacket(1, i + 1); //[b2 b2i] - - blockr.packet[0] = vec_mergeh(cblock.packet[0].v, cblock.packet[2].v); //[a1 a2] - blockr.packet[1] = vec_mergeh(cblock.packet[1].v, cblock.packet[3].v); //[b1 b2] - - blocki.packet[0] = vec_mergel(cblock.packet[0].v, cblock.packet[2].v); - blocki.packet[1] = vec_mergel(cblock.packet[1].v, cblock.packet[3].v); - } else { - cblock.packet[0] = lhs2.template loadPacket(0, i); //[a1 a1i] - cblock.packet[1] = lhs2.template loadPacket(1, i); //[a2 a2i] - - cblock.packet[2] = lhs2.template loadPacket(0, i + 1); //[b1 b1i] - cblock.packet[3] = lhs2.template loadPacket(1, i + 1); //[b2 b2i - - blockr.packet[0] = vec_mergeh(cblock.packet[0].v, cblock.packet[1].v); //[a1 a2] - blockr.packet[1] = vec_mergeh(cblock.packet[2].v, cblock.packet[3].v); //[b1 b2] - - blocki.packet[0] = vec_mergel(cblock.packet[0].v, cblock.packet[1].v); - blocki.packet[1] = vec_mergel(cblock.packet[2].v, cblock.packet[3].v); - } - - if(Conjugate) - { - blocki.packet[0] = -blocki.packet[0]; - blocki.packet[1] = -blocki.packet[1]; - } - - storeBlock(blockAt + rir, blockr); - storeBlock(blockAt + rii, blocki); - - rir += 2*vectorSize; - rii += 2*vectorSize; - } for(; i < depth; i++) { PacketBlock blockr, blocki; @@ -1410,6 +1511,35 @@ struct dhs_cpack struct dhs_cpack { + EIGEN_ALWAYS_INLINE void dhs_ccopy(double* blockBt, const DataMapper& rhs2, Index& i, Index& rir, Index& rii, Index depth, const Index vectorSize) + { + for(; i < depth; i++) + { + PacketBlock cblock; + PacketBlock blockr, blocki; + + bload(cblock, rhs2, i, 0); + + blockr.packet[0] = vec_mergeh(cblock.packet[0].v, cblock.packet[1].v); + blockr.packet[1] = vec_mergeh(cblock.packet[2].v, cblock.packet[3].v); + + blocki.packet[0] = vec_mergel(cblock.packet[0].v, cblock.packet[1].v); + blocki.packet[1] = vec_mergel(cblock.packet[2].v, cblock.packet[3].v); + + if(Conjugate) + { + blocki.packet[0] = -blocki.packet[0]; + blocki.packet[1] = -blocki.packet[1]; + } + + storeBlock(blockBt + rir, blockr); + storeBlock(blockBt + rii, blocki); + + rir += 2*vectorSize; + rii += 2*vectorSize; + } + } + EIGEN_STRONG_INLINE void operator()(std::complex* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride, Index offset) { const Index vectorSize = quad_traits::vectorsize; @@ -1425,31 +1555,7 @@ struct dhs_cpack cblock; - PacketBlock blockr, blocki; - - bload(cblock, rhs2, i, 0); - - blockr.packet[0] = vec_mergeh(cblock.packet[0].v, cblock.packet[1].v); - blockr.packet[1] = vec_mergeh(cblock.packet[2].v, cblock.packet[3].v); - - blocki.packet[0] = vec_mergel(cblock.packet[0].v, cblock.packet[1].v); - blocki.packet[1] = vec_mergel(cblock.packet[2].v, cblock.packet[3].v); - - if(Conjugate) - { - blocki.packet[0] = -blocki.packet[0]; - blocki.packet[1] = -blocki.packet[1]; - } - - storeBlock(blockBt + rir, blockr); - storeBlock(blockBt + rii, blocki); - - rir += 2*vectorSize; - rii += 2*vectorSize; - } + dhs_ccopy(blockBt, rhs2, i, rir, rii, depth, vectorSize); rir += ((PanelMode) ? (2*vectorSize*(2*stride - depth)) : vectorDelta); } diff --git a/Eigen/src/Core/arch/AltiVec/MatrixProductMMA.h b/Eigen/src/Core/arch/AltiVec/MatrixProductMMA.h index e83cba5fd..72e8c31a2 100644 --- a/Eigen/src/Core/arch/AltiVec/MatrixProductMMA.h +++ b/Eigen/src/Core/arch/AltiVec/MatrixProductMMA.h @@ -42,19 +42,13 @@ EIGEN_ALWAYS_INLINE void bsetzeroMMA(__vector_quad* acc) __builtin_mma_xxsetaccz(acc); } -#ifdef USE_PARTIAL_PACKETS template EIGEN_ALWAYS_INLINE void storeAccumulator(Index i, const DataMapper& data, const Packet& alpha, const Index elements, __vector_quad* acc) -#else -template -EIGEN_ALWAYS_INLINE void storeAccumulator(Index i, const DataMapper& data, const Packet& alpha, const Packet& pMask, __vector_quad* acc) -#endif { PacketBlock result; __builtin_mma_disassemble_acc(&result.packet, acc); PacketBlock tRes; -#ifdef USE_PARTIAL_PACKETS if (full) { EIGEN_UNUSED_VARIABLE(elements); bload(tRes, data, i, 0); @@ -65,11 +59,6 @@ EIGEN_ALWAYS_INLINE void storeAccumulator(Index i, const DataMapper& data, const bscale(tRes, result, alpha); bstore_partial(tRes, data, i, elements); } -#else - bload(tRes, data, i, 0); - bscale(tRes, result, alpha, pMask); - bstore(tRes, data, i); -#endif } template @@ -166,78 +155,118 @@ EIGEN_ALWAYS_INLINE void ploadLhsMMA(const double* lhs, __vector_pair& lhsV) ploadRhsMMA(lhs, lhsV); } -#if (EIGEN_COMP_LLVM || (__GNUC__ >= 11)) +#define GEMM_MULTIPLE_COLS + +// Disable in GCC until unnecessary register moves are fixed +//#if (EIGEN_COMP_LLVM || (__GNUC__ >= 11)) +#if EIGEN_COMP_LLVM #define VECTOR_PAIR_LOADS_LHS #endif // PEEL_MMA loop factor. +#ifdef GEMM_MULTIPLE_COLS +#define PEEL_MMA 8 +#else +// Register spillage with GCC12+ +#if EIGEN_COMP_LLVM || (__GNUC__ < 12) || defined(VECTOR_PAIR_LOADS_LHS) #define PEEL_MMA 7 +#else +#define PEEL_MMA 6 +#endif +#endif #define MICRO_MMA_UNROLL(func) \ func(0) func(1) func(2) func(3) func(4) func(5) func(6) func(7) #define MICRO_MMA_WORK(func, type, peel) \ - func(0,type,peel) func(1,type,peel) func(2,type,peel) func(3,type,peel) \ - func(4,type,peel) func(5,type,peel) func(6,type,peel) func(7,type,peel) + if (accItr == 1) { \ + func(0,type,peel,0,0) func(1,type,peel,1,0) func(2,type,peel,2,0) func(3,type,peel,3,0) \ + func(4,type,peel,4,0) func(5,type,peel,5,0) func(6,type,peel,6,0) func(7,type,peel,7,0) \ + } else if (accItr == 2) { \ + func(0,type,peel,0,0) func(1,type,peel,0,1) func(2,type,peel,1,0) func(3,type,peel,1,1) \ + func(4,type,peel,2,0) func(5,type,peel,2,1) func(6,type,peel,3,0) func(7,type,peel,3,1) \ + } else { \ + func(0,type,peel,0,0) func(1,type,peel,0,1) func(2,type,peel,0,2) func(3,type,peel,0,3) \ + func(4,type,peel,1,0) func(5,type,peel,1,1) func(6,type,peel,1,2) func(7,type,peel,1,3) \ + } -#define MICRO_MMA_WORK_ONE(iter, type, peel) \ - if (unroll_factor > iter) { \ - pgerMMA(&accZero##iter, rhsV[peel], lhsV##iter); \ +#define MICRO_MMA_WORK_ONE(iter, type, peel, left, right) \ + if (unroll_factor > left) { \ + pgerMMA(&accZero##iter, rhsV##right[peel], lhsV##left); \ } #ifdef VECTOR_PAIR_LOADS_LHS -#define MICRO_MMA_WORK_TWO(iter, type, peel) \ - if (unroll_factor > iter) { \ - pgerMMA(&accZero##iter, rhsV[peel], lhsV2##iter.packet[peel & 1]); \ +#define MICRO_MMA_WORK_TWO(iter, type, peel, left, right) \ + if (unroll_factor > left) { \ + pgerMMA(&accZero##iter, rhsV##right[peel], lhsV2##left.packet[peel & 1]); \ } -#define MICRO_MMA_LOAD1_TWO(lhs_ptr, iter) \ - if (unroll_factor > iter) { \ - if (MICRO_NORMAL(iter)) { \ - ploadLhsMMA(reinterpret_cast(lhs_ptr##iter), plhsV##iter); \ - __builtin_vsx_disassemble_pair(reinterpret_cast(&lhsV2##iter.packet), &plhsV##iter); \ - lhs_ptr##iter += accCols*2; \ +#define MICRO_MMA_LOAD1_TWO(lhs_ptr, left) \ + if (unroll_factor > left) { \ + if (MICRO_NORMAL(left)) { \ + ploadLhsMMA(reinterpret_cast(lhs_ptr##left), plhsV##left); \ + __builtin_vsx_disassemble_pair(reinterpret_cast(&lhsV2##left.packet), &plhsV##left); \ + lhs_ptr##left += accCols*2; \ } else { \ - lhsV2##iter.packet[0] = ploadLhs(lhs_ptr##iter); \ - lhsV2##iter.packet[1] = ploadLhs(lhs_ptr##iter + accCols2); \ - lhs_ptr##iter += accCols2*2; \ - EIGEN_UNUSED_VARIABLE(plhsV##iter) \ + lhsV2##left.packet[0] = ploadLhs(lhs_ptr##left); \ + lhsV2##left.packet[1] = ploadLhs(lhs_ptr##left + accCols2); \ + lhs_ptr##left += accCols2*2; \ + EIGEN_UNUSED_VARIABLE(plhsV##left); \ } \ } else { \ - EIGEN_UNUSED_VARIABLE(lhsV2##iter); \ - EIGEN_UNUSED_VARIABLE(plhsV##iter) \ + EIGEN_UNUSED_VARIABLE(lhsV2##left); \ + EIGEN_UNUSED_VARIABLE(plhsV##left); \ } -#define MICRO_MMA_LOAD_TWO(iter) MICRO_MMA_LOAD1_TWO(lhs_ptr, iter) +#define MICRO_MMA_LOAD_TWO(left) MICRO_MMA_LOAD1_TWO(lhs_ptr, left) #endif +#define MICRO_MMA_UNROLL_ITER(func, val) \ + func(val,0) \ + if (accItr > 1) { \ + func(val,1) \ + if (accItr > 2) { \ + func(val,2) \ + func(val,3) \ + } \ + } + +#define MICRO_MMA_LOAD_ONE_RHS1(peel, right) \ + ploadRhsMMA(rhs_ptr##right + (accRows * peel), rhsV##right[peel]); + +#define MICRO_MMA_LOAD_ONE_RHS(peel) \ + MICRO_MMA_UNROLL_ITER(MICRO_MMA_LOAD_ONE_RHS1, peel) + #define MICRO_MMA_TYPE_PEEL(funcw, funcl, type, peel) \ if (PEEL_MMA > peel) { \ Packet lhsV0, lhsV1, lhsV2, lhsV3, lhsV4, lhsV5, lhsV6, lhsV7; \ - ploadRhsMMA(rhs_ptr + (accRows * peel), rhsV[peel]); \ + MICRO_MMA_LOAD_ONE_RHS(peel) \ MICRO_MMA_UNROLL(funcl) \ MICRO_MMA_WORK(funcw, type, peel) \ } #ifndef VECTOR_PAIR_LOADS_LHS #define MICRO_MMA_UNROLL_TYPE_PEEL(funcw, funcl, type) \ - type rhsV[8]; \ + type rhsV0[8], rhsV1[(accItr > 1) ? 8 : 1], rhsV2[(accItr > 2) ? 8 : 1], rhsV3[(accItr > 2) ? 8 : 1]; \ MICRO_MMA_TYPE_PEEL(funcw,funcl,type,0) MICRO_MMA_TYPE_PEEL(funcw,funcl,type,1) \ MICRO_MMA_TYPE_PEEL(funcw,funcl,type,2) MICRO_MMA_TYPE_PEEL(funcw,funcl,type,3) \ MICRO_MMA_TYPE_PEEL(funcw,funcl,type,4) MICRO_MMA_TYPE_PEEL(funcw,funcl,type,5) \ MICRO_MMA_TYPE_PEEL(funcw,funcl,type,6) MICRO_MMA_TYPE_PEEL(funcw,funcl,type,7) #else +#define MICRO_MMA_LOAD_TWO_RHS(peel1, right) \ + ploadRhsMMA(reinterpret_cast(rhs_ptr##right + (accRows * peel1)), prhsV##peel1); \ + __builtin_vsx_disassemble_pair(reinterpret_cast(&rhsV##right[peel1]), &prhsV##peel1); + #define MICRO_MMA_TYPE_PEEL2(funcw1, funcl1, funcw2, funcl2, type, peel1, peel2) \ if (PEEL_MMA > peel2) { \ PacketBlock lhsV20, lhsV21, lhsV22, lhsV23, lhsV24, lhsV25, lhsV26, lhsV27; \ __vector_pair plhsV0, plhsV1, plhsV2, plhsV3, plhsV4, plhsV5, plhsV6, plhsV7; \ if (sizeof(type) == 16) { \ - ploadRhsMMA(reinterpret_cast(rhs_ptr + (accRows * peel1)), prhsV##peel1); \ - __builtin_vsx_disassemble_pair(reinterpret_cast(&rhsV[peel1]), &prhsV##peel1); \ + MICRO_MMA_UNROLL_ITER(MICRO_MMA_LOAD_TWO_RHS, peel1) \ } else { \ EIGEN_UNUSED_VARIABLE(prhsV##peel1); \ - ploadRhsMMA(rhs_ptr + (accRows * peel1), rhsV[peel1]); \ - ploadRhsMMA(rhs_ptr + (accRows * peel2), rhsV[peel2]); \ + MICRO_MMA_LOAD_ONE_RHS(peel1) \ + MICRO_MMA_LOAD_ONE_RHS(peel2) \ } \ MICRO_MMA_UNROLL(funcl2) \ MICRO_MMA_WORK(funcw2, type, peel1) \ @@ -248,7 +277,7 @@ EIGEN_ALWAYS_INLINE void ploadLhsMMA(const double* lhs, __vector_pair& lhsV) } #define MICRO_MMA_UNROLL_TYPE_PEEL2(funcw1, funcl1, funcw2, funcl2, type) \ - type rhsV[8]; \ + type rhsV0[8], rhsV1[(accItr > 1) ? 8 : 1], rhsV2[(accItr > 2) ? 8 : 1], rhsV3[(accItr > 2) ? 8 : 1]; \ __vector_pair prhsV0, prhsV2, prhsV4, prhsV6; \ MICRO_MMA_TYPE_PEEL2(funcw1,funcl1,funcw2,funcl2,type,0,1) \ MICRO_MMA_TYPE_PEEL2(funcw1,funcl1,funcw2,funcl2,type,2,3) \ @@ -257,19 +286,25 @@ EIGEN_ALWAYS_INLINE void ploadLhsMMA(const double* lhs, __vector_pair& lhsV) #endif #define MICRO_MMA_UNROLL_TYPE_ONE(funcw, funcl, type) \ - type rhsV[1]; \ + type rhsV0[1], rhsV1[1], rhsV2[1], rhsV3[1]; \ MICRO_MMA_TYPE_PEEL(funcw,funcl,type,0) +#define MICRO_MMA_UPDATE_RHS1(size, right) \ + rhs_ptr##right += (accRows * size); + +#define MICRO_MMA_UPDATE_RHS(size) \ + MICRO_MMA_UNROLL_ITER(MICRO_MMA_UPDATE_RHS1, size) + #define MICRO_MMA_UNROLL_TYPE(MICRO_MMA_TYPE, size) \ MICRO_MMA_TYPE(MICRO_MMA_WORK_ONE, MICRO_LOAD_ONE, RhsPacket) \ - rhs_ptr += (accRows * size); + MICRO_MMA_UPDATE_RHS(size) #ifndef VECTOR_PAIR_LOADS_LHS #define MICRO_MMA_ONE_PEEL MICRO_MMA_UNROLL_TYPE(MICRO_MMA_UNROLL_TYPE_PEEL, PEEL_MMA) #else #define MICRO_MMA_UNROLL_TYPE2(MICRO_MMA_TYPE, size) \ MICRO_MMA_TYPE(MICRO_MMA_WORK_ONE, MICRO_LOAD_ONE, MICRO_MMA_WORK_TWO, MICRO_MMA_LOAD_TWO, RhsPacket) \ - rhs_ptr += (accRows * size); + MICRO_MMA_UPDATE_RHS(size) #define MICRO_MMA_ONE_PEEL MICRO_MMA_UNROLL_TYPE2(MICRO_MMA_UNROLL_TYPE_PEEL2, PEEL_MMA) #endif @@ -277,7 +312,7 @@ EIGEN_ALWAYS_INLINE void ploadLhsMMA(const double* lhs, __vector_pair& lhsV) #define MICRO_MMA_ONE MICRO_MMA_UNROLL_TYPE(MICRO_MMA_UNROLL_TYPE_ONE, 1) #define MICRO_MMA_DST_PTR_ONE(iter) \ - if (unroll_factor > iter) { \ + if (unroll_factor * accItr > iter) { \ bsetzeroMMA(&accZero##iter); \ } else { \ EIGEN_UNUSED_VARIABLE(accZero##iter); \ @@ -289,45 +324,69 @@ EIGEN_ALWAYS_INLINE void ploadLhsMMA(const double* lhs, __vector_pair& lhsV) #define MICRO_MMA_PREFETCH MICRO_MMA_UNROLL(MICRO_PREFETCH_ONE) -#ifdef USE_PARTIAL_PACKETS -#define MICRO_MMA_STORE_ONE(iter) \ - if (unroll_factor > iter) { \ - storeAccumulator(row + iter*accCols, res, pAlpha, accCols2, &accZero##iter); \ +#define MICRO_MMA_STORE_ONE(iter, left, right) \ + if (unroll_factor > left) { \ + storeAccumulator(row + left*accCols, res##right, pAlpha, accCols2, &accZero##iter); \ } -#else -#define MICRO_MMA_STORE_ONE(iter) \ - if (unroll_factor > iter) { \ - storeAccumulator(row + iter*accCols, res, pAlpha, pMask, &accZero##iter); \ + +#define MICRO_MMA_ITER_UNROLL(func) \ + if (accItr == 1) { \ + func(0,0,0) func(1,1,0) func(2,2,0) func(3,3,0) \ + func(4,4,0) func(5,5,0) func(6,6,0) func(7,7,0) \ + } else if (accItr == 2) { \ + func(0,0,0) func(1,0,1) func(2,1,0) func(3,1,1) \ + func(4,2,0) func(5,2,1) func(6,3,0) func(7,3,1) \ + } else { \ + func(0,0,0) func(1,0,1) func(2,0,2) func(3,0,3) \ + func(4,1,0) func(5,1,1) func(6,1,2) func(7,1,3) \ } -#endif -#define MICRO_MMA_STORE MICRO_MMA_UNROLL(MICRO_MMA_STORE_ONE) +#define MICRO_MMA_STORE MICRO_MMA_ITER_UNROLL(MICRO_MMA_STORE_ONE) -#ifdef USE_PARTIAL_PACKETS -template -#else -template -#endif +#define MICRO_MMA_EXTRA_ROWS(right) \ + gemm_extra_row(res3##right, blockA, rhs_base + right*accRows*strideB, depth, strideA, offsetA, strideB, row, rows, remaining_rows, pAlpha, pMask); + +#define MICRO_MMA_EXTRA_ROWS1(val, right) \ + MICRO_MMA_EXTRA_ROWS(right); + +template EIGEN_ALWAYS_INLINE void gemm_unrolled_MMA_iteration( - const DataMapper& res, + const DataMapper& res0, + const DataMapper& res1, + const DataMapper& res2, + const DataMapper& res3, const Scalar* lhs_base, const Scalar* rhs_base, Index depth, Index strideA, + Index strideB, Index offsetA, Index& row, const Packet& pAlpha, -#ifdef USE_PARTIAL_PACKETS Index accCols2 -#else - const Packet& pMask -#endif ) { - const Scalar* rhs_ptr = rhs_base; + const Scalar* rhs_ptr0 = rhs_base, * rhs_ptr1 = NULL, * rhs_ptr2 = NULL, * rhs_ptr3 = NULL; const Scalar* lhs_ptr0 = NULL, * lhs_ptr1 = NULL, * lhs_ptr2 = NULL, * lhs_ptr3 = NULL, * lhs_ptr4 = NULL, * lhs_ptr5 = NULL, * lhs_ptr6 = NULL, * lhs_ptr7 = NULL; __vector_quad accZero0, accZero1, accZero2, accZero3, accZero4, accZero5, accZero6, accZero7; + if (accItr > 1) { + rhs_ptr1 = rhs_base + (accRows * strideB); + } else { + EIGEN_UNUSED_VARIABLE(strideB); + EIGEN_UNUSED_VARIABLE(rhs_ptr1); + EIGEN_UNUSED_VARIABLE(res1); + } + if (accItr > 2) { + rhs_ptr2 = rhs_base + (2 * accRows * strideB); + rhs_ptr3 = rhs_base + (3 * accRows * strideB); + } else { + EIGEN_UNUSED_VARIABLE(rhs_ptr2); + EIGEN_UNUSED_VARIABLE(rhs_ptr3); + EIGEN_UNUSED_VARIABLE(res2); + EIGEN_UNUSED_VARIABLE(res3); + } + MICRO_MMA_SRC_PTR MICRO_MMA_DST_PTR @@ -347,17 +406,16 @@ EIGEN_ALWAYS_INLINE void gemm_unrolled_MMA_iteration( MICRO_UPDATE } -#ifdef USE_PARTIAL_PACKETS #define MICRO_MMA_UNROLL_ITER2(N, M) \ - gemm_unrolled_MMA_iteration(res3, lhs_base, rhs_base, depth, strideA, offsetA, row, pAlpha, M ? remaining_rows : accCols); \ + gemm_unrolled_MMA_iteration(res30, res31, res32, res33, lhs_base, rhs_base, depth, strideA, strideB, offsetA, row, pAlpha, M ? remaining_rows : accCols); \ if (M) return; -#else -#define MICRO_MMA_UNROLL_ITER2(N, M) \ - gemm_unrolled_MMA_iteration(res3, lhs_base, rhs_base, depth, strideA, offsetA, row, pAlpha, pMask); \ - if (M) return; -#endif -template +#define MICRO_MMA_ROWS(n) \ + while(row + n*accCols <= rows) { \ + MICRO_MMA_UNROLL_ITER2(n, 0); \ + } + +template EIGEN_ALWAYS_INLINE void gemmMMA_cols( const DataMapper& res, const Scalar* blockA, @@ -373,45 +431,71 @@ EIGEN_ALWAYS_INLINE void gemmMMA_cols( const Packet& pAlpha, const Packet& pMask) { - const DataMapper res3 = res.getSubMapper(0, col); + const DataMapper res30 = res.getSubMapper(0, col); + const DataMapper res31 = (accItr > 1) ? res30.getSubMapper(0, accRows*1) : res30; + const DataMapper res32 = (accItr > 2) ? res30.getSubMapper(0, accRows*2) : res30; + const DataMapper res33 = (accItr > 2) ? res30.getSubMapper(0, accRows*3) : res30; const Scalar* rhs_base = blockB + col*strideB + accRows*offsetB; const Scalar* lhs_base = blockA + accCols*offsetA; Index row = 0; #define MAX_MMA_UNROLL 7 - while(row + MAX_MMA_UNROLL*accCols <= rows) { - MICRO_MMA_UNROLL_ITER2(MAX_MMA_UNROLL, 0); + +#if MAX_MMA_UNROLL < 2 + if (1) { +#elif MAX_MMA_UNROLL < 4 + if (accItr <= 2) { +#else + if (accItr == 1) { +#endif + MICRO_MMA_ROWS(MAX_MMA_UNROLL); + } else if (accItr == 2) { + MICRO_MMA_ROWS(4); + } else { + MICRO_MMA_ROWS(2); } switch( (rows-row)/accCols ) { #if MAX_MMA_UNROLL > 7 case 7: - MICRO_UNROLL_ITER(MICRO_MMA_UNROLL_ITER2, 7) + if (accItr == 1) { + MICRO_UNROLL_ITER(MICRO_MMA_UNROLL_ITER2, 7) + } break; #endif #if MAX_MMA_UNROLL > 6 case 6: - MICRO_UNROLL_ITER(MICRO_MMA_UNROLL_ITER2, 6) + if (accItr == 1) { + MICRO_UNROLL_ITER(MICRO_MMA_UNROLL_ITER2, 6) + } break; #endif #if MAX_MMA_UNROLL > 5 case 5: - MICRO_UNROLL_ITER(MICRO_MMA_UNROLL_ITER2, 5) + if (accItr == 1) { + MICRO_UNROLL_ITER(MICRO_MMA_UNROLL_ITER2, 5) + } break; #endif #if MAX_MMA_UNROLL > 4 case 4: - MICRO_UNROLL_ITER(MICRO_MMA_UNROLL_ITER2, 4) + if (accItr == 1) { + MICRO_UNROLL_ITER(MICRO_MMA_UNROLL_ITER2, 4) + } break; #endif #if MAX_MMA_UNROLL > 3 case 3: - MICRO_UNROLL_ITER(MICRO_MMA_UNROLL_ITER2, 3) + if (accItr <= 2) { + MICRO_UNROLL_ITER(MICRO_MMA_UNROLL_ITER2, 3) + } break; #endif #if MAX_MMA_UNROLL > 2 case 2: - MICRO_UNROLL_ITER(MICRO_MMA_UNROLL_ITER2, 2) + if (accItr <= 2) { + MICRO_UNROLL_ITER(MICRO_MMA_UNROLL_ITER2, 2) + } break; #endif #if MAX_MMA_UNROLL > 1 @@ -426,10 +510,16 @@ EIGEN_ALWAYS_INLINE void gemmMMA_cols( if(remaining_rows > 0) { - gemm_extra_row(res3, blockA, rhs_base, depth, strideA, offsetA, strideB, row, rows, remaining_rows, pAlpha, pMask); + MICRO_MMA_UNROLL_ITER(MICRO_MMA_EXTRA_ROWS1, 0) } } +#define MICRO_MMA_COLS(n) \ + for(; col + n*accRows <= cols; col += n*accRows) \ + { \ + gemmMMA_cols(res, blockA, blockB, depth, strideA, offsetA, strideB, offsetB, col, rows, remaining_rows, pAlpha, pMask); \ + } + template void gemmMMA(const DataMapper& res, const Scalar* blockA, const Scalar* blockB, Index rows, Index depth, Index cols, Scalar alpha, Index strideA, Index strideB, Index offsetA, Index offsetB) { @@ -444,10 +534,11 @@ void gemmMMA(const DataMapper& res, const Scalar* blockA, const Scalar* blockB, typedef typename std::conditional_t<(sizeof(Scalar) == sizeof(float)), RhsPacket, __vector_pair> RhsPacket2; Index col = 0; - for(; col + accRows <= cols; col += accRows) - { - gemmMMA_cols(res, blockA, blockB, depth, strideA, offsetA, strideB, offsetB, col, rows, remaining_rows, pAlpha, pMask); - } +#ifdef GEMM_MULTIPLE_COLS + MICRO_MMA_COLS(4); + MICRO_MMA_COLS(2); +#endif + MICRO_MMA_COLS(1); if (col != cols) { @@ -459,62 +550,88 @@ void gemmMMA(const DataMapper& res, const Scalar* blockA, const Scalar* blockB, #define advanceCols ((RhsIsReal) ? 1 : 2) // PEEL_COMPLEX_MMA loop factor. +#ifdef GEMM_MULTIPLE_COLS +#define PEEL_COMPLEX_MMA 4 +#else #define PEEL_COMPLEX_MMA 3 +#endif #define MICRO_COMPLEX_MMA_UNROLL(func) \ func(0) func(1) func(2) func(3) #define MICRO_COMPLEX_MMA_WORK(func, type, peel) \ - func(0,type,peel) func(1,type,peel) func(2,type,peel) func(3,type,peel) + if (accItr == 1) { \ + func(0,type,peel,0,0) func(1,type,peel,1,0) func(2,type,peel,2,0) func(3,type,peel,3,0) \ + } else if (accItr == 2) { \ + func(0,type,peel,0,0) func(1,type,peel,0,1) func(2,type,peel,1,0) func(3,type,peel,1,1) \ + } else { \ + func(0,type,peel,0,0) func(1,type,peel,0,1) func(2,type,peel,0,2) func(3,type,peel,0,3) \ + } -#define MICRO_COMPLEX_MMA_WORK_ONE(iter, type, peel) \ - if (unroll_factor > iter) { \ - pgercMMA(&accReal##iter, &accImag##iter, lhsV##iter, lhsVi##iter, rhsV[peel], rhsVi[peel]); \ +#define MICRO_COMPLEX_MMA_WORK_ONE(iter, type, peel, left, right) \ + if (unroll_factor > left) { \ + pgercMMA(&accReal##iter, &accImag##iter, lhsV##left, lhsVi##left, rhsV##right[peel], rhsVi##right[peel]); \ } #ifdef VECTOR_PAIR_LOADS_LHS -#define MICRO_COMPLEX_MMA_WORK_TWO(iter, type, peel) \ - if (unroll_factor > iter) { \ - pgercMMA(&accReal##iter, &accImag##iter, lhsV2##iter.packet[peel & 1], lhsVi2##iter.packet[peel & 1], rhsV[peel], rhsVi[peel]); \ +#define MICRO_COMPLEX_MMA_WORK_TWO(iter, type, peel, left, right) \ + if (unroll_factor > left) { \ + pgercMMA(&accReal##iter, &accImag##iter, lhsV2##left.packet[peel & 1], lhsVi2##left.packet[peel & 1], rhsV##right[peel], rhsVi##right[peel]); \ } -#define MICRO_COMPLEX_MMA_LOAD1_TWO(lhs_ptr, iter) \ - if (!LhsIsReal && (unroll_factor > iter)) { \ - if (MICRO_NORMAL(iter)) { \ - ploadLhsMMA(reinterpret_cast(lhs_ptr_real##iter + imag_delta), plhsVi##iter); \ - __builtin_vsx_disassemble_pair(reinterpret_cast(&lhsVi2##iter.packet), &plhsVi##iter); \ +#define MICRO_COMPLEX_MMA_LOAD1_TWO(lhs_ptr, left) \ + if (!LhsIsReal && (unroll_factor > left)) { \ + if (MICRO_NORMAL(left)) { \ + ploadLhsMMA(reinterpret_cast(lhs_ptr_real##left + imag_delta), plhsVi##left); \ + __builtin_vsx_disassemble_pair(reinterpret_cast(&lhsVi2##left.packet), &plhsVi##left); \ } else { \ - lhsVi2##iter.packet[0] = ploadLhs(lhs_ptr_real##iter + imag_delta2); \ - lhsVi2##iter.packet[1] = ploadLhs(lhs_ptr_real##iter + imag_delta2 + accCols2); \ - EIGEN_UNUSED_VARIABLE(plhsVi##iter) \ + lhsVi2##left.packet[0] = ploadLhs(lhs_ptr_real##left + imag_delta2); \ + lhsVi2##left.packet[1] = ploadLhs(lhs_ptr_real##left + imag_delta2 + accCols2); \ + EIGEN_UNUSED_VARIABLE(plhsVi##left); \ } \ } else { \ - EIGEN_UNUSED_VARIABLE(lhsVi2##iter); \ - EIGEN_UNUSED_VARIABLE(plhsVi##iter) \ + EIGEN_UNUSED_VARIABLE(lhsVi2##left); \ + EIGEN_UNUSED_VARIABLE(plhsVi##left); \ } \ - MICRO_MMA_LOAD1_TWO(lhs_ptr_real, iter) + MICRO_MMA_LOAD1_TWO(lhs_ptr_real, left) -#define MICRO_COMPLEX_MMA_LOAD_TWO(iter) MICRO_COMPLEX_MMA_LOAD1_TWO(lhs_ptr, iter) +#define MICRO_COMPLEX_MMA_LOAD_TWO(left) MICRO_COMPLEX_MMA_LOAD1_TWO(lhs_ptr, left) #endif +#define MICRO_COMPLEX_MMA_LOAD_RHS1(peel, right) \ + ploadRhsMMA(rhs_ptr_real##right + (accRows * peel), rhsV##right[peel]); \ + if (!RhsIsReal) { \ + ploadRhsMMA(rhs_ptr_imag##right + (accRows * peel), rhsVi##right[peel]); \ + } + +#define MICRO_COMPLEX_MMA_LOAD_ONE_RHS(peel) \ + MICRO_MMA_UNROLL_ITER(MICRO_COMPLEX_MMA_LOAD_RHS1, peel) + #define MICRO_COMPLEX_MMA_TYPE_PEEL(funcw, funcl, type, peel) \ if (PEEL_COMPLEX_MMA > peel) { \ Packet lhsV0, lhsV1, lhsV2, lhsV3; \ Packet lhsVi0, lhsVi1, lhsVi2, lhsVi3; \ - ploadRhsMMA(rhs_ptr_real + (accRows * peel), rhsV[peel]); \ - if(!RhsIsReal) { \ - ploadRhsMMA(rhs_ptr_imag + (accRows * peel), rhsVi[peel]); \ - } \ + MICRO_COMPLEX_MMA_LOAD_ONE_RHS(peel) \ MICRO_COMPLEX_MMA_UNROLL(funcl) \ MICRO_COMPLEX_MMA_WORK(funcw, type, peel) \ } #ifndef VECTOR_PAIR_LOADS_LHS #define MICRO_COMPLEX_MMA_UNROLL_TYPE_PEEL(funcw, funcl, type) \ - type rhsV[4], rhsVi[4]; \ + type rhsV0[4], rhsVi0[4], rhsV1[(accItr > 1) ? 4 : 1], rhsVi1[(accItr > 1) ? 4 : 1], rhsV2[(accItr > 2) ? 4 : 1], rhsVi2[(accItr > 2) ? 4 : 1], rhsV3[(accItr > 2) ? 4 : 1], rhsVi3[(accItr > 2) ? 4 : 1]; \ MICRO_COMPLEX_MMA_TYPE_PEEL(funcw,funcl,type,0) MICRO_COMPLEX_MMA_TYPE_PEEL(funcw,funcl,type,1) \ MICRO_COMPLEX_MMA_TYPE_PEEL(funcw,funcl,type,2) MICRO_COMPLEX_MMA_TYPE_PEEL(funcw,funcl,type,3) #else +#define MICRO_COMPLEX_MMA_LOAD_TWO_RHS(peel1, right) \ + ploadRhsMMA(reinterpret_cast(rhs_ptr_real##right + (accRows * peel1)), prhsV##peel1); \ + __builtin_vsx_disassemble_pair(reinterpret_cast(&rhsV##right[peel1]), &prhsV##peel1); \ + if(!RhsIsReal) { \ + ploadRhsMMA(reinterpret_cast(rhs_ptr_imag##right + (accRows * peel1)), prhsVi##peel1); \ + __builtin_vsx_disassemble_pair(reinterpret_cast(&rhsVi##right[peel1]), &prhsVi##peel1); \ + } else { \ + EIGEN_UNUSED_VARIABLE(prhsVi##peel1); \ + } + #define MICRO_COMPLEX_MMA_TYPE_PEEL2(funcw1, funcl1, funcw2, funcl2, type, peel1, peel2) \ if (PEEL_COMPLEX_MMA > peel2) { \ PacketBlock lhsV20, lhsV21, lhsV22, lhsV23; \ @@ -522,23 +639,12 @@ void gemmMMA(const DataMapper& res, const Scalar* blockA, const Scalar* blockB, __vector_pair plhsV0, plhsV1, plhsV2, plhsV3; \ __vector_pair plhsVi0, plhsVi1, plhsVi2, plhsVi3; \ if (sizeof(type) == 16) { \ - ploadRhsMMA(reinterpret_cast(rhs_ptr_real + (accRows * peel1)), prhsV##peel1); \ - __builtin_vsx_disassemble_pair(reinterpret_cast(&rhsV[peel1]), &prhsV##peel1); \ - if(!RhsIsReal) { \ - ploadRhsMMA(reinterpret_cast(rhs_ptr_imag + (accRows * peel1)), prhsVi##peel1); \ - __builtin_vsx_disassemble_pair(reinterpret_cast(&rhsVi[peel1]), &prhsVi##peel1); \ - } else { \ - EIGEN_UNUSED_VARIABLE(prhsVi##peel1); \ - } \ + MICRO_MMA_UNROLL_ITER(MICRO_COMPLEX_MMA_LOAD_TWO_RHS, peel1) \ } else { \ EIGEN_UNUSED_VARIABLE(prhsV##peel1); \ EIGEN_UNUSED_VARIABLE(prhsVi##peel1); \ - ploadRhsMMA(rhs_ptr_real + (accRows * peel1), rhsV[peel1]); \ - ploadRhsMMA(rhs_ptr_real + (accRows * peel2), rhsV[peel2]); \ - if(!RhsIsReal) { \ - ploadRhsMMA(rhs_ptr_imag + (accRows * peel1), rhsVi[peel1]); \ - ploadRhsMMA(rhs_ptr_imag + (accRows * peel2), rhsVi[peel2]); \ - } \ + MICRO_COMPLEX_MMA_LOAD_ONE_RHS(peel1); \ + MICRO_COMPLEX_MMA_LOAD_ONE_RHS(peel2); \ } \ MICRO_COMPLEX_MMA_UNROLL(funcl2) \ MICRO_COMPLEX_MMA_WORK(funcw2, type, peel1) \ @@ -550,7 +656,7 @@ void gemmMMA(const DataMapper& res, const Scalar* blockA, const Scalar* blockB, } #define MICRO_COMPLEX_MMA_UNROLL_TYPE_PEEL2(funcw1, funcl1, funcw2, funcl2, type) \ - type rhsV[4], rhsVi[4]; \ + type rhsV0[4], rhsVi0[4], rhsV1[(accItr > 1) ? 4 : 1], rhsVi1[(accItr > 1) ? 4 : 1], rhsV2[(accItr > 2) ? 4 : 1], rhsVi2[(accItr > 2) ? 4 : 1], rhsV3[(accItr > 2) ? 4 : 1], rhsVi3[(accItr > 2) ? 4 : 1]; \ __vector_pair prhsV0, prhsV2; \ __vector_pair prhsVi0, prhsVi2; \ MICRO_COMPLEX_MMA_TYPE_PEEL2(funcw1,funcl1,funcw2,funcl2,type,0,1) \ @@ -558,21 +664,26 @@ void gemmMMA(const DataMapper& res, const Scalar* blockA, const Scalar* blockB, #endif #define MICRO_COMPLEX_MMA_UNROLL_TYPE_ONE(funcw, funcl, type) \ - type rhsV[1], rhsVi[1]; \ + type rhsV0[1], rhsVi0[1], rhsV1[1], rhsVi1[1], rhsV2[1], rhsVi2[1], rhsV3[1], rhsVi3[1]; \ MICRO_COMPLEX_MMA_TYPE_PEEL(funcw,funcl,type,0) +#define MICRO_COMPLEX_MMA_UPDATE_RHS1(size, right) \ + rhs_ptr_real##right += (accRows * size); \ + if(!RhsIsReal) rhs_ptr_imag##right += (accRows * size); + +#define MICRO_COMPLEX_MMA_UPDATE_RHS(size) \ + MICRO_MMA_UNROLL_ITER(MICRO_COMPLEX_MMA_UPDATE_RHS1, size) + #define MICRO_COMPLEX_MMA_UNROLL_TYPE(MICRO_COMPLEX_MMA_TYPE, size) \ MICRO_COMPLEX_MMA_TYPE(MICRO_COMPLEX_MMA_WORK_ONE, MICRO_COMPLEX_LOAD_ONE, RhsPacket) \ - rhs_ptr_real += (accRows * size); \ - if(!RhsIsReal) rhs_ptr_imag += (accRows * size); + MICRO_COMPLEX_MMA_UPDATE_RHS(size); #ifndef VECTOR_PAIR_LOADS_LHS #define MICRO_COMPLEX_MMA_ONE_PEEL MICRO_COMPLEX_MMA_UNROLL_TYPE(MICRO_COMPLEX_MMA_UNROLL_TYPE_PEEL, PEEL_COMPLEX_MMA) #else #define MICRO_COMPLEX_MMA_UNROLL_TYPE2(MICRO_COMPLEX_MMA_TYPE, size) \ MICRO_COMPLEX_MMA_TYPE(MICRO_COMPLEX_MMA_WORK_ONE, MICRO_COMPLEX_LOAD_ONE, MICRO_COMPLEX_MMA_WORK_TWO, MICRO_COMPLEX_MMA_LOAD_TWO, RhsPacket) \ - rhs_ptr_real += (accRows * size); \ - if(!RhsIsReal) rhs_ptr_imag += (accRows * size); + MICRO_COMPLEX_MMA_UPDATE_RHS(size); #define MICRO_COMPLEX_MMA_ONE_PEEL MICRO_COMPLEX_MMA_UNROLL_TYPE2(MICRO_COMPLEX_MMA_UNROLL_TYPE_PEEL2, PEEL_COMPLEX_MMA) #endif @@ -580,7 +691,7 @@ void gemmMMA(const DataMapper& res, const Scalar* blockA, const Scalar* blockB, #define MICRO_COMPLEX_MMA_ONE MICRO_COMPLEX_MMA_UNROLL_TYPE(MICRO_COMPLEX_MMA_UNROLL_TYPE_ONE, 1) #define MICRO_COMPLEX_MMA_DST_PTR_ONE(iter) \ - if (unroll_factor > iter) { \ + if (unroll_factor * accItr > iter) { \ bsetzeroMMA(&accReal##iter); \ bsetzeroMMA(&accImag##iter); \ } else { \ @@ -594,16 +705,34 @@ void gemmMMA(const DataMapper& res, const Scalar* blockA, const Scalar* blockB, #define MICRO_COMPLEX_MMA_PREFETCH MICRO_COMPLEX_MMA_UNROLL(MICRO_COMPLEX_PREFETCH_ONE) -#define MICRO_COMPLEX_MMA_STORE_ONE(iter) \ - if (unroll_factor > iter) { \ - storeComplexAccumulator(row + iter*accCols, res, pAlphaReal, pAlphaImag, pMask, &accReal##iter, &accImag##iter); \ +#define MICRO_COMPLEX_MMA_STORE_ONE(iter, left, right) \ + if (unroll_factor > left) { \ + storeComplexAccumulator(row + left*accCols, res##right, pAlphaReal, pAlphaImag, pMask, &accReal##iter, &accImag##iter); \ } -#define MICRO_COMPLEX_MMA_STORE MICRO_COMPLEX_MMA_UNROLL(MICRO_COMPLEX_MMA_STORE_ONE) +#define MICRO_COMPLEX_MMA_ITER_UNROLL(func) \ + if (accItr == 1) { \ + func(0,0,0) func(1,1,0) func(2,2,0) func(3,3,0) \ + } else if (accItr == 2) { \ + func(0,0,0) func(1,0,1) func(2,1,0) func(3,1,1) \ + } else { \ + func(0,0,0) func(1,0,1) func(2,0,2) func(3,0,3) \ + } -template +#define MICRO_COMPLEX_MMA_STORE MICRO_COMPLEX_MMA_ITER_UNROLL(MICRO_COMPLEX_MMA_STORE_ONE) + +#define MICRO_COMPLEX_MMA_EXTRA_ROWS(right) \ + gemm_complex_extra_row(res3##right, blockA, rhs_base + right*accRows*(RhsIsReal ? 1 : 2)*strideB, depth, strideA, offsetA, strideB, row, rows, remaining_rows, pAlphaReal, pAlphaImag, pMask); + +#define MICRO_COMPLEX_MMA_EXTRA_ROWS1(val, right) \ + MICRO_COMPLEX_MMA_EXTRA_ROWS(right); + +template EIGEN_ALWAYS_INLINE void gemm_complex_unrolled_MMA_iteration( - const DataMapper& res, + const DataMapper& res0, + const DataMapper& res1, + const DataMapper& res2, + const DataMapper& res3, const Scalar* lhs_base, const Scalar* rhs_base, Index depth, @@ -615,14 +744,48 @@ EIGEN_ALWAYS_INLINE void gemm_complex_unrolled_MMA_iteration( const Packet& pAlphaImag, const Packet& pMask) { - const Scalar* rhs_ptr_real = rhs_base; - const Scalar* rhs_ptr_imag = NULL; + const Scalar* rhs_ptr_real0 = rhs_base, * rhs_ptr_real1 = NULL, * rhs_ptr_real2 = NULL, * rhs_ptr_real3 = NULL; + const Scalar* rhs_ptr_imag0 = NULL, * rhs_ptr_imag1 = NULL, * rhs_ptr_imag2 = NULL, * rhs_ptr_imag3 = NULL; const Index imag_delta = accCols*strideA; const Index imag_delta2 = accCols2*strideA; + if(!RhsIsReal) { - rhs_ptr_imag = rhs_base + accRows*strideB; + rhs_ptr_imag0 = rhs_base + accRows*strideB; } else { - EIGEN_UNUSED_VARIABLE(rhs_ptr_imag); + EIGEN_UNUSED_VARIABLE(rhs_ptr_imag0); + } + if (accItr > 1) { + if(!RhsIsReal) { + rhs_ptr_real1 = rhs_base + (2*accRows*strideB); + rhs_ptr_imag1 = rhs_base + (3*accRows*strideB); + } else { + rhs_ptr_real1 = rhs_base + accRows*strideB; + EIGEN_UNUSED_VARIABLE(rhs_ptr_imag1); + } + } else { + EIGEN_UNUSED_VARIABLE(rhs_ptr_real1); + EIGEN_UNUSED_VARIABLE(rhs_ptr_imag1); + EIGEN_UNUSED_VARIABLE(res1); + } + if (accItr > 2) { + if(!RhsIsReal) { + rhs_ptr_real2 = rhs_base + (4*accRows*strideB); + rhs_ptr_imag2 = rhs_base + (5*accRows*strideB); + rhs_ptr_real3 = rhs_base + (6*accRows*strideB); + rhs_ptr_imag3 = rhs_base + (7*accRows*strideB); + } else { + rhs_ptr_real2 = rhs_base + (2*accRows*strideB); + rhs_ptr_real3 = rhs_base + (3*accRows*strideB); + EIGEN_UNUSED_VARIABLE(rhs_ptr_imag2); + EIGEN_UNUSED_VARIABLE(rhs_ptr_imag3); + } + } else { + EIGEN_UNUSED_VARIABLE(rhs_ptr_real2); + EIGEN_UNUSED_VARIABLE(rhs_ptr_real3); + EIGEN_UNUSED_VARIABLE(rhs_ptr_imag2); + EIGEN_UNUSED_VARIABLE(rhs_ptr_imag3); + EIGEN_UNUSED_VARIABLE(res2); + EIGEN_UNUSED_VARIABLE(res3); } const Scalar* lhs_ptr_real0 = NULL, * lhs_ptr_real1 = NULL; const Scalar* lhs_ptr_real2 = NULL, * lhs_ptr_real3 = NULL; @@ -651,10 +814,15 @@ EIGEN_ALWAYS_INLINE void gemm_complex_unrolled_MMA_iteration( } #define MICRO_COMPLEX_MMA_UNROLL_ITER2(N, M) \ - gemm_complex_unrolled_MMA_iteration(res3, lhs_base, rhs_base, depth, strideA, offsetA, strideB, row, pAlphaReal, pAlphaImag, pMask); \ + gemm_complex_unrolled_MMA_iteration(res30, res31, res32, res33, lhs_base, rhs_base, depth, strideA, offsetA, strideB, row, pAlphaReal, pAlphaImag, pMask); \ if (M) return; -template +#define MICRO_COMPLEX_MMA_ROWS(n) \ + while(row + n*accCols <= rows) { \ + MICRO_COMPLEX_MMA_UNROLL_ITER2(n, 0); \ + } + +template EIGEN_ALWAYS_INLINE void gemmMMA_complex_cols( const DataMapper& res, const Scalar* blockA, @@ -671,35 +839,50 @@ EIGEN_ALWAYS_INLINE void gemmMMA_complex_cols( const Packet& pAlphaImag, const Packet& pMask) { - const DataMapper res3 = res.getSubMapper(0, col); + const DataMapper res30 = res.getSubMapper(0, col); + const DataMapper res31 = (accItr > 1) ? res30.getSubMapper(0, accRows*1) : res30; + const DataMapper res32 = (accItr > 2) ? res30.getSubMapper(0, accRows*2) : res30; + const DataMapper res33 = (accItr > 2) ? res30.getSubMapper(0, accRows*3) : res30; const Scalar* rhs_base = blockB + advanceCols*col*strideB + accRows*offsetB; const Scalar* lhs_base = blockA + accCols*offsetA; Index row = 0; #define MAX_COMPLEX_MMA_UNROLL 4 - while(row + MAX_COMPLEX_MMA_UNROLL*accCols <= rows) { - MICRO_COMPLEX_MMA_UNROLL_ITER2(MAX_COMPLEX_MMA_UNROLL, 0); + +#if MAX_COMPLEX_MMA_UNROLL < 2 + if (1) { +#elif MAX_COMPLEX_MMA_UNROLL < 4 + if (accItr <= 2) { +#else + if (accItr == 1) { +#endif + MICRO_COMPLEX_MMA_ROWS(MAX_COMPLEX_MMA_UNROLL); + } else if (accItr == 2) { + MICRO_COMPLEX_MMA_ROWS(2); + } else { + MICRO_COMPLEX_MMA_ROWS(1); } switch( (rows-row)/accCols ) { -#if MAX_COMPLEX_MMA_UNROLL > 4 - case 4: - MICRO_COMPLEX_UNROLL_ITER(MICRO_COMPLEX_MMA_UNROLL_ITER2, 4) - break; -#endif #if MAX_COMPLEX_MMA_UNROLL > 3 case 3: - MICRO_COMPLEX_UNROLL_ITER(MICRO_COMPLEX_MMA_UNROLL_ITER2, 3) + if (accItr == 1) { + MICRO_COMPLEX_UNROLL_ITER(MICRO_COMPLEX_MMA_UNROLL_ITER2, 3) + } break; #endif #if MAX_COMPLEX_MMA_UNROLL > 2 case 2: - MICRO_COMPLEX_UNROLL_ITER(MICRO_COMPLEX_MMA_UNROLL_ITER2, 2) + if (accItr == 1) { + MICRO_COMPLEX_UNROLL_ITER(MICRO_COMPLEX_MMA_UNROLL_ITER2, 2) + } break; #endif #if MAX_COMPLEX_MMA_UNROLL > 1 case 1: - MICRO_COMPLEX_UNROLL_ITER(MICRO_COMPLEX_MMA_UNROLL_ITER2, 1) + if (accItr <= 2) { + MICRO_COMPLEX_UNROLL_ITER(MICRO_COMPLEX_MMA_UNROLL_ITER2, 1) + } break; #endif default: @@ -709,10 +892,16 @@ EIGEN_ALWAYS_INLINE void gemmMMA_complex_cols( if(remaining_rows > 0) { - gemm_complex_extra_row(res3, blockA, rhs_base, depth, strideA, offsetA, strideB, row, rows, remaining_rows, pAlphaReal, pAlphaImag, pMask); + MICRO_MMA_UNROLL_ITER(MICRO_COMPLEX_MMA_EXTRA_ROWS1, 0) } } +#define MICRO_COMPLEX_MMA_COLS(n) \ + for(; col + n*accRows <= cols; col += n*accRows) \ + { \ + gemmMMA_complex_cols(res, blockA, blockB, depth, strideA, offsetA, strideB, offsetB, col, rows, remaining_rows, pAlphaReal, pAlphaImag, pMask); \ + } + template void gemm_complexMMA(const DataMapper& res, const LhsScalar* blockAc, const RhsScalar* blockBc, Index rows, Index depth, Index cols, Scalarc alpha, Index strideA, Index strideB, Index offsetA, Index offsetB) { @@ -731,10 +920,11 @@ void gemm_complexMMA(const DataMapper& res, const LhsScalar* blockAc, const RhsS typedef typename std::conditional_t<(sizeof(Scalar) == sizeof(float)), RhsPacket, __vector_pair> RhsPacket2; Index col = 0; - for(; col + accRows <= cols; col += accRows) - { - gemmMMA_complex_cols(res, blockA, blockB, depth, strideA, offsetA, strideB, offsetB, col, rows, remaining_rows, pAlphaReal, pAlphaImag, pMask); - } +#ifdef GEMM_MULTIPLE_COLS + MICRO_COMPLEX_MMA_COLS(4); + MICRO_COMPLEX_MMA_COLS(2); +#endif + MICRO_COMPLEX_MMA_COLS(1); if (col != cols) {