diff --git a/Eigen/src/Core/arch/AltiVec/MatrixProduct.h b/Eigen/src/Core/arch/AltiVec/MatrixProduct.h index 2429c8126..61d961881 100644 --- a/Eigen/src/Core/arch/AltiVec/MatrixProduct.h +++ b/Eigen/src/Core/arch/AltiVec/MatrixProduct.h @@ -841,6 +841,345 @@ struct dhs_pack } }; +#ifdef __MMA__ +// General template for lhs packing, bfloat16 specialization. +template +struct dhs_pack +{ + EIGEN_STRONG_INLINE void operator()(bfloat16* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride, Index offset) + { + const Index vectorSize = quad_traits::vectorsize; + Index ri = 0, j = 0; + + for(; j + 2*vectorSize <= rows; j+=2*vectorSize) + { + const DataMapper lhs2 = lhs.getSubMapper(j, 0); + Index i = 0; + + if(PanelMode) ri += 2*vectorSize*offset; + + if(StorageOrder == ColMajor) + { + for(; i + 2 <= depth; i+=2) + { + PacketBlock block; + + block.packet[0] = lhs2.template loadPacket(0 * vectorSize, i + 0); + block.packet[1] = lhs2.template loadPacket(1 * vectorSize, i + 0); + block.packet[2] = lhs2.template loadPacket(0 * vectorSize, i + 1); + block.packet[3] = lhs2.template loadPacket(1 * vectorSize, i + 1); + + Packet8bf t0, t1; + t0 = vec_mergeh(block.packet[0].m_val, block.packet[2].m_val); + t1 = vec_mergel(block.packet[0].m_val, block.packet[2].m_val); + block.packet[2] = vec_mergeh(block.packet[1].m_val, block.packet[3].m_val); + block.packet[3] = vec_mergel(block.packet[1].m_val, block.packet[3].m_val); + block.packet[0] = t0; + block.packet[1] = t1; + + storeBlock(blockA + ri, block); + + ri += 2*2*vectorSize; + } + if (depth & 1) + { + PacketBlock block; + + block.packet[0] = lhs2.template loadPacket(0 * vectorSize, i + 0); + block.packet[1] = lhs2.template loadPacket(1 * vectorSize, i + 0); + + storeBlock(blockA + ri, block); + + ri += 2*vectorSize; + } + } else { + for(; i + vectorSize <= depth; i+=vectorSize) + { + PacketBlock block1, block2; + + bload(block1, lhs2, 0 * vectorSize, i); + bload(block2, lhs2, 1 * vectorSize, i); + + Packet2ul v1[8], v2[8]; + + v1[0] = reinterpret_cast(vec_mergeh(reinterpret_cast(block1.packet[0].m_val), reinterpret_cast(block1.packet[1].m_val))); + v1[1] = reinterpret_cast(vec_mergel(reinterpret_cast(block1.packet[0].m_val), reinterpret_cast(block1.packet[1].m_val))); + v1[2] = reinterpret_cast(vec_mergeh(reinterpret_cast(block1.packet[2].m_val), reinterpret_cast(block1.packet[3].m_val))); + v1[3] = reinterpret_cast(vec_mergel(reinterpret_cast(block1.packet[2].m_val), reinterpret_cast(block1.packet[3].m_val))); + v1[4] = reinterpret_cast(vec_mergeh(reinterpret_cast(block1.packet[4].m_val), reinterpret_cast(block1.packet[5].m_val))); + v1[5] = reinterpret_cast(vec_mergel(reinterpret_cast(block1.packet[4].m_val), reinterpret_cast(block1.packet[5].m_val))); + v1[6] = reinterpret_cast(vec_mergeh(reinterpret_cast(block1.packet[6].m_val), reinterpret_cast(block1.packet[7].m_val))); + v1[7] = reinterpret_cast(vec_mergel(reinterpret_cast(block1.packet[6].m_val), reinterpret_cast(block1.packet[7].m_val))); + v2[0] = reinterpret_cast(vec_mergeh(reinterpret_cast(block2.packet[0].m_val), reinterpret_cast(block2.packet[1].m_val))); + v2[1] = reinterpret_cast(vec_mergel(reinterpret_cast(block2.packet[0].m_val), reinterpret_cast(block2.packet[1].m_val))); + v2[2] = reinterpret_cast(vec_mergeh(reinterpret_cast(block2.packet[2].m_val), reinterpret_cast(block2.packet[3].m_val))); + v2[3] = reinterpret_cast(vec_mergel(reinterpret_cast(block2.packet[2].m_val), reinterpret_cast(block2.packet[3].m_val))); + v2[4] = reinterpret_cast(vec_mergeh(reinterpret_cast(block2.packet[4].m_val), reinterpret_cast(block2.packet[5].m_val))); + v2[5] = reinterpret_cast(vec_mergel(reinterpret_cast(block2.packet[4].m_val), reinterpret_cast(block2.packet[5].m_val))); + v2[6] = reinterpret_cast(vec_mergeh(reinterpret_cast(block2.packet[6].m_val), reinterpret_cast(block2.packet[7].m_val))); + v2[7] = reinterpret_cast(vec_mergel(reinterpret_cast(block2.packet[6].m_val), reinterpret_cast(block2.packet[7].m_val))); + + block1.packet[0] = reinterpret_cast(vec_mergeh(v1[0],v1[2])); + block1.packet[2] = reinterpret_cast(vec_mergel(v1[0],v1[2])); + block1.packet[4] = reinterpret_cast(vec_mergeh(v1[1],v1[3])); + block1.packet[6] = reinterpret_cast(vec_mergel(v1[1],v1[3])); + block1.packet[1] = reinterpret_cast(vec_mergeh(v1[4],v1[6])); + block1.packet[3] = reinterpret_cast(vec_mergel(v1[4],v1[6])); + block1.packet[5] = reinterpret_cast(vec_mergeh(v1[5],v1[7])); + block1.packet[7] = reinterpret_cast(vec_mergel(v1[5],v1[7])); + block2.packet[0] = reinterpret_cast(vec_mergeh(v2[0],v2[2])); + block2.packet[2] = reinterpret_cast(vec_mergel(v2[0],v2[2])); + block2.packet[4] = reinterpret_cast(vec_mergeh(v2[1],v2[3])); + block2.packet[6] = reinterpret_cast(vec_mergel(v2[1],v2[3])); + block2.packet[1] = reinterpret_cast(vec_mergeh(v2[4],v2[6])); + block2.packet[3] = reinterpret_cast(vec_mergel(v2[4],v2[6])); + block2.packet[5] = reinterpret_cast(vec_mergeh(v2[5],v2[7])); + block2.packet[7] = reinterpret_cast(vec_mergel(v2[5],v2[7])); + + + for(Index M = 0; M < 8; M+=2) { + pstore(blockA + ri + (0 * vectorSize) + (2*vectorSize * M), block1.packet[M+0]); + pstore(blockA + ri + (1 * vectorSize) + (2*vectorSize * M), block1.packet[M+1]); + pstore(blockA + ri + (2 * vectorSize) + (2*vectorSize * M), block2.packet[M+0]); + pstore(blockA + ri + (3 * vectorSize) + (2*vectorSize * M), block2.packet[M+1]); + } + + ri += 2*vectorSize*vectorSize; + } + for(; i + 2 <= depth; i+=2) + { + for(Index M = 0; M < 2*vectorSize; M++) { + blockA[ri + (M * 2) + 0] = lhs2(M, i + 0); + blockA[ri + (M * 2) + 1] = lhs2(M, i + 1); + } + + ri += 2*2*vectorSize; + } + if (depth & 1) + { + for(Index M = 0; M < 2*vectorSize; M++) { + blockA[ri + M] = lhs2(M, i); + } + ri += 2*vectorSize; + } + } + + if(PanelMode) ri += 2*vectorSize*(stride - offset - depth); + } + for(; j + vectorSize <= rows; j+=vectorSize) + { + const DataMapper lhs2 = lhs.getSubMapper(j, 0); + Index i = 0; + + if(PanelMode) ri += vectorSize*offset; + + if(StorageOrder == ColMajor) + { + for(; i + 2 <= depth; i+=2) + { + PacketBlock block; + + block.packet[0] = lhs2.template loadPacket(0 * vectorSize, i + 0); + block.packet[1] = lhs2.template loadPacket(0 * vectorSize, i + 1); + + Packet8bf t0; + t0 = vec_mergeh(block.packet[0].m_val, block.packet[1].m_val); + block.packet[1] = vec_mergel(block.packet[0].m_val, block.packet[1].m_val); + block.packet[0] = t0; + + storeBlock(blockA + ri, block); + + ri += 2*vectorSize; + } + if (depth & 1) + { + Packet8bf lhsV = lhs2.template loadPacket(0 * vectorSize, i + 0); + pstore(blockA + ri, lhsV); + + ri += vectorSize; + } + } else { + for(; i + vectorSize <= depth; i+=vectorSize) + { + PacketBlock block1; + + bload(block1, lhs2, 0 * vectorSize, i); + + Packet2ul v1[8]; + + v1[0] = reinterpret_cast(vec_mergeh(reinterpret_cast(block1.packet[0].m_val), reinterpret_cast(block1.packet[1].m_val))); + v1[1] = reinterpret_cast(vec_mergel(reinterpret_cast(block1.packet[0].m_val), reinterpret_cast(block1.packet[1].m_val))); + v1[2] = reinterpret_cast(vec_mergeh(reinterpret_cast(block1.packet[2].m_val), reinterpret_cast(block1.packet[3].m_val))); + v1[3] = reinterpret_cast(vec_mergel(reinterpret_cast(block1.packet[2].m_val), reinterpret_cast(block1.packet[3].m_val))); + v1[4] = reinterpret_cast(vec_mergeh(reinterpret_cast(block1.packet[4].m_val), reinterpret_cast(block1.packet[5].m_val))); + v1[5] = reinterpret_cast(vec_mergel(reinterpret_cast(block1.packet[4].m_val), reinterpret_cast(block1.packet[5].m_val))); + v1[6] = reinterpret_cast(vec_mergeh(reinterpret_cast(block1.packet[6].m_val), reinterpret_cast(block1.packet[7].m_val))); + v1[7] = reinterpret_cast(vec_mergel(reinterpret_cast(block1.packet[6].m_val), reinterpret_cast(block1.packet[7].m_val))); + + block1.packet[0] = reinterpret_cast(vec_mergeh(v1[0],v1[2])); + block1.packet[2] = reinterpret_cast(vec_mergel(v1[0],v1[2])); + block1.packet[4] = reinterpret_cast(vec_mergeh(v1[1],v1[3])); + block1.packet[6] = reinterpret_cast(vec_mergel(v1[1],v1[3])); + block1.packet[1] = reinterpret_cast(vec_mergeh(v1[4],v1[6])); + block1.packet[3] = reinterpret_cast(vec_mergel(v1[4],v1[6])); + block1.packet[5] = reinterpret_cast(vec_mergeh(v1[5],v1[7])); + block1.packet[7] = reinterpret_cast(vec_mergel(v1[5],v1[7])); + + for(Index M = 0; M < 8; M++) { + pstore(blockA + ri + (vectorSize * M), block1.packet[M]); + } + + ri += vectorSize*vectorSize; + } + for(; i + 2 <= depth; i+=2) + { + for(Index M = 0; M < vectorSize; M++) { + blockA[ri + (M * 2) + 0] = lhs2(M, i + 0); + blockA[ri + (M * 2) + 1] = lhs2(M, i + 1); + } + + ri += 2*vectorSize; + } + if (depth & 1) + { + for(Index M = 0; M < vectorSize; M++) { + blockA[ri + M] = lhs2(M, i); + } + + ri += vectorSize; + } + } + + if(PanelMode) ri += vectorSize*(stride - offset - depth); + } + + if(PanelMode) ri += offset; + + for(; j < rows; j++) + { + const DataMapper lhs2 = lhs.getSubMapper(j, 0); + for(Index i = 0; i < depth; i++) + { + blockA[ri] = lhs2(0, i); + ri += 1; + } + + if(PanelMode) ri += stride - depth; + } + } +}; + +// General template for rhs packing, bfloat16 specialization. +template +struct dhs_pack +{ + EIGEN_STRONG_INLINE void operator()(bfloat16* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride, Index offset) + { + const Index vectorSize = quad_traits::vectorsize; + Index ri = 0, j = 0; + + for(; j + 4 <= cols; j+=4) + { + const DataMapper rhs2 = rhs.getSubMapper(0, j); + Index i = 0; + + if(PanelMode) ri += 4*offset; + + for(; i + vectorSize <= depth; i+=vectorSize) + { + if(StorageOrder == ColMajor) + { + PacketBlock block; + + bload(block, rhs2, i, 0); + + Packet2ul t0, t1, t2, t3; + t0 = reinterpret_cast(vec_mergeh(reinterpret_cast(block.packet[0].m_val), reinterpret_cast(block.packet[1].m_val))); + t1 = reinterpret_cast(vec_mergeh(reinterpret_cast(block.packet[2].m_val), reinterpret_cast(block.packet[3].m_val))); + t2 = reinterpret_cast(vec_mergel(reinterpret_cast(block.packet[0].m_val), reinterpret_cast(block.packet[1].m_val))); + t3 = reinterpret_cast(vec_mergel(reinterpret_cast(block.packet[2].m_val), reinterpret_cast(block.packet[3].m_val))); + block.packet[0] = reinterpret_cast(vec_mergeh(t0, t1)); + block.packet[1] = reinterpret_cast(vec_mergel(t0, t1)); + block.packet[2] = reinterpret_cast(vec_mergeh(t2, t3)); + block.packet[3] = reinterpret_cast(vec_mergel(t2, t3)); + + storeBlock(blockB + ri, block); + } else { + PacketBlock block; + + for (int M = 0; M < 8; M++) { + block.packet[M] = rhs2.template loadPacketPartial(i + M, 0, 4); + } + + block.packet[0] = vec_mergeh(block.packet[0].m_val, block.packet[1].m_val); + block.packet[1] = vec_mergeh(block.packet[2].m_val, block.packet[3].m_val); + block.packet[2] = vec_mergeh(block.packet[4].m_val, block.packet[5].m_val); + block.packet[3] = vec_mergeh(block.packet[6].m_val, block.packet[7].m_val); + + const Index size = 16 / sizeof(bfloat16); + + for (int M = 0; M < 4; M++) { + pstore(blockB + ri + (M * size), block.packet[M]); + } + } + + ri += 4*vectorSize; + } + for (; i + 2 <= depth; i += 2) { + if(StorageOrder == ColMajor) + { + blockB[ri+0] = rhs2(i + 0, 0); + blockB[ri+1] = rhs2(i + 1, 0); + blockB[ri+2] = rhs2(i + 0, 1); + blockB[ri+3] = rhs2(i + 1, 1); + blockB[ri+4] = rhs2(i + 0, 2); + blockB[ri+5] = rhs2(i + 1, 2); + blockB[ri+6] = rhs2(i + 0, 3); + blockB[ri+7] = rhs2(i + 1, 3); + } else { + PacketBlock block; + + for (int M = 0; M < 2; M++) { + block.packet[M] = rhs2.template loadPacketPartial(i + M, 0, 4); + } + + block.packet[0] = vec_mergeh(block.packet[0].m_val, block.packet[1].m_val); + + pstore(blockB + ri, block.packet[0]); + } + + ri += 4*2; + } + if (depth & 1) + { + blockB[ri+0] = rhs2(i, 0); + blockB[ri+1] = rhs2(i, 1); + blockB[ri+2] = rhs2(i, 2); + blockB[ri+3] = rhs2(i, 3); + + ri += 4; + } + + if(PanelMode) ri += 4*(stride - offset - depth); + } + + if(PanelMode) ri += offset; + + for(; j < cols; j++) + { + const DataMapper rhs2 = rhs.getSubMapper(0, j); + for(Index i = 0; i < depth; i++) + { + blockB[ri] = rhs2(i, 0); + ri += 1; + } + + if(PanelMode) ri += stride - depth; + } + } +}; +#endif + // General template for lhs complex packing, float64 specialization. template struct dhs_cpack @@ -2322,6 +2661,64 @@ void gemm_pack_rhs +struct gemm_pack_rhs +{ + void operator()(bfloat16* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride=0, Index offset=0); +}; + +template +void gemm_pack_rhs + ::operator()(bfloat16* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride, Index offset) +{ + dhs_pack pack; + pack(blockB, rhs, depth, cols, stride, offset); +} + +template +struct gemm_pack_rhs +{ + void operator()(bfloat16* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride=0, Index offset=0); +}; + +template +void gemm_pack_rhs + ::operator()(bfloat16* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride, Index offset) +{ + dhs_pack pack; + pack(blockB, rhs, depth, cols, stride, offset); +} + +template +struct gemm_pack_lhs +{ + void operator()(bfloat16* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride=0, Index offset=0); +}; + +template +void gemm_pack_lhs + ::operator()(bfloat16* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride, Index offset) +{ + dhs_pack pack; + pack(blockA, lhs, depth, rows, stride, offset); +} + +template +struct gemm_pack_lhs +{ + void operator()(bfloat16* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride=0, Index offset=0); +}; + +template +void gemm_pack_lhs + ::operator()(bfloat16* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride, Index offset) +{ + dhs_pack pack; + pack(blockA, lhs, depth, rows, stride, offset); +} +#endif + template struct gemm_pack_lhs { diff --git a/Eigen/src/Core/arch/AltiVec/MatrixProductMMAbfloat16.h b/Eigen/src/Core/arch/AltiVec/MatrixProductMMAbfloat16.h index b3e063d46..01465d3e7 100644 --- a/Eigen/src/Core/arch/AltiVec/MatrixProductMMAbfloat16.h +++ b/Eigen/src/Core/arch/AltiVec/MatrixProductMMAbfloat16.h @@ -1,117 +1,64 @@ - #ifndef EIGEN_MATRIX_PRODUCT_MMA_BFLOAT16_ALTIVEC_H - #define EIGEN_MATRIX_PRODUCT_MMA_BFLOAT16_ALTIVEC_H +#ifndef EIGEN_MATRIX_PRODUCT_MMA_BFLOAT16_ALTIVEC_H +#define EIGEN_MATRIX_PRODUCT_MMA_BFLOAT16_ALTIVEC_H + +#if EIGEN_COMP_LLVM +#define BFLOAT16_UNROLL _Pragma("unroll 8") +#else +#define BFLOAT16_UNROLL _Pragma("GCC unroll(8)") +#endif namespace Eigen { namespace internal { -EIGEN_STRONG_INLINE void pgerMMAbfloat16(__vector_quad* acc, const Packet8bf& a, const Packet8bf& b, int maskX, int maskY) -{ - switch(maskX){ - case 15: - switch(maskY){ - case 0b1111: - __builtin_mma_xvbf16ger2pp(acc, reinterpret_cast(a.m_val), reinterpret_cast(b.m_val)); - break; - case 0b0011: - __builtin_mma_pmxvbf16ger2pp(acc, reinterpret_cast(a.m_val), reinterpret_cast(b.m_val), 0b1111, 0b11, 0b11); - break; - case 0b0001: - __builtin_mma_pmxvbf16ger2pp(acc, reinterpret_cast(a.m_val), reinterpret_cast(b.m_val), 0b1111, 0b1, 0b11); - break; - case 0b0111: - __builtin_mma_pmxvbf16ger2pp(acc, reinterpret_cast(a.m_val), reinterpret_cast(b.m_val), 0b1111, 0b111, 0b11); - break; - } - break; - case 3: - switch(maskY){ - case 0b1111: - __builtin_mma_xvbf16ger2pp(acc, reinterpret_cast(a.m_val), reinterpret_cast(b.m_val)); - break; - case 0b0011: - __builtin_mma_pmxvbf16ger2pp(acc, reinterpret_cast(a.m_val), reinterpret_cast(b.m_val), 0b11, 0b11, 0b11); - break; - case 0b0001: - __builtin_mma_pmxvbf16ger2pp(acc, reinterpret_cast(a.m_val), reinterpret_cast(b.m_val), 0b11, 0b1, 0b11); - break; - case 0b0111: - __builtin_mma_pmxvbf16ger2pp(acc, reinterpret_cast(a.m_val), reinterpret_cast(b.m_val), 0b11, 0b111, 0b11); - break; - } - break; - case 1: - switch(maskY){ - case 0b1111: - __builtin_mma_xvbf16ger2pp(acc, reinterpret_cast(a.m_val), reinterpret_cast(b.m_val)); - break; - case 0b0011: - __builtin_mma_pmxvbf16ger2pp(acc, reinterpret_cast(a.m_val), reinterpret_cast(b.m_val), 0b1, 0b11, 0b11); - break; - case 0b0001: - __builtin_mma_pmxvbf16ger2pp(acc, reinterpret_cast(a.m_val), reinterpret_cast(b.m_val), 0b1, 0b1, 0b11); - break; - case 0b0111: - __builtin_mma_pmxvbf16ger2pp(acc, reinterpret_cast(a.m_val), reinterpret_cast(b.m_val), 0b1, 0b111, 0b11); - break; - } - break; - case 0b0111: - switch(maskY){ - case 0b1111: - __builtin_mma_xvbf16ger2pp(acc, reinterpret_cast(a.m_val), reinterpret_cast(b.m_val)); - break; - case 0b0011: - __builtin_mma_pmxvbf16ger2pp(acc, reinterpret_cast(a.m_val), reinterpret_cast(b.m_val), 0b111, 0b11, 0b11); - break; - case 0b0001: - __builtin_mma_pmxvbf16ger2pp(acc, reinterpret_cast(a.m_val), reinterpret_cast(b.m_val), 0b111, 0b1, 0b11); - break; - case 0b0111: - __builtin_mma_pmxvbf16ger2pp(acc, reinterpret_cast(a.m_val), reinterpret_cast(b.m_val), 0b111, 0b111, 0b11); - break; - } - break; - } -} - -EIGEN_STRONG_INLINE void scaleAndStore(float* result, float* acc, Packet4f pAlpha) +EIGEN_ALWAYS_INLINE void scaleAndStore(float* result, Packet4f& acc, const Packet4f& pAlpha) { Packet4f result_block = ploadu(result); - Packet4f packet_pmadd = pmadd(pload(acc), pAlpha, result_block); - pstoreu(result, packet_pmadd); + result_block = pmadd(acc, pAlpha, result_block); + pstoreu(result, result_block); } -template -EIGEN_STRONG_INLINE Packet8bf loadLhsBfloat16(const bfloat16* indexA) +template +EIGEN_ALWAYS_INLINE Packet8bf loadLhsBfloat16(const bfloat16* indexA) { Packet8bf lhs1 = ploadu(indexA); - Packet8bf lhs2; - const int packet_size = 8; //We fit 8 bfloat16 on a 128 register if(zero){ - lhs2 = pset1(Eigen::bfloat16(0)); + Packet8bf lhs2 = pset1(Eigen::bfloat16(0)); + return vec_mergeh(lhs1.m_val, lhs2.m_val); + } else { + return lhs1; } - else lhs2 = ploadu(indexA + num_packets*packet_size); - return vec_mergeh(lhs1.m_val, lhs2.m_val); } template -EIGEN_STRONG_INLINE Packet8bf loadLhsBfloat16ExtraRows(const bfloat16* indexA, Index strideA, Index row, int extra_rows) +EIGEN_ALWAYS_INLINE Packet8bf loadBfloat16Extra(const bfloat16* indexA, Index strideA, Index extra_rows) { - EIGEN_ALIGN16 bfloat16 lhs_array[8] = {Eigen::bfloat16(0), Eigen::bfloat16(0), Eigen::bfloat16(0), Eigen::bfloat16(0), Eigen::bfloat16(0), Eigen::bfloat16(0), Eigen::bfloat16(0), Eigen::bfloat16(0)}; - int count = 0; - const bfloat16* idxA = indexA + row*strideA; - for(int row_count = 0; row_count < extra_rows; row_count++){ - lhs_array[count++] = *idxA; - if(!zero) lhs_array[count] = *(idxA+1); - count++; - idxA += strideA; + Index row_count = 0; + if (zero) { + EIGEN_ALIGN16 bfloat16 lhs_array[8] = { Eigen::bfloat16(0) }; + do{ + lhs_array[row_count] = *indexA; + indexA += strideA; + } while ((row_count += 2) < extra_rows*2); + return pload_partial(lhs_array, extra_rows*2); + } else { + EIGEN_ALIGN16 int lhs_array[4]; + do{ + lhs_array[row_count] = *reinterpret_cast(indexA); + indexA += strideA; + } while ((row_count += 1) < extra_rows); + return reinterpret_cast(pload_partial(lhs_array, extra_rows)); } - return pload(lhs_array); } template -EIGEN_STRONG_INLINE Packet8bf loadRhsBfloat16(const bfloat16* baseB, Index strideB, int i, int k) +EIGEN_ALWAYS_INLINE Packet8bf loadLhsBfloat16ExtraRows(const bfloat16* indexA, Index strideA, Index row, Index extra_rows) +{ + return loadBfloat16Extra(indexA + row*strideA, strideA, extra_rows); +} + +template +EIGEN_ALWAYS_INLINE Packet8bf loadRhsBfloat16(const bfloat16* baseB, Index strideB, Index i, Index k) { const bfloat16* indexB = baseB + strideB*4*i + (k*4); Packet8bf rhs1 = ploadu(indexB); @@ -119,28 +66,16 @@ EIGEN_STRONG_INLINE Packet8bf loadRhsBfloat16(const bfloat16* baseB, Index strid Packet8bf rhs2 = pset1(Eigen::bfloat16(0)); return vec_mergeh(rhs1.m_val, rhs2.m_val); } - //r = vec_perm (a, b, c) - //Let v be the concatenation of a and b. - //Each byte of r selected by using the least-significant 5 bits of the corresponding byte of c as an index into v - //We need this elements from rhs: 0, 4, 1, 5, 2, 6, 3, 7 - Packet16uc c = {0x0u, 0x1u, 0x8u, 0x9u, 0x2u, 0x3u, 0xAu, 0xB, 0x4, 0x5, 0xCu, 0xDu, 0x6u, 0x7u, 0xEu, 0xFu}; - return vec_perm(rhs1.m_val, rhs1.m_val, c); + return rhs1; } template -EIGEN_STRONG_INLINE Packet8bf loadRhsBfloat16ExtraCols(const bfloat16* blockB, Index strideB, Index offsetB, Index col, int i, int k, int extra_cols) +EIGEN_ALWAYS_INLINE Packet8bf loadRhsBfloat16ExtraCols(const bfloat16* blockB, Index strideB, Index offsetB, Index col, Index i, Index k, Index extra_cols) { - EIGEN_ALIGN16 bfloat16 rhs_vector[8] = {Eigen::bfloat16(0), Eigen::bfloat16(0), Eigen::bfloat16(0), Eigen::bfloat16(0), Eigen::bfloat16(0), Eigen::bfloat16(0), Eigen::bfloat16(0), Eigen::bfloat16(0)}; - const bfloat16* indexB = blockB + ((col+4*i)*strideB)+k+offsetB; - for(int c = 0; c < extra_cols; c++){ - rhs_vector[2*c] = *indexB; - if(!zero) rhs_vector[2*c+1] = *(indexB+1); - indexB += strideB; - } - return pload(rhs_vector); + return loadBfloat16Extra(blockB + ((col+4*i)*strideB)+k+offsetB, strideB, extra_cols); } -template +template EIGEN_STRONG_INLINE void KLoop ( const bfloat16* indexA, @@ -152,107 +87,95 @@ EIGEN_STRONG_INLINE void KLoop Index k, Index row, Index col, - int extra_rows, - int extra_cols, - int mask_rows = 0xF, - int mask_cols = 0xF + Index extra_rows, + Index extra_cols ) { Packet8bf lhs; Packet8bf rhs[num_acc]; if(lhs_extra_rows) lhs = loadLhsBfloat16ExtraRows(indexA+k, strideA, row, extra_rows); - else lhs = loadLhsBfloat16(indexA + k*num_packets*8); //a packet of bfloat16 has 8 elements - for(int i = 0; i < num_acc; i++){ + else lhs = loadLhsBfloat16(indexA + k*num_packets); //a packet of bfloat16 has 8 elements + BFLOAT16_UNROLL + for(Index i = 0; i < num_acc; i++){ if(!rhs_extra_cols) rhs[i] = loadRhsBfloat16(indexB, strideB, i, k); else{ rhs[i] = loadRhsBfloat16ExtraCols(indexB, strideB, offsetB, col, i, k, extra_cols); } - pgerMMAbfloat16(&(quad_acc[i]), rhs[i], lhs, mask_cols, mask_rows); + __builtin_mma_xvbf16ger2pp(&(quad_acc[i]), reinterpret_cast(rhs[i].m_val), reinterpret_cast(lhs.m_val)); } } -template -void colLoopBody(Index* p_col, Index row, Index depth, Index cols, Index rows, int offset_row, int block_index, Packet4f pAlpha, const bfloat16* indexA, Index strideA, const bfloat16* blockB, Index strideB, Index offsetB, float* result, int extra_cols = 0, int extra_rows = 0, int mask_cols = 0xF, int mask_rows = 0xF) +template +void colLoopBody(Index& col, Index row, Index depth, Index cols, Index rows, Index offset_row, Index block_index, const Packet4f& pAlpha, const bfloat16* indexA, Index strideA, const bfloat16* blockB, Index strideB, Index offsetB, float* result, Index extra_cols = 0, Index extra_rows = 0) { - int col = *p_col; - int count; - int max, step, bound; - const bfloat16* indexB; + const Index step = rhsExtraCols ? 1 : (num_acc * 4); //each accumulator has 4 elements + const bfloat16* indexB = rhsExtraCols ? blockB : (blockB + 4*offsetB + strideB*col); - if(num_acc == 1) bound = 0; - else bound = 1; - - if(rhsExtraCols){ - count = 0; - max = 1; - step = 1; - indexB = blockB; - } - else{ - count = col; - step = num_acc * 4; //each accumulator has 4 elements - max = cols/step; - indexB = blockB + 4*offsetB + strideB*col; - } - - while(count/step + bound < max){ + while(col + step <= cols){ Index k = 0; - EIGEN_ALIGN32 float acc[num_acc][4][4]; + Packet4f acc[num_acc][4]; __vector_quad quad_acc[num_acc]; - for(int i = 0; i < num_acc; i++) + BFLOAT16_UNROLL + for(Index i = 0; i < num_acc; i++) __builtin_mma_xxsetaccz(&(quad_acc[i])); - if(depth%2 != 0){ - KLoop(indexA, indexB, quad_acc, strideA, strideB, offsetB, k, row, col, extra_rows, extra_cols, mask_rows, mask_cols); - k = 1; + for(; k + 2 <= depth; k += 2){ + KLoop(indexA, indexB, quad_acc, strideA, strideB, offsetB, k, row, col, extra_rows, extra_cols); } - for(; k/2 < depth/2; k += 2){ - KLoop(indexA, indexB, quad_acc, strideA, strideB, offsetB, k, row, col, extra_rows, extra_cols, mask_rows, mask_cols); + if(depth&1){ + KLoop(indexA-offset_row, indexB, quad_acc, strideA, strideB, offsetB, k, row, col, extra_rows, extra_cols); } - for(int i = 0; i < num_acc; i++){ + + BFLOAT16_UNROLL + for(Index i = 0; i < num_acc; i++) __builtin_mma_disassemble_acc((void*)acc[i], &(quad_acc[i])); + + for(Index i = 0; i < num_acc; i++){ if(lhsExtraRows){ - for(int x = 0; x < extra_cols; x++){ - for(int y = 0; y < extra_rows; y++){ - result[((col+i*4)+x)*rows + row + y] += acc[i][x][y]*(pAlpha[0]); - } + float *r = result + (col+i*4)*rows + row; + for(Index x = 0; x < extra_cols; x++, r += rows){ + Packet4f result_block = ploadu_partial(r, extra_rows); + result_block = pmadd(acc[i][x], pAlpha, result_block); + pstoreu_partial(r, result_block, extra_rows); } } else{ if(rhsExtraCols){ - for(int x = 0; x < cols-col; x++){ - scaleAndStore(result + ((col+i*4)+x)*rows + row + offset_row,acc[i][x], pAlpha); + float *r = result + (col+i*4)*rows + row + offset_row; + for(Index x = 0; x < cols-col; x++, r += rows){ + scaleAndStore(r,acc[i][x], pAlpha); } } else{ - for(int x = 0; x < 4; x++){ - scaleAndStore(result + ((col+i*4)+x)*rows + (block_index*16) + offset_row,acc[i][x], pAlpha); + float *r = result + (col+i*4)*rows + (block_index*16) + offset_row; + for(Index x = 0; x < 4; x++, r += rows){ + scaleAndStore(r,acc[i][x], pAlpha); } } } } - count += step; - if(!rhsExtraCols) { - indexB += strideB*step; - col += step; - } + if(rhsExtraCols) return; + indexB += strideB*step; + col += step; } - *p_col = col; } template void gemmMMAbfloat16(const DataMapper& res, const bfloat16* blockA, const bfloat16* blockB, Index rows, Index depth, Index cols, bfloat16 alpha, Index strideA, Index strideB, Index offsetA, Index offsetB) { - if(rows == 0 || cols == 0 || depth == 0) return; - const Packet4f pAlpha = pset1(Eigen::bfloat16_impl::bfloat16_to_float(alpha)); + float falpha = Eigen::bfloat16_impl::bfloat16_to_float(alpha); + if (falpha == float(0)) return; + const Packet4f pAlpha = pset1(falpha); ei_declare_aligned_stack_constructed_variable(float, result, cols*rows, 0); - for(int j = 0; j < cols; j++){ - for(int i = 0; i < rows; i++){ - result[j*rows + i] = res(i,j); + typedef typename DataMapper::LinearMapper LinearMapper; + for(Index j = 0; j < cols; j++){ + const LinearMapper res2 = res.getLinearMapper(0, j); + for(Index i = 0; i < rows; i++){ + result[j*rows + i] = res2(i); } } @@ -268,26 +191,27 @@ void gemmMMAbfloat16(const DataMapper& res, const bfloat16* blockA, const bfloat //Blocks of 8 columns with <8 elements as row major. This happens when there's less than 8 remaining rows //Loop for LHS standard block (8x16) - int standard_block_size = 16; - const int standard_blocks_quantity = rows/standard_block_size; //Number of standard blocks - int bigSuffix = (2*8) * (strideA-offsetA-depth); + const Index standard_block_size = 16; + const Index standard_blocks_quantity = rows/standard_block_size; //Number of standard blocks + Index bigSuffix = (2*8) * (strideA-offsetA-depth); const bfloat16* indexA = blockA; - int block_index; + const Index offset_factor = 2; + Index block_index; for(block_index = 0; block_index < standard_blocks_quantity; block_index++){ indexA += 2*8*offsetA; - for(int offset_row = 0; offset_row < standard_block_size; offset_row += 4){ //This block size has 16 rows maximum + for(Index offset_row = 0; offset_row < standard_block_size; offset_row += 4){ //This block size has 16 rows maximum col = 0; - colLoopBody<5, 16, 2>(&col, row, depth, cols, rows, offset_row, block_index, pAlpha, indexA+offset_row, strideA, blockB, strideB, offsetB, result); - colLoopBody<4, 16, 2>(&col, row, depth, cols, rows, offset_row, block_index, pAlpha, indexA+offset_row, strideA, blockB, strideB, offsetB, result); - colLoopBody<3, 16, 2>(&col, row, depth, cols, rows, offset_row, block_index, pAlpha, indexA+offset_row, strideA, blockB, strideB, offsetB, result); - colLoopBody<2, 16, 2>(&col, row, depth, cols, rows, offset_row, block_index, pAlpha, indexA+offset_row, strideA, blockB, strideB, offsetB, result); - colLoopBody<1, 16, 2>(&col, row, depth, cols, rows, offset_row, block_index, pAlpha, indexA+offset_row, strideA, blockB, strideB, offsetB, result); + colLoopBody<7, 16>(col, row, depth, cols, rows, offset_row, block_index, pAlpha, indexA+offset_row*offset_factor, strideA, blockB, strideB, offsetB, result); + colLoopBody<6, 16>(col, row, depth, cols, rows, offset_row, block_index, pAlpha, indexA+offset_row*offset_factor, strideA, blockB, strideB, offsetB, result); + colLoopBody<5, 16>(col, row, depth, cols, rows, offset_row, block_index, pAlpha, indexA+offset_row*offset_factor, strideA, blockB, strideB, offsetB, result); + colLoopBody<4, 16>(col, row, depth, cols, rows, offset_row, block_index, pAlpha, indexA+offset_row*offset_factor, strideA, blockB, strideB, offsetB, result); + colLoopBody<3, 16>(col, row, depth, cols, rows, offset_row, block_index, pAlpha, indexA+offset_row*offset_factor, strideA, blockB, strideB, offsetB, result); + colLoopBody<2, 16>(col, row, depth, cols, rows, offset_row, block_index, pAlpha, indexA+offset_row*offset_factor, strideA, blockB, strideB, offsetB, result); + colLoopBody<1, 16>(col, row, depth, cols, rows, offset_row, block_index, pAlpha, indexA+offset_row*offset_factor, strideA, blockB, strideB, offsetB, result); if(cols > col){ - int extra_cols= cols-col; - int shift = (4-extra_cols>= 0) ? 4-extra_cols: 0; - int mask_cols= 0xF >> shift; + Index extra_cols= cols-col; //Remember: It doesnt make sense use multiple acc to extra_cols as we are unrolling col loop - colLoopBody<1, 16, 2, true>(&col, row, depth, cols, rows, offset_row, block_index, pAlpha, indexA+offset_row, strideA, blockB, strideB, offsetB, result, extra_cols, 4, mask_cols, 0xF); + colLoopBody<1, 16, true>(col, row, depth, cols, rows, offset_row, block_index, pAlpha, indexA+offset_row*offset_factor, strideA, blockB, strideB, offsetB, result, extra_cols, 4); } } row += 16; @@ -296,75 +220,66 @@ void gemmMMAbfloat16(const DataMapper& res, const bfloat16* blockA, const bfloat //LHS (8x8) block if(rows - standard_blocks_quantity*16 >= 8){ indexA += 1*8*offsetA + 2*8*offsetA; - for(int offset_row = 0; offset_row < 8; offset_row += 4){ + for(Index offset_row = 0; offset_row < 8; offset_row += 4){ col = 0; - colLoopBody<5, 8, 1>(&col, row, depth, cols, rows, offset_row, block_index, pAlpha, indexA+offset_row, strideA, blockB, strideB, offsetB, result); - colLoopBody<4, 8, 1>(&col, row, depth, cols, rows, offset_row, block_index, pAlpha, indexA+offset_row, strideA, blockB, strideB, offsetB, result); - colLoopBody<3, 8, 1>(&col, row, depth, cols, rows, offset_row, block_index, pAlpha, indexA+offset_row, strideA, blockB, strideB, offsetB, result); - colLoopBody<2, 8, 1>(&col, row, depth, cols, rows, offset_row, block_index, pAlpha, indexA+offset_row, strideA, blockB, strideB, offsetB, result); - colLoopBody<1, 8, 1>(&col, row, depth, cols, rows, offset_row, block_index, pAlpha, indexA+offset_row, strideA, blockB, strideB, offsetB, result); + colLoopBody<7, 8>(col, row, depth, cols, rows, offset_row, block_index, pAlpha, indexA+offset_row*offset_factor, strideA, blockB, strideB, offsetB, result); + colLoopBody<6, 8>(col, row, depth, cols, rows, offset_row, block_index, pAlpha, indexA+offset_row*offset_factor, strideA, blockB, strideB, offsetB, result); + colLoopBody<5, 8>(col, row, depth, cols, rows, offset_row, block_index, pAlpha, indexA+offset_row*offset_factor, strideA, blockB, strideB, offsetB, result); + colLoopBody<4, 8>(col, row, depth, cols, rows, offset_row, block_index, pAlpha, indexA+offset_row*offset_factor, strideA, blockB, strideB, offsetB, result); + colLoopBody<3, 8>(col, row, depth, cols, rows, offset_row, block_index, pAlpha, indexA+offset_row*offset_factor, strideA, blockB, strideB, offsetB, result); + colLoopBody<2, 8>(col, row, depth, cols, rows, offset_row, block_index, pAlpha, indexA+offset_row*offset_factor, strideA, blockB, strideB, offsetB, result); + colLoopBody<1, 8>(col, row, depth, cols, rows, offset_row, block_index, pAlpha, indexA+offset_row*offset_factor, strideA, blockB, strideB, offsetB, result); } if(cols > col){ - int extra_cols= cols-col; - int shift = (4-extra_cols>= 0) ? 4-extra_cols: 0; - int mask_cols= 0xF >> shift; + Index extra_cols= cols-col; - for(int offset_row = 0; offset_row < 8; offset_row += 4){ - colLoopBody<1, 8, 1, true>(&col, row, depth, cols, rows, offset_row, block_index, pAlpha, indexA+offset_row, strideA, blockB, strideB, offsetB, result, extra_cols, 4, mask_cols, 0xF); + for(Index offset_row = 0; offset_row < 8; offset_row += 4){ + colLoopBody<1, 8, true>(col, row, depth, cols, rows, offset_row, block_index, pAlpha, indexA+offset_row*offset_factor, strideA, blockB, strideB, offsetB, result, extra_cols, 4); } } //end extra cols row += 8; } //extra rows while(row < rows){ - int extra_rows = rows-row; - int shift = (4-extra_rows >= 0) ? 4-extra_rows : 0; - int mask_rows = 0xF >> shift; - int extra_rows_or_four = (extra_rows <= 4) ? extra_rows : 4; + Index extra_rows = rows-row; + Index extra_rows_or_four = (extra_rows <= 4) ? extra_rows : 4; //This index is the beginning of remaining block. //This last block for LHS is organized as RowMajor col = 0; - colLoopBody<5, 8, 1, false, true>(&col, row, depth, cols, rows, 0, block_index, pAlpha, blockA, strideA, blockB, strideB, offsetB, result, 4, extra_rows_or_four, 0xF, mask_rows); - colLoopBody<4, 8, 1, false, true>(&col, row, depth, cols, rows, 0, block_index, pAlpha, blockA, strideA, blockB, strideB, offsetB, result, 4, extra_rows_or_four, 0xF, mask_rows); - colLoopBody<3, 8, 1, false, true>(&col, row, depth, cols, rows, 0, block_index, pAlpha, blockA, strideA, blockB, strideB, offsetB, result, 4, extra_rows_or_four, 0xF, mask_rows); - colLoopBody<2, 8, 1, false, true>(&col, row, depth, cols, rows, 0, block_index, pAlpha, blockA, strideA, blockB, strideB, offsetB, result, 4, extra_rows_or_four, 0xF, mask_rows); - colLoopBody<1, 8, 1, false, true>(&col, row, depth, cols, rows, 0, block_index, pAlpha, blockA, strideA, blockB, strideB, offsetB, result, 4, extra_rows_or_four, 0xF, mask_rows); + colLoopBody<7, 8, false, true>(col, row, depth, cols, rows, 0, block_index, pAlpha, blockA, strideA, blockB, strideB, offsetB, result, 4, extra_rows_or_four); + colLoopBody<6, 8, false, true>(col, row, depth, cols, rows, 0, block_index, pAlpha, blockA, strideA, blockB, strideB, offsetB, result, 4, extra_rows_or_four); + colLoopBody<5, 8, false, true>(col, row, depth, cols, rows, 0, block_index, pAlpha, blockA, strideA, blockB, strideB, offsetB, result, 4, extra_rows_or_four); + colLoopBody<4, 8, false, true>(col, row, depth, cols, rows, 0, block_index, pAlpha, blockA, strideA, blockB, strideB, offsetB, result, 4, extra_rows_or_four); + colLoopBody<3, 8, false, true>(col, row, depth, cols, rows, 0, block_index, pAlpha, blockA, strideA, blockB, strideB, offsetB, result, 4, extra_rows_or_four); + colLoopBody<2, 8, false, true>(col, row, depth, cols, rows, 0, block_index, pAlpha, blockA, strideA, blockB, strideB, offsetB, result, 4, extra_rows_or_four); + colLoopBody<1, 8, false, true>(col, row, depth, cols, rows, 0, block_index, pAlpha, blockA, strideA, blockB, strideB, offsetB, result, 4, extra_rows_or_four); if(cols > col){ - int extra_cols= cols-col; - int shift = (4-extra_cols>= 0) ? 4-extra_cols: 0; - int mask_cols= 0xF >> shift; - int extra_cols_or_four = (extra_cols <= 4) ? extra_cols : 4; + Index extra_cols= cols-col; - colLoopBody<1, 8, 1, true, true>(&col, row, depth, cols, rows, 0, block_index, pAlpha, blockA, strideA, blockB, strideB, offsetB, result, extra_cols_or_four, extra_rows_or_four, mask_cols, mask_rows); + colLoopBody<1, 8, true, true>(col, row, depth, cols, rows, 0, block_index, pAlpha, blockA, strideA, blockB, strideB, offsetB, result, extra_cols, extra_rows_or_four); } row += extra_rows_or_four; } //Convert back to bfloat16 - for(col = 0; col/4 < cols/4; col += 4){ - int row; - for(row = 0; row/8 < rows/8; row += 8){ + for(col = 0; col + 4 <= cols; col += 4){ + const DataMapper res2 = res.getSubMapper(0, col); + for(row = 0; row + 8 <= rows; row += 8){ //get and save block PacketBlock block; - for(int j = 0; j < 4; j++){ - Packet4f temp_even, temp_odd; - EIGEN_ALIGN32 float even[4], odd[4]; - for(int i = 0; i < 4; i++){ - even[i] = result[(col + j)*rows + row + i*2]; - odd[i] = result[(col + j)*rows + row + i*2+1]; - } - temp_even = pload(even); - temp_odd = pload(odd); - block.packet[j] = F32ToBf16(temp_even, temp_odd); + for(Index j = 0; j < 4; j++){ + Packet16uc fp16_0 = __builtin_vsx_xvcvspbf16(reinterpret_cast(pload(result + (col + j)*rows + row))); + Packet16uc fp16_1 = __builtin_vsx_xvcvspbf16(reinterpret_cast(pload(result + (col + j)*rows + row + 4))); + block.packet[j].m_val = vec_pack(reinterpret_cast(fp16_0), reinterpret_cast(fp16_1)); } - res.template storePacketBlock(row, col, block); + res2.template storePacketBlock(row, 0, block); } //extra rows while(row < rows){ - for(int col_off = 0; col_off < 4; col_off++){ - res(row, col+col_off) = Eigen::bfloat16(result[(col+col_off)*rows+row]); + for(Index col_off = 0; col_off < 4; col_off++){ + res2(row, col_off) = Eigen::bfloat16(result[(col+col_off)*rows+row]); } row++; } @@ -372,8 +287,9 @@ void gemmMMAbfloat16(const DataMapper& res, const bfloat16* blockA, const bfloat } //extra cols while(col < cols){ - for(int r = 0; r < rows; r++){ - res(r, col) = Eigen::bfloat16(result[col*rows + r]); + const LinearMapper res2 = res.getLinearMapper(0, col); + for(Index r= 0; r< rows; r++){ + res2(r) = Eigen::bfloat16(result[col*rows + r]); } col++; } diff --git a/Eigen/src/Core/util/BlasUtil.h b/Eigen/src/Core/util/BlasUtil.h index 3f7638e17..a2486af41 100644 --- a/Eigen/src/Core/util/BlasUtil.h +++ b/Eigen/src/Core/util/BlasUtil.h @@ -380,11 +380,6 @@ public: EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void store(const blas_data_mapper* sup, Index i, Index j, const PacketBlock& block) const { spbh.store(sup, i,j,block); sup->template storePacket(i, j+idx, block.packet[idx]); - //for(int l = 0; l < unpacket_traits::size; l++) - //{ - // Scalar_ *v = &sup->operator()(i+l, j+idx); - // *v = *reinterpret_cast(&block.packet[idx][l]); - //} } };