From fba12e02b3e3751c36a5d5b4a3e5aabb41e6b701 Mon Sep 17 00:00:00 2001 From: Chip Kerchner Date: Fri, 10 Feb 2023 17:32:06 +0000 Subject: [PATCH] Fold extra column calculations into an extra MMA accumulator and other bfloat16 MMA GEMM improvements --- .../arch/AltiVec/MatrixProductMMAbfloat16.h | 331 +++++++++--------- 1 file changed, 164 insertions(+), 167 deletions(-) diff --git a/Eigen/src/Core/arch/AltiVec/MatrixProductMMAbfloat16.h b/Eigen/src/Core/arch/AltiVec/MatrixProductMMAbfloat16.h index c6d421600..91c1dd764 100644 --- a/Eigen/src/Core/arch/AltiVec/MatrixProductMMAbfloat16.h +++ b/Eigen/src/Core/arch/AltiVec/MatrixProductMMAbfloat16.h @@ -11,13 +11,6 @@ namespace Eigen { namespace internal { -EIGEN_ALWAYS_INLINE void scaleAndStore(float* result, Packet4f& acc, const Packet4f& pAlpha) -{ - Packet4f result_block = ploadu(result); - result_block = pmadd(acc, pAlpha, result_block); - pstoreu(result, result_block); -} - template EIGEN_ALWAYS_INLINE Packet8bf loadBfloat16(const bfloat16* indexA) { @@ -31,122 +24,161 @@ EIGEN_ALWAYS_INLINE Packet8bf loadBfloat16(const bfloat16* indexA) } template -EIGEN_ALWAYS_INLINE Packet8bf loadBfloat16Extra(const bfloat16* indexA, Index extra_rows) +EIGEN_ALWAYS_INLINE Packet8bf loadRhsBfloat16(const bfloat16* blockB, Index strideB, Index i) { - if (zero) { - Packet8bf lhs1 = ploadu_partial(indexA, extra_rows); - Packet8bf lhs2 = pset1(Eigen::bfloat16(0)); - return vec_mergeh(lhs1.m_val, lhs2.m_val); - } else { - return reinterpret_cast(ploadu_partial(reinterpret_cast(indexA), extra_rows)); - } + return loadBfloat16(blockB + strideB*i); } -template -EIGEN_ALWAYS_INLINE Packet8bf loadLhsBfloat16ExtraRows(const bfloat16* indexA, Index strideA, Index row, Index extra_rows) -{ - return loadBfloat16Extra(indexA + row*strideA, extra_rows); -} - -template -EIGEN_ALWAYS_INLINE Packet8bf loadRhsBfloat16(const bfloat16* baseB, Index strideB, Index i, Index k) -{ - return loadBfloat16(baseB + strideB*4*i + (k*4)); -} - -template -EIGEN_ALWAYS_INLINE Packet8bf loadRhsBfloat16ExtraCols(const bfloat16* blockB, Index strideB, Index offsetB, Index col, Index i, Index k, Index extra_cols) -{ - return loadBfloat16Extra(blockB + ((col+4*i)*strideB)+k*extra_cols+offsetB, extra_cols); -} - -template -EIGEN_STRONG_INLINE void KLoop +template +EIGEN_ALWAYS_INLINE void KLoop ( const bfloat16* indexA, const bfloat16* indexB, __vector_quad (&quad_acc)[num_acc], - Index strideA, Index strideB, - Index offsetB, Index k, - Index row, - Index col, - Index extra_rows, - Index extra_cols + Index offsetB, + Index extra_cols, + Index extra_rows ) { - Packet8bf lhs; + Packet8bf lhs = loadBfloat16(indexA + k*(lhsExtraRows ? extra_rows : num_packets)); //a packet of bfloat16 has 8 elements Packet8bf rhs[num_acc]; - if(lhs_extra_rows) lhs = loadLhsBfloat16ExtraRows(indexA+k*extra_rows, strideA, row, extra_rows); - else lhs = loadBfloat16(indexA + k*num_packets); //a packet of bfloat16 has 8 elements + + for(Index i = 0; i < (num_acc - (rhsExtraCols ? 1 : 0)); i++){ + rhs[i] = loadRhsBfloat16(indexB + k*4, strideB, i); + } + if(rhsExtraCols) { + rhs[num_acc - 1] = loadRhsBfloat16(indexB + k*extra_cols - offsetB, strideB, num_acc - 1); + } + BFLOAT16_UNROLL - for(Index i = 0; i < num_acc; i++){ - if(!rhs_extra_cols) - rhs[i] = loadRhsBfloat16(indexB, strideB, i, k); - else{ - rhs[i] = loadRhsBfloat16ExtraCols(indexB, strideB, offsetB, col, i, k, extra_cols); - } + for (Index i = 0; i < num_acc; i++) { __builtin_mma_xvbf16ger2pp(&(quad_acc[i]), reinterpret_cast(rhs[i].m_val), reinterpret_cast(lhs.m_val)); } } -template -void colLoopBody(Index& col, Index row, Index depth, Index cols, Index rows, Index offset_row, Index block_index, const Packet4f& pAlpha, const bfloat16* indexA, Index strideA, const bfloat16* blockB, Index strideB, Index offsetB, float* result, Index extra_cols = 0, Index extra_rows = 0) +template +EIGEN_ALWAYS_INLINE void storeResults(Packet4f (&acc)[4], Index rows, const Packet4f pAlpha, float* result, Index extra_cols, Index extra_rows) { - const Index step = rhsExtraCols ? 1 : (num_acc * 4); //each accumulator has 4 elements - const bfloat16* indexB = rhsExtraCols ? blockB : (blockB + 4*offsetB + strideB*col); - - while(col + step <= cols){ - Index k = 0; - Packet4f acc[num_acc][4]; - __vector_quad quad_acc[num_acc]; - - BFLOAT16_UNROLL - for(Index i = 0; i < num_acc; i++) - __builtin_mma_xxsetaccz(&(quad_acc[i])); - - for(; k + 2 <= depth; k += 2){ - KLoop(indexA, indexB, quad_acc, strideA, strideB, offsetB, k, row, col, extra_rows, extra_cols); - } - if(depth&1){ - KLoop(indexA-(offset_row&(num_packets-1)), indexB, quad_acc, strideA, strideB, offsetB, k, row, col, extra_rows, extra_cols); + Index x = 0; + do{ + Packet4f result_block = ploadu(result); + result_block = pmadd(acc[x], pAlpha, result_block); + if (lhsExtraRows) { + pstoreu_partial(result, result_block, extra_rows); + } else { + pstoreu(result, result_block); } + result += rows; + } while (++x < (rhsExtraCols ? extra_cols : 4)); +} - BFLOAT16_UNROLL - for(Index i = 0; i < num_acc; i++) - __builtin_mma_disassemble_acc((void*)acc[i], &(quad_acc[i])); +#define MAX_BFLOAT16_ACC 8 - for(Index i = 0; i < num_acc; i++){ - if(lhsExtraRows){ - float *r = result + (col+i*4)*rows + row; - for(Index x = 0; x < extra_cols; x++, r += rows){ - Packet4f result_block = ploadu_partial(r, extra_rows); - result_block = pmadd(acc[i][x], pAlpha, result_block); - pstoreu_partial(r, result_block, extra_rows); - } +template +void colLoopBody(Index& col, Index depth, Index cols, Index rows, const Packet4f pAlpha, const bfloat16* indexA, const bfloat16* indexB, Index strideB, Index offsetB, float* result, Index extra_rows) +{ + const Index step = (num_acc * 4); //each accumulator has 4 elements + const Index extra_cols = (rhsExtraCols) ? (cols & 3) : 0; + + do{ + for(Index offset_row = 0; offset_row < num_packets; offset_row += 4, indexA += 8, result += 4) { + Index k; + Packet4f acc[num_acc][4]; + __vector_quad quad_acc[num_acc]; + + BFLOAT16_UNROLL + for(k = 0; k < num_acc; k++) + __builtin_mma_xxsetaccz(&(quad_acc[k])); + + for(k = 0; k + 2 <= depth; k += 2){ + KLoop(indexA, indexB, quad_acc, strideB, k, offsetB, extra_cols, extra_rows); } - else{ - if(rhsExtraCols){ - float *r = result + (col+i*4)*rows + row + offset_row; - for(Index x = 0; x < cols-col; x++, r += rows){ - scaleAndStore(r,acc[i][x], pAlpha); - } - } - else{ - float *r = result + (col+i*4)*rows + (block_index*16) + offset_row; - for(Index x = 0; x < 4; x++, r += rows){ - scaleAndStore(r,acc[i][x], pAlpha); - } - } + if(depth&1){ + KLoop(indexA - offset_row, indexB, quad_acc, strideB, k, offsetB, extra_cols, extra_rows); } + + BFLOAT16_UNROLL + for(k = 0; k < num_acc; k++) + __builtin_mma_disassemble_acc((void*)acc[k], &(quad_acc[k])); + + for(k = 0; k < (num_acc - 1); k++){ + storeResults(acc[k], rows, pAlpha, result + k*4*rows, extra_cols, extra_rows); + } + storeResults(acc[k], rows, pAlpha, result + k*4*rows, extra_cols, extra_rows); } - if(rhsExtraCols) return; - indexB += strideB*step; - col += step; + + indexA -= num_packets*2; + indexB += strideB*num_acc; + result += (rows*step - num_packets); + } while(!rhsExtraCols && (num_acc == MAX_BFLOAT16_ACC) && (step <= cols - (col += step))); +} + +template +EIGEN_ALWAYS_INLINE void colLoopBodyExtraN(Index col, Index depth, Index cols, Index rows, const Packet4f pAlpha, const bfloat16* indexA, const bfloat16* blockB, Index strideB, Index offsetB, float* result, Index extra_rows) +{ + if (MAX_BFLOAT16_ACC > num_acc) { + colLoopBody(col, depth, cols, rows, pAlpha, indexA, blockB, strideB, offsetB, result, extra_rows); } } +template +void colLoopBodyExtra(Index col, Index depth, Index cols, Index rows, const Packet4f pAlpha, const bfloat16* indexA, const bfloat16* blockB, Index strideB, Index offsetB, float* result, Index extra_rows) +{ + switch ((cols - col) >> 2) { + case 7: + colLoopBodyExtraN<7, num_packets, rhsExtraCols, lhsExtraRows>(col, depth, cols, rows, pAlpha, indexA, blockB, strideB, offsetB, result, extra_rows); + break; + case 6: + colLoopBodyExtraN<6, num_packets, rhsExtraCols, lhsExtraRows>(col, depth, cols, rows, pAlpha, indexA, blockB, strideB, offsetB, result, extra_rows); + break; + case 5: + colLoopBodyExtraN<5, num_packets, rhsExtraCols, lhsExtraRows>(col, depth, cols, rows, pAlpha, indexA, blockB, strideB, offsetB, result, extra_rows); + break; + case 4: + colLoopBodyExtraN<4, num_packets, rhsExtraCols, lhsExtraRows>(col, depth, cols, rows, pAlpha, indexA, blockB, strideB, offsetB, result, extra_rows); + break; + case 3: + colLoopBodyExtraN<3, num_packets, rhsExtraCols, lhsExtraRows>(col, depth, cols, rows, pAlpha, indexA, blockB, strideB, offsetB, result, extra_rows); + break; + case 2: + colLoopBodyExtraN<2, num_packets, rhsExtraCols, lhsExtraRows>(col, depth, cols, rows, pAlpha, indexA, blockB, strideB, offsetB, result, extra_rows); + break; + case 1: + colLoopBodyExtraN<1, num_packets, rhsExtraCols, lhsExtraRows>(col, depth, cols, rows, pAlpha, indexA, blockB, strideB, offsetB, result, extra_rows); + break; + default: + if (rhsExtraCols) { + colLoopBody<1, num_packets, true, lhsExtraRows>(col, depth, cols, rows, pAlpha, indexA, blockB, strideB, offsetB, result, extra_rows); + } + break; + } +} + +template +EIGEN_ALWAYS_INLINE void colLoops(Index depth, Index cols, Index rows, const Packet4f pAlpha, const bfloat16* indexA, const bfloat16* blockB, Index strideB, Index offsetB, float* result, Index extra_rows = 0) +{ + Index col = 0; + if (cols >= (MAX_BFLOAT16_ACC * 4)) { + colLoopBody(col, depth, cols, rows, pAlpha, indexA, blockB, strideB, 0, result, extra_rows); + blockB += (strideB >> 2)*col; + result += rows*col; + } + if (cols & 3) { + colLoopBodyExtra(col, depth, cols, rows, pAlpha, indexA, blockB, strideB, offsetB, result, extra_rows); + } else { + colLoopBodyExtra(col, depth, cols, rows, pAlpha, indexA, blockB, strideB, 0, result, extra_rows); + } +} + +EIGEN_ALWAYS_INLINE Packet8bf convertF16toF32(const float *res) +{ + Packet16uc fp16_0 = __builtin_vsx_xvcvspbf16(reinterpret_cast(ploadu(res + 0))); + Packet16uc fp16_1 = __builtin_vsx_xvcvspbf16(reinterpret_cast(ploadu(res + 4))); + return vec_pack(reinterpret_cast(fp16_0), reinterpret_cast(fp16_1)); +} + template void gemmMMAbfloat16(const DataMapper& res, const bfloat16* blockA, const bfloat16* blockB, Index rows, Index depth, Index cols, bfloat16 alpha, Index strideA, Index strideB, Index offsetA, Index offsetB) { @@ -157,17 +189,33 @@ void gemmMMAbfloat16(const DataMapper& res, const bfloat16* blockA, const bfloat ei_declare_aligned_stack_constructed_variable(float, result, cols*rows, 0); typedef typename DataMapper::LinearMapper LinearMapper; + Packet4f z = pset1(float(0)); for(Index j = 0; j < cols; j++){ const LinearMapper res2 = res.getLinearMapper(0, j); float *result2 = result + j*rows; + Index i = 0; + for(; i + 32 <= rows; i+=32){ + Packet4f r32_0 = reinterpret_cast(res2.template loadPacket(i + 0).m_val); + Packet4f r32_1 = reinterpret_cast(res2.template loadPacket(i + 8).m_val); + Packet4f r32_2 = reinterpret_cast(res2.template loadPacket(i + 16).m_val); + Packet4f r32_3 = reinterpret_cast(res2.template loadPacket(i + 24).m_val); + pstore(result2 + i + 0, vec_mergeo(r32_0, z)); + pstore(result2 + i + 4, vec_mergee(r32_0, z)); + pstore(result2 + i + 8, vec_mergeo(r32_1, z)); + pstore(result2 + i + 12, vec_mergee(r32_1, z)); + pstore(result2 + i + 16, vec_mergeo(r32_2, z)); + pstore(result2 + i + 20, vec_mergee(r32_2, z)); + pstore(result2 + i + 24, vec_mergeo(r32_3, z)); + pstore(result2 + i + 28, vec_mergee(r32_3, z)); + } BFLOAT16_UNROLL - for(Index i = 0; i < rows; i++){ + for(; i < rows; i++){ result2[i] = res2(i); } } Index row = 0; - Index col = 0; + Index col; if( strideA == -1 ) strideA = depth; if( strideB == -1 ) strideB = depth; @@ -183,90 +231,35 @@ void gemmMMAbfloat16(const DataMapper& res, const bfloat16* blockA, const bfloat const Index standard_blocks_quantity = rows/standard_block_size; //Number of standard blocks Index bigSuffix = (2*8) * (strideA-offsetA); const bfloat16* indexA = blockA; - const Index offset_factor = 2; + const bfloat16* indexB = blockB + 4*offsetB; Index block_index; + strideB *= 4; + offsetB *= 3; for(block_index = 0; block_index < standard_blocks_quantity; block_index++){ indexA += 2*8*offsetA; - for(Index offset_row = 0; offset_row < standard_block_size; offset_row += 4){ //This block size has 16 rows maximum - col = 0; - colLoopBody<7, 16>(col, row, depth, cols, rows, offset_row, block_index, pAlpha, indexA+offset_row*offset_factor, strideA, blockB, strideB, offsetB, result); - colLoopBody<6, 16>(col, row, depth, cols, rows, offset_row, block_index, pAlpha, indexA+offset_row*offset_factor, strideA, blockB, strideB, offsetB, result); - colLoopBody<5, 16>(col, row, depth, cols, rows, offset_row, block_index, pAlpha, indexA+offset_row*offset_factor, strideA, blockB, strideB, offsetB, result); - colLoopBody<4, 16>(col, row, depth, cols, rows, offset_row, block_index, pAlpha, indexA+offset_row*offset_factor, strideA, blockB, strideB, offsetB, result); - colLoopBody<3, 16>(col, row, depth, cols, rows, offset_row, block_index, pAlpha, indexA+offset_row*offset_factor, strideA, blockB, strideB, offsetB, result); - colLoopBody<2, 16>(col, row, depth, cols, rows, offset_row, block_index, pAlpha, indexA+offset_row*offset_factor, strideA, blockB, strideB, offsetB, result); - colLoopBody<1, 16>(col, row, depth, cols, rows, offset_row, block_index, pAlpha, indexA+offset_row*offset_factor, strideA, blockB, strideB, offsetB, result); - if(cols > col){ - Index extra_cols= cols-col; - //Remember: It doesnt make sense use multiple acc to extra_cols as we are unrolling col loop - colLoopBody<1, 16, true>(col, row, depth, cols, rows, offset_row, block_index, pAlpha, indexA+offset_row*offset_factor, strideA, blockB, strideB, offsetB, result, extra_cols, 4); - } - } + colLoops<16>(depth, cols, rows, pAlpha, indexA, indexB, strideB, offsetB, result + row); row += 16; indexA += bigSuffix; } //LHS (8x8) block if(rows & 8){ indexA += 1*8*offsetA; - for(Index offset_row = 0; offset_row < 8; offset_row += 4){ - col = 0; - colLoopBody<7, 8>(col, row, depth, cols, rows, offset_row, block_index, pAlpha, indexA+offset_row*offset_factor, strideA, blockB, strideB, offsetB, result); - colLoopBody<6, 8>(col, row, depth, cols, rows, offset_row, block_index, pAlpha, indexA+offset_row*offset_factor, strideA, blockB, strideB, offsetB, result); - colLoopBody<5, 8>(col, row, depth, cols, rows, offset_row, block_index, pAlpha, indexA+offset_row*offset_factor, strideA, blockB, strideB, offsetB, result); - colLoopBody<4, 8>(col, row, depth, cols, rows, offset_row, block_index, pAlpha, indexA+offset_row*offset_factor, strideA, blockB, strideB, offsetB, result); - colLoopBody<3, 8>(col, row, depth, cols, rows, offset_row, block_index, pAlpha, indexA+offset_row*offset_factor, strideA, blockB, strideB, offsetB, result); - colLoopBody<2, 8>(col, row, depth, cols, rows, offset_row, block_index, pAlpha, indexA+offset_row*offset_factor, strideA, blockB, strideB, offsetB, result); - colLoopBody<1, 8>(col, row, depth, cols, rows, offset_row, block_index, pAlpha, indexA+offset_row*offset_factor, strideA, blockB, strideB, offsetB, result); - } - if(cols > col){ - Index extra_cols= cols-col; - - for(Index offset_row = 0; offset_row < 8; offset_row += 4){ - colLoopBody<1, 8, true>(col, row, depth, cols, rows, offset_row, block_index, pAlpha, indexA+offset_row*offset_factor, strideA, blockB, strideB, offsetB, result, extra_cols, 4); - } - } //end extra cols + colLoops<8>(depth, cols, rows, pAlpha, indexA, indexB, strideB, offsetB, result + row); row += 8; indexA += (bigSuffix >> 1); } //LHS (8x4) block if(rows & 4){ - Index offset_row = (rows & 8); indexA += 1*4*offsetA; - col = 0; - colLoopBody<7, 4>(col, row, depth, cols, rows, offset_row, block_index, pAlpha, indexA, strideA, blockB, strideB, offsetB, result); - colLoopBody<6, 4>(col, row, depth, cols, rows, offset_row, block_index, pAlpha, indexA, strideA, blockB, strideB, offsetB, result); - colLoopBody<5, 4>(col, row, depth, cols, rows, offset_row, block_index, pAlpha, indexA, strideA, blockB, strideB, offsetB, result); - colLoopBody<4, 4>(col, row, depth, cols, rows, offset_row, block_index, pAlpha, indexA, strideA, blockB, strideB, offsetB, result); - colLoopBody<3, 4>(col, row, depth, cols, rows, offset_row, block_index, pAlpha, indexA, strideA, blockB, strideB, offsetB, result); - colLoopBody<2, 4>(col, row, depth, cols, rows, offset_row, block_index, pAlpha, indexA, strideA, blockB, strideB, offsetB, result); - colLoopBody<1, 4>(col, row, depth, cols, rows, offset_row, block_index, pAlpha, indexA, strideA, blockB, strideB, offsetB, result); - if(cols > col){ - Index extra_cols= cols-col; - - colLoopBody<1, 4, true>(col, row, depth, cols, rows, 0, block_index, pAlpha, indexA, strideA, blockB, strideB, offsetB, result, extra_cols, 4); - } + colLoops<4>(depth, cols, rows, pAlpha, indexA, indexB, strideB, offsetB, result + row); row += 4; indexA += (bigSuffix >> 2); } //extra rows - if(row < rows){ - Index extra_rows_or_four = rows-row; - + Index extra_rows = rows & 3; + if(extra_rows){ //This index is the beginning of remaining block. - col = 0; - colLoopBody<7, 8, false, true>(col, row, depth, cols, rows, 0, block_index, pAlpha, blockA, strideA, blockB, strideB, offsetB, result, 4, extra_rows_or_four); - colLoopBody<6, 8, false, true>(col, row, depth, cols, rows, 0, block_index, pAlpha, blockA, strideA, blockB, strideB, offsetB, result, 4, extra_rows_or_four); - colLoopBody<5, 8, false, true>(col, row, depth, cols, rows, 0, block_index, pAlpha, blockA, strideA, blockB, strideB, offsetB, result, 4, extra_rows_or_four); - colLoopBody<4, 8, false, true>(col, row, depth, cols, rows, 0, block_index, pAlpha, blockA, strideA, blockB, strideB, offsetB, result, 4, extra_rows_or_four); - colLoopBody<3, 8, false, true>(col, row, depth, cols, rows, 0, block_index, pAlpha, blockA, strideA, blockB, strideB, offsetB, result, 4, extra_rows_or_four); - colLoopBody<2, 8, false, true>(col, row, depth, cols, rows, 0, block_index, pAlpha, blockA, strideA, blockB, strideB, offsetB, result, 4, extra_rows_or_four); - colLoopBody<1, 8, false, true>(col, row, depth, cols, rows, 0, block_index, pAlpha, blockA, strideA, blockB, strideB, offsetB, result, 4, extra_rows_or_four); - if(cols > col){ - Index extra_cols= cols-col; - - colLoopBody<1, 8, true, true>(col, row, depth, cols, rows, 0, block_index, pAlpha, blockA, strideA, blockB, strideB, offsetB, result, extra_cols, extra_rows_or_four); - } - row += extra_rows_or_four; + colLoops<4, true>(depth, cols, rows, pAlpha, indexA, indexB, strideB, offsetB, result + row, extra_rows); } //Convert back to bfloat16 @@ -276,9 +269,7 @@ void gemmMMAbfloat16(const DataMapper& res, const bfloat16* blockA, const bfloat //get and save block PacketBlock block; for(Index j = 0; j < 4; j++){ - Packet16uc fp16_0 = __builtin_vsx_xvcvspbf16(reinterpret_cast(ploadu(result + (col + j)*rows + row))); - Packet16uc fp16_1 = __builtin_vsx_xvcvspbf16(reinterpret_cast(ploadu(result + (col + j)*rows + row + 4))); - block.packet[j].m_val = vec_pack(reinterpret_cast(fp16_0), reinterpret_cast(fp16_1)); + block.packet[j].m_val = convertF16toF32(result + (col + j)*rows + row); } res2.template storePacketBlock(row, 0, block); @@ -295,8 +286,14 @@ void gemmMMAbfloat16(const DataMapper& res, const bfloat16* blockA, const bfloat //extra cols while(col < cols){ const LinearMapper res2 = res.getLinearMapper(0, col); - for(Index r= 0; r< rows; r++){ - res2(r) = Eigen::bfloat16(result[col*rows + r]); + float *result2 = result + col*rows; + Index r = 0; + for(; r + 8 <= rows; r += 8){ + Packet8bf fp16 = convertF16toF32(result2 + r); + res2.template storePacket(r, fp16); + } + for(; r< rows; r++){ + res2(r) = Eigen::bfloat16(result2[r]); } col++; }