From 6fc9de7d93ec270313d2ffba4ac64ea58da2e1ca Mon Sep 17 00:00:00 2001 From: Chip Kerchner Date: Wed, 25 Jan 2023 18:22:20 +0000 Subject: [PATCH] Fix slowdown in bfloat16 MMA when rows is not a multiple of 8 or columns is not a multiple of 4. --- Eigen/src/Core/arch/AltiVec/MatrixProduct.h | 113 +++++++++++++++--- .../arch/AltiVec/MatrixProductMMAbfloat16.h | 95 ++++++++------- 2 files changed, 146 insertions(+), 62 deletions(-) diff --git a/Eigen/src/Core/arch/AltiVec/MatrixProduct.h b/Eigen/src/Core/arch/AltiVec/MatrixProduct.h index 61d961881..adb2eac7a 100644 --- a/Eigen/src/Core/arch/AltiVec/MatrixProduct.h +++ b/Eigen/src/Core/arch/AltiVec/MatrixProduct.h @@ -1007,6 +1007,7 @@ struct dhs_pack Packet2ul v1[8]; + // This is transposing and interleaving data v1[0] = reinterpret_cast(vec_mergeh(reinterpret_cast(block1.packet[0].m_val), reinterpret_cast(block1.packet[1].m_val))); v1[1] = reinterpret_cast(vec_mergel(reinterpret_cast(block1.packet[0].m_val), reinterpret_cast(block1.packet[1].m_val))); v1[2] = reinterpret_cast(vec_mergeh(reinterpret_cast(block1.packet[2].m_val), reinterpret_cast(block1.packet[3].m_val))); @@ -1052,19 +1053,82 @@ struct dhs_pack if(PanelMode) ri += vectorSize*(stride - offset - depth); } - - if(PanelMode) ri += offset; - - for(; j < rows; j++) + if(j + 4 <= rows) { const DataMapper lhs2 = lhs.getSubMapper(j, 0); - for(Index i = 0; i < depth; i++) + Index i = 0; + + if(PanelMode) ri += 4*offset; + + for(; i + 2 <= depth; i+=2) { - blockA[ri] = lhs2(0, i); - ri += 1; + if(StorageOrder == ColMajor) + { + PacketBlock block; + + block.packet[0] = lhs2.template loadPacketPartial(0, i + 0, 4); + block.packet[1] = lhs2.template loadPacketPartial(0, i + 1, 4); + + block.packet[0] = vec_mergeh(block.packet[0].m_val, block.packet[1].m_val); + + pstore(blockA + ri, block.packet[0]); + } else { + blockA[ri+0] = lhs2(0, i + 0); + blockA[ri+1] = lhs2(0, i + 1); + blockA[ri+2] = lhs2(1, i + 0); + blockA[ri+3] = lhs2(1, i + 1); + blockA[ri+4] = lhs2(2, i + 0); + blockA[ri+5] = lhs2(2, i + 1); + blockA[ri+6] = lhs2(3, i + 0); + blockA[ri+7] = lhs2(3, i + 1); + } + + ri += 2*4; + } + if (depth & 1) + { + if(StorageOrder == ColMajor) + { + Packet8bf lhsV = lhs2.template loadPacketPartial(0, i + 0, 4); + + pstore_partial(blockA + ri, lhsV, 4); + } else { + blockA[ri+0] = lhs2(0, i); + blockA[ri+1] = lhs2(1, i); + blockA[ri+2] = lhs2(2, i); + blockA[ri+3] = lhs2(3, i); + } + + ri += 4; } - if(PanelMode) ri += stride - depth; + if(PanelMode) ri += 4*(stride - offset - depth); + j += 4; + } + + if (j < rows) + { + if(PanelMode) ri += offset*(rows - j); + + Index i = 0; + for(; i + 2 <= depth; i+=2) + { + Index k = j; + for(; k < rows; k++) + { + blockA[ri+0] = lhs(k, i + 0); + blockA[ri+1] = lhs(k, i + 1); + ri += 2; + } + } + if (depth & 1) + { + for(; j < rows; j++) + { + blockA[ri] = lhs(j, i); + ri += 1; + } + } } } }; @@ -1163,18 +1227,29 @@ struct dhs_pack if(PanelMode) ri += 4*(stride - offset - depth); } - if(PanelMode) ri += offset; - - for(; j < cols; j++) + if (j < cols) { - const DataMapper rhs2 = rhs.getSubMapper(0, j); - for(Index i = 0; i < depth; i++) - { - blockB[ri] = rhs2(i, 0); - ri += 1; - } + if(PanelMode) ri += offset*(cols - j); - if(PanelMode) ri += stride - depth; + Index i = 0; + for(; i + 2 <= depth; i+=2) + { + Index k = j; + for(; k < cols; k++) + { + blockB[ri+0] = rhs(i + 0, k); + blockB[ri+1] = rhs(i + 1, k); + ri += 2; + } + } + if (depth & 1) + { + for(; j < cols; j++) + { + blockB[ri] = rhs(i, j); + ri += 1; + } + } } } }; @@ -2662,6 +2737,7 @@ void gemm_pack_rhs struct gemm_pack_rhs { @@ -2689,6 +2765,7 @@ void gemm_pack_rhs pack; pack(blockB, rhs, depth, cols, stride, offset); } +#endif template struct gemm_pack_lhs diff --git a/Eigen/src/Core/arch/AltiVec/MatrixProductMMAbfloat16.h b/Eigen/src/Core/arch/AltiVec/MatrixProductMMAbfloat16.h index 01465d3e7..c6d421600 100644 --- a/Eigen/src/Core/arch/AltiVec/MatrixProductMMAbfloat16.h +++ b/Eigen/src/Core/arch/AltiVec/MatrixProductMMAbfloat16.h @@ -18,8 +18,8 @@ EIGEN_ALWAYS_INLINE void scaleAndStore(float* result, Packet4f& acc, const Packe pstoreu(result, result_block); } -template -EIGEN_ALWAYS_INLINE Packet8bf loadLhsBfloat16(const bfloat16* indexA) +template +EIGEN_ALWAYS_INLINE Packet8bf loadBfloat16(const bfloat16* indexA) { Packet8bf lhs1 = ploadu(indexA); if(zero){ @@ -31,48 +31,33 @@ EIGEN_ALWAYS_INLINE Packet8bf loadLhsBfloat16(const bfloat16* indexA) } template -EIGEN_ALWAYS_INLINE Packet8bf loadBfloat16Extra(const bfloat16* indexA, Index strideA, Index extra_rows) +EIGEN_ALWAYS_INLINE Packet8bf loadBfloat16Extra(const bfloat16* indexA, Index extra_rows) { - Index row_count = 0; if (zero) { - EIGEN_ALIGN16 bfloat16 lhs_array[8] = { Eigen::bfloat16(0) }; - do{ - lhs_array[row_count] = *indexA; - indexA += strideA; - } while ((row_count += 2) < extra_rows*2); - return pload_partial(lhs_array, extra_rows*2); + Packet8bf lhs1 = ploadu_partial(indexA, extra_rows); + Packet8bf lhs2 = pset1(Eigen::bfloat16(0)); + return vec_mergeh(lhs1.m_val, lhs2.m_val); } else { - EIGEN_ALIGN16 int lhs_array[4]; - do{ - lhs_array[row_count] = *reinterpret_cast(indexA); - indexA += strideA; - } while ((row_count += 1) < extra_rows); - return reinterpret_cast(pload_partial(lhs_array, extra_rows)); + return reinterpret_cast(ploadu_partial(reinterpret_cast(indexA), extra_rows)); } } template EIGEN_ALWAYS_INLINE Packet8bf loadLhsBfloat16ExtraRows(const bfloat16* indexA, Index strideA, Index row, Index extra_rows) { - return loadBfloat16Extra(indexA + row*strideA, strideA, extra_rows); + return loadBfloat16Extra(indexA + row*strideA, extra_rows); } template EIGEN_ALWAYS_INLINE Packet8bf loadRhsBfloat16(const bfloat16* baseB, Index strideB, Index i, Index k) { - const bfloat16* indexB = baseB + strideB*4*i + (k*4); - Packet8bf rhs1 = ploadu(indexB); - if(zero){ - Packet8bf rhs2 = pset1(Eigen::bfloat16(0)); - return vec_mergeh(rhs1.m_val, rhs2.m_val); - } - return rhs1; + return loadBfloat16(baseB + strideB*4*i + (k*4)); } template EIGEN_ALWAYS_INLINE Packet8bf loadRhsBfloat16ExtraCols(const bfloat16* blockB, Index strideB, Index offsetB, Index col, Index i, Index k, Index extra_cols) { - return loadBfloat16Extra(blockB + ((col+4*i)*strideB)+k+offsetB, strideB, extra_cols); + return loadBfloat16Extra(blockB + ((col+4*i)*strideB)+k*extra_cols+offsetB, extra_cols); } template @@ -93,8 +78,8 @@ EIGEN_STRONG_INLINE void KLoop { Packet8bf lhs; Packet8bf rhs[num_acc]; - if(lhs_extra_rows) lhs = loadLhsBfloat16ExtraRows(indexA+k, strideA, row, extra_rows); - else lhs = loadLhsBfloat16(indexA + k*num_packets); //a packet of bfloat16 has 8 elements + if(lhs_extra_rows) lhs = loadLhsBfloat16ExtraRows(indexA+k*extra_rows, strideA, row, extra_rows); + else lhs = loadBfloat16(indexA + k*num_packets); //a packet of bfloat16 has 8 elements BFLOAT16_UNROLL for(Index i = 0; i < num_acc; i++){ if(!rhs_extra_cols) @@ -125,7 +110,7 @@ void colLoopBody(Index& col, Index row, Index depth, Index cols, Index rows, Ind KLoop(indexA, indexB, quad_acc, strideA, strideB, offsetB, k, row, col, extra_rows, extra_cols); } if(depth&1){ - KLoop(indexA-offset_row, indexB, quad_acc, strideA, strideB, offsetB, k, row, col, extra_rows, extra_cols); + KLoop(indexA-(offset_row&(num_packets-1)), indexB, quad_acc, strideA, strideB, offsetB, k, row, col, extra_rows, extra_cols); } BFLOAT16_UNROLL @@ -174,8 +159,10 @@ void gemmMMAbfloat16(const DataMapper& res, const bfloat16* blockA, const bfloat typedef typename DataMapper::LinearMapper LinearMapper; for(Index j = 0; j < cols; j++){ const LinearMapper res2 = res.getLinearMapper(0, j); + float *result2 = result + j*rows; + BFLOAT16_UNROLL for(Index i = 0; i < rows; i++){ - result[j*rows + i] = res2(i); + result2[i] = res2(i); } } @@ -185,15 +172,16 @@ void gemmMMAbfloat16(const DataMapper& res, const bfloat16* blockA, const bfloat if( strideA == -1 ) strideA = depth; if( strideB == -1 ) strideB = depth; //Packing is done in blocks. - //There's 3 possible sizes of blocks - //Blocks of 8 columns with 16 elements (8x16) as col major - //Blocks of 8 columns with 8 elements (8x8) as col major. This happens when there's 16 > rows > 8 - //Blocks of 8 columns with <8 elements as row major. This happens when there's less than 8 remaining rows + //There's 4 possible sizes of blocks + //Blocks of 8 columns with 16 elements (8x16) + //Blocks of 8 columns with 8 elements (8x8). This happens when there's 16 > rows >= 8 + //Blocks of 8 columns with 4 elements (8x4). This happens when there's 8 > rows >= 4 + //Blocks of 8 columns with < 4 elements. This happens when there's less than 4 remaining rows //Loop for LHS standard block (8x16) const Index standard_block_size = 16; const Index standard_blocks_quantity = rows/standard_block_size; //Number of standard blocks - Index bigSuffix = (2*8) * (strideA-offsetA-depth); + Index bigSuffix = (2*8) * (strideA-offsetA); const bfloat16* indexA = blockA; const Index offset_factor = 2; Index block_index; @@ -215,11 +203,11 @@ void gemmMMAbfloat16(const DataMapper& res, const bfloat16* blockA, const bfloat } } row += 16; - indexA += bigSuffix + 2*8*depth; + indexA += bigSuffix; } //LHS (8x8) block - if(rows - standard_blocks_quantity*16 >= 8){ - indexA += 1*8*offsetA + 2*8*offsetA; + if(rows & 8){ + indexA += 1*8*offsetA; for(Index offset_row = 0; offset_row < 8; offset_row += 4){ col = 0; colLoopBody<7, 8>(col, row, depth, cols, rows, offset_row, block_index, pAlpha, indexA+offset_row*offset_factor, strideA, blockB, strideB, offsetB, result); @@ -238,14 +226,33 @@ void gemmMMAbfloat16(const DataMapper& res, const bfloat16* blockA, const bfloat } } //end extra cols row += 8; + indexA += (bigSuffix >> 1); + } + //LHS (8x4) block + if(rows & 4){ + Index offset_row = (rows & 8); + indexA += 1*4*offsetA; + col = 0; + colLoopBody<7, 4>(col, row, depth, cols, rows, offset_row, block_index, pAlpha, indexA, strideA, blockB, strideB, offsetB, result); + colLoopBody<6, 4>(col, row, depth, cols, rows, offset_row, block_index, pAlpha, indexA, strideA, blockB, strideB, offsetB, result); + colLoopBody<5, 4>(col, row, depth, cols, rows, offset_row, block_index, pAlpha, indexA, strideA, blockB, strideB, offsetB, result); + colLoopBody<4, 4>(col, row, depth, cols, rows, offset_row, block_index, pAlpha, indexA, strideA, blockB, strideB, offsetB, result); + colLoopBody<3, 4>(col, row, depth, cols, rows, offset_row, block_index, pAlpha, indexA, strideA, blockB, strideB, offsetB, result); + colLoopBody<2, 4>(col, row, depth, cols, rows, offset_row, block_index, pAlpha, indexA, strideA, blockB, strideB, offsetB, result); + colLoopBody<1, 4>(col, row, depth, cols, rows, offset_row, block_index, pAlpha, indexA, strideA, blockB, strideB, offsetB, result); + if(cols > col){ + Index extra_cols= cols-col; + + colLoopBody<1, 4, true>(col, row, depth, cols, rows, 0, block_index, pAlpha, indexA, strideA, blockB, strideB, offsetB, result, extra_cols, 4); + } + row += 4; + indexA += (bigSuffix >> 2); } //extra rows - while(row < rows){ - Index extra_rows = rows-row; - Index extra_rows_or_four = (extra_rows <= 4) ? extra_rows : 4; + if(row < rows){ + Index extra_rows_or_four = rows-row; - //This index is the beginning of remaining block. - //This last block for LHS is organized as RowMajor + //This index is the beginning of remaining block. col = 0; colLoopBody<7, 8, false, true>(col, row, depth, cols, rows, 0, block_index, pAlpha, blockA, strideA, blockB, strideB, offsetB, result, 4, extra_rows_or_four); colLoopBody<6, 8, false, true>(col, row, depth, cols, rows, 0, block_index, pAlpha, blockA, strideA, blockB, strideB, offsetB, result, 4, extra_rows_or_four); @@ -269,8 +276,8 @@ void gemmMMAbfloat16(const DataMapper& res, const bfloat16* blockA, const bfloat //get and save block PacketBlock block; for(Index j = 0; j < 4; j++){ - Packet16uc fp16_0 = __builtin_vsx_xvcvspbf16(reinterpret_cast(pload(result + (col + j)*rows + row))); - Packet16uc fp16_1 = __builtin_vsx_xvcvspbf16(reinterpret_cast(pload(result + (col + j)*rows + row + 4))); + Packet16uc fp16_0 = __builtin_vsx_xvcvspbf16(reinterpret_cast(ploadu(result + (col + j)*rows + row))); + Packet16uc fp16_1 = __builtin_vsx_xvcvspbf16(reinterpret_cast(ploadu(result + (col + j)*rows + row + 4))); block.packet[j].m_val = vec_pack(reinterpret_cast(fp16_0), reinterpret_cast(fp16_1)); }