diff --git a/Eigen/src/Core/arch/AltiVec/MatrixProduct.h b/Eigen/src/Core/arch/AltiVec/MatrixProduct.h index 7c4a42bea..50270603b 100644 --- a/Eigen/src/Core/arch/AltiVec/MatrixProduct.h +++ b/Eigen/src/Core/arch/AltiVec/MatrixProduct.h @@ -841,7 +841,7 @@ struct dhs_pack } }; -#ifdef __MMA__ +#if !EIGEN_ALTIVEC_DISABLE_MMA && defined(__MMA__) // General template for lhs packing, bfloat16 specialization. template struct dhs_pack @@ -1162,6 +1162,7 @@ struct dhs_pack t1 = reinterpret_cast(vec_mergeh(reinterpret_cast(block.packet[2].m_val), reinterpret_cast(block.packet[3].m_val))); t2 = reinterpret_cast(vec_mergel(reinterpret_cast(block.packet[0].m_val), reinterpret_cast(block.packet[1].m_val))); t3 = reinterpret_cast(vec_mergel(reinterpret_cast(block.packet[2].m_val), reinterpret_cast(block.packet[3].m_val))); + block.packet[0] = reinterpret_cast(vec_mergeh(t0, t1)); block.packet[1] = reinterpret_cast(vec_mergel(t0, t1)); block.packet[2] = reinterpret_cast(vec_mergeh(t2, t3)); @@ -2736,7 +2737,7 @@ void gemm_pack_rhs struct gemm_pack_rhs @@ -3270,7 +3271,7 @@ void gebp_kernel, Index, DataMapper, mr, nr, Conjug gemm_function(res, blockA, blockB, rows, depth, cols, alpha, strideA, strideB, offsetA, offsetB); } -#if defined(__MMA__) +#if !EIGEN_ALTIVEC_DISABLE_MMA && defined(__MMA__) template struct gebp_kernel { @@ -3288,10 +3289,7 @@ void gebp_kernel::rows; - const Index accCols = quad_traits::size; - - Eigen::internal::gemmMMAbfloat16(res, blockA, blockB, rows, depth, cols, alpha, strideA, strideB, offsetA, offsetB); + Eigen::internal::gemmMMAbfloat16(res, blockA, blockB, rows, depth, cols, alpha, strideA, strideB, offsetA, offsetB); } #endif } // end namespace internal diff --git a/Eigen/src/Core/arch/AltiVec/MatrixProductMMA.h b/Eigen/src/Core/arch/AltiVec/MatrixProductMMA.h index e4013a747..05d180c0c 100644 --- a/Eigen/src/Core/arch/AltiVec/MatrixProductMMA.h +++ b/Eigen/src/Core/arch/AltiVec/MatrixProductMMA.h @@ -28,7 +28,9 @@ #include "../../InternalHeaderCheck.h" +#if !EIGEN_ALTIVEC_DISABLE_MMA #include "MatrixProductMMAbfloat16.h" +#endif namespace Eigen { diff --git a/Eigen/src/Core/arch/AltiVec/MatrixProductMMAbfloat16.h b/Eigen/src/Core/arch/AltiVec/MatrixProductMMAbfloat16.h index bdf434722..976a73ff5 100644 --- a/Eigen/src/Core/arch/AltiVec/MatrixProductMMAbfloat16.h +++ b/Eigen/src/Core/arch/AltiVec/MatrixProductMMAbfloat16.h @@ -29,7 +29,7 @@ EIGEN_ALWAYS_INLINE Packet8bf loadRhsBfloat16(const bfloat16* blockB, Index stri return loadBfloat16(blockB + strideB*i); } -template +template EIGEN_ALWAYS_INLINE void KLoop ( const bfloat16* indexA, @@ -42,250 +42,277 @@ EIGEN_ALWAYS_INLINE void KLoop Index extra_rows ) { - Packet8bf lhs = loadBfloat16(indexA + k*(lhsExtraRows ? extra_rows : num_packets)); //a packet of bfloat16 has 8 elements - Packet8bf rhs[num_acc]; + Packet8bf lhs[num_lhs], rhs[num_rhs]; - for(Index i = 0; i < (num_acc - (rhsExtraCols ? 1 : 0)); i++){ + for(Index i = 0; i < (num_rhs - (rhsExtraCols ? 1 : 0)); i++){ rhs[i] = loadRhsBfloat16(indexB + k*4, strideB, i); } if(rhsExtraCols) { - rhs[num_acc - 1] = loadRhsBfloat16(indexB + k*extra_cols - offsetB, strideB, num_acc - 1); + rhs[num_rhs - 1] = loadRhsBfloat16(indexB + k*extra_cols - offsetB, strideB, num_rhs - 1); + } + + indexA += k*(lhsExtraRows ? extra_rows : num_packets); + for(Index j = 0; j < num_lhs; j++) { + lhs[j] = loadBfloat16(indexA + j*(zero ? 4 : 8)); //a packet of bfloat16 has 8 elements } BFLOAT16_UNROLL - for (Index i = 0; i < num_acc; i++) { - __builtin_mma_xvbf16ger2pp(&(quad_acc[i]), reinterpret_cast(rhs[i].m_val), reinterpret_cast(lhs.m_val)); + for(Index i = 0, k = 0; i < num_rhs; i++) { + BFLOAT16_UNROLL + for(Index j = 0; j < num_lhs; j++, k++) { + __builtin_mma_xvbf16ger2pp(&(quad_acc[k]), reinterpret_cast(rhs[i].m_val), reinterpret_cast(lhs[j].m_val)); + } } } -template +EIGEN_ALWAYS_INLINE Packet4f loadAndMultiplyF32(Packet4f acc, const Packet4f pAlpha, float* result) +{ + Packet4f result_block = ploadu(result); + return pmadd(acc, pAlpha, result_block); +} + +template +EIGEN_ALWAYS_INLINE void storeF32(float*& result, Packet4f result_block, Index rows, Index extra_rows) +{ + if (lhsExtraRows) { + pstoreu_partial(result, result_block, extra_rows); + } else { + pstoreu(result, result_block); + } + result += rows; +} + +template EIGEN_ALWAYS_INLINE void storeResults(Packet4f (&acc)[4], Index rows, const Packet4f pAlpha, float* result, Index extra_cols, Index extra_rows) { Index x = 0; - do{ - Packet4f result_block = ploadu(result); - result_block = pmadd(acc[x], pAlpha, result_block); - if (lhsExtraRows) { - pstoreu_partial(result, result_block, extra_rows); - } else { - pstoreu(result, result_block); + if (rhsExtraCols) { + do{ + Packet4f result_block = loadAndMultiplyF32(acc[x], pAlpha, result); + storeF32(result, result_block, rows, extra_rows); + } while (++x < extra_cols); + } else { + Packet4f result_block[4]; + float *result2 = result; + do{ + result_block[x] = loadAndMultiplyF32(acc[x], pAlpha, result); + result += rows; + } while (++x < 4); + x = 0; + do{ + storeF32(result2, result_block[x], rows, extra_rows); + } while (++x < 4); + } +} + +template +EIGEN_ALWAYS_INLINE void zeroAccumulators(__vector_quad (&quad_acc)[num_acc]) +{ + BFLOAT16_UNROLL + for(Index k = 0; k < num_acc; k++) + __builtin_mma_xxsetaccz(&(quad_acc[k])); +} + +template +EIGEN_ALWAYS_INLINE void disassembleAccumulators(__vector_quad (&quad_acc)[num_acc], Packet4f (&acc)[num_acc][4]) +{ + BFLOAT16_UNROLL + for(Index k = 0; k < num_acc; k++) + __builtin_mma_disassemble_acc((void*)acc[k], &(quad_acc[k])); +} + +template +EIGEN_ALWAYS_INLINE void outputResults(Packet4f (&acc)[num_acc][4], Index rows, const Packet4f pAlpha, float* result, const Index extra_cols, Index extra_rows) +{ + for(Index i = 0, k = 0; i < num_rhs - (rhsExtraCols ? 1 : 0); i++, result += 4*rows){ + for(Index j = 0; j < num_lhs; j++, k++) { + storeResults(acc[k], rows, pAlpha, result + j*4, extra_cols, extra_rows); } - result += rows; - } while (++x < (rhsExtraCols ? extra_cols : 4)); + } + if(rhsExtraCols) { + storeResults(acc[num_acc - 1], rows, pAlpha, result, extra_cols, extra_rows); + } +} + +template +EIGEN_ALWAYS_INLINE void colLoopBodyIter(Index depth, Index rows, const Packet4f pAlpha, const bfloat16* indexA, const bfloat16* indexB, Index strideB, Index offsetB, float* result, const Index extra_cols, const Index extra_rows) +{ + constexpr Index num_lhs = multiIter ? (num_packets / 4) : 1; + constexpr Index num_rhs = (num_acc + num_lhs - 1) / num_lhs; + + for(Index offset_row = 0; offset_row < num_packets; offset_row += 4, indexA += (multiIter ? 0 : 8), indexB += (multiIter ? (num_rhs*strideB) : 0), result += (multiIter ? (4*rows*num_rhs) : 4)) { + Packet4f acc[num_acc][4]; + __vector_quad quad_acc[num_acc]; + + zeroAccumulators(quad_acc); + + Index k; + for(k = 0; k + 2 <= depth; k += 2){ + KLoop(indexA, indexB, quad_acc, strideB, k, offsetB, extra_cols, extra_rows); + } + if(depth&1){ + KLoop(indexA - (multiIter ? 0 : offset_row), indexB, quad_acc, strideB, k, offsetB, extra_cols, extra_rows); + } + + disassembleAccumulators(quad_acc, acc); + + outputResults(acc, rows, pAlpha, result, extra_cols, extra_rows); + } } #define MAX_BFLOAT16_ACC 8 template -void colLoopBody(Index& col, Index depth, Index cols, Index rows, const Packet4f pAlpha, const bfloat16* indexA, const bfloat16* indexB, Index strideB, Index offsetB, float* result, Index extra_rows) +void colLoopBody(Index& col, Index depth, Index cols, Index rows, const Packet4f pAlpha, const bfloat16* indexA, const bfloat16* indexB, Index strideB, Index offsetB, float* result) { - const Index step = (num_acc * 4); //each accumulator has 4 elements + constexpr Index step = (num_acc * 4); //each accumulator has 4 elements const Index extra_cols = (rhsExtraCols) ? (cols & 3) : 0; + const Index extra_rows = (lhsExtraRows) ? (rows & 3) : 0; + constexpr bool multiIters = !rhsExtraCols && (num_acc == MAX_BFLOAT16_ACC); do{ - for(Index offset_row = 0; offset_row < num_packets; offset_row += 4, indexA += 8, result += 4) { - Index k; - Packet4f acc[num_acc][4]; - __vector_quad quad_acc[num_acc]; - - BFLOAT16_UNROLL - for(k = 0; k < num_acc; k++) - __builtin_mma_xxsetaccz(&(quad_acc[k])); - - for(k = 0; k + 2 <= depth; k += 2){ - KLoop(indexA, indexB, quad_acc, strideB, k, offsetB, extra_cols, extra_rows); - } - if(depth&1){ - KLoop(indexA - offset_row, indexB, quad_acc, strideB, k, offsetB, extra_cols, extra_rows); - } - - BFLOAT16_UNROLL - for(k = 0; k < num_acc; k++) - __builtin_mma_disassemble_acc((void*)acc[k], &(quad_acc[k])); - - for(k = 0; k < (num_acc - 1); k++){ - storeResults(acc[k], rows, pAlpha, result + k*4*rows, extra_cols, extra_rows); - } - storeResults(acc[k], rows, pAlpha, result + k*4*rows, extra_cols, extra_rows); + if (multiIters && ((num_acc % (num_packets / 4)) == 0)) { + colLoopBodyIter(depth, rows, pAlpha, indexA, indexB, strideB, offsetB, result, extra_cols, extra_rows); + } else { + colLoopBodyIter(depth, rows, pAlpha, indexA, indexB, strideB, offsetB, result, extra_cols, extra_rows); } - indexA -= num_packets*2; indexB += strideB*num_acc; - result += (rows*step - num_packets); - } while(!rhsExtraCols && (num_acc == MAX_BFLOAT16_ACC) && (step <= cols - (col += step))); + result += rows*step; + } while(multiIters && (step <= cols - (col += step))); } template -EIGEN_ALWAYS_INLINE void colLoopBodyExtraN(Index col, Index depth, Index cols, Index rows, const Packet4f pAlpha, const bfloat16* indexA, const bfloat16* blockB, Index strideB, Index offsetB, float* result, Index extra_rows) +EIGEN_ALWAYS_INLINE void colLoopBodyExtraN(Index col, Index depth, Index cols, Index rows, const Packet4f pAlpha, const bfloat16* indexA, const bfloat16* blockB, Index strideB, Index offsetB, float* result) { if (MAX_BFLOAT16_ACC > num_acc) { - colLoopBody(col, depth, cols, rows, pAlpha, indexA, blockB, strideB, offsetB, result, extra_rows); + colLoopBody(col, depth, cols, rows, pAlpha, indexA, blockB, strideB, offsetB, result); } } template -void colLoopBodyExtra(Index col, Index depth, Index cols, Index rows, const Packet4f pAlpha, const bfloat16* indexA, const bfloat16* blockB, Index strideB, Index offsetB, float* result, Index extra_rows) +void colLoopBodyExtra(Index col, Index depth, Index cols, Index rows, const Packet4f pAlpha, const bfloat16* indexA, const bfloat16* blockB, Index strideB, Index offsetB, float* result) { switch ((cols - col) >> 2) { case 7: - colLoopBodyExtraN<7, num_packets, rhsExtraCols, lhsExtraRows>(col, depth, cols, rows, pAlpha, indexA, blockB, strideB, offsetB, result, extra_rows); + colLoopBodyExtraN<7, num_packets, rhsExtraCols, lhsExtraRows>(col, depth, cols, rows, pAlpha, indexA, blockB, strideB, offsetB, result); break; case 6: - colLoopBodyExtraN<6, num_packets, rhsExtraCols, lhsExtraRows>(col, depth, cols, rows, pAlpha, indexA, blockB, strideB, offsetB, result, extra_rows); + colLoopBodyExtraN<6, num_packets, rhsExtraCols, lhsExtraRows>(col, depth, cols, rows, pAlpha, indexA, blockB, strideB, offsetB, result); break; case 5: - colLoopBodyExtraN<5, num_packets, rhsExtraCols, lhsExtraRows>(col, depth, cols, rows, pAlpha, indexA, blockB, strideB, offsetB, result, extra_rows); + colLoopBodyExtraN<5, num_packets, rhsExtraCols, lhsExtraRows>(col, depth, cols, rows, pAlpha, indexA, blockB, strideB, offsetB, result); break; case 4: - colLoopBodyExtraN<4, num_packets, rhsExtraCols, lhsExtraRows>(col, depth, cols, rows, pAlpha, indexA, blockB, strideB, offsetB, result, extra_rows); + colLoopBodyExtraN<4, num_packets, rhsExtraCols, lhsExtraRows>(col, depth, cols, rows, pAlpha, indexA, blockB, strideB, offsetB, result); break; case 3: - colLoopBodyExtraN<3, num_packets, rhsExtraCols, lhsExtraRows>(col, depth, cols, rows, pAlpha, indexA, blockB, strideB, offsetB, result, extra_rows); + colLoopBodyExtraN<3, num_packets, rhsExtraCols, lhsExtraRows>(col, depth, cols, rows, pAlpha, indexA, blockB, strideB, offsetB, result); break; case 2: - colLoopBodyExtraN<2, num_packets, rhsExtraCols, lhsExtraRows>(col, depth, cols, rows, pAlpha, indexA, blockB, strideB, offsetB, result, extra_rows); + colLoopBodyExtraN<2, num_packets, rhsExtraCols, lhsExtraRows>(col, depth, cols, rows, pAlpha, indexA, blockB, strideB, offsetB, result); break; case 1: - colLoopBodyExtraN<1, num_packets, rhsExtraCols, lhsExtraRows>(col, depth, cols, rows, pAlpha, indexA, blockB, strideB, offsetB, result, extra_rows); + colLoopBodyExtraN<1, num_packets, rhsExtraCols, lhsExtraRows>(col, depth, cols, rows, pAlpha, indexA, blockB, strideB, offsetB, result); break; default: if (rhsExtraCols) { - colLoopBody<1, num_packets, true, lhsExtraRows>(col, depth, cols, rows, pAlpha, indexA, blockB, strideB, offsetB, result, extra_rows); + colLoopBody<1, num_packets, true, lhsExtraRows>(col, depth, cols, rows, pAlpha, indexA, blockB, strideB, offsetB, result); } break; } } template -EIGEN_ALWAYS_INLINE void colLoops(Index depth, Index cols, Index rows, const Packet4f pAlpha, const bfloat16* indexA, const bfloat16* blockB, Index strideB, Index offsetB, float* result, Index extra_rows = 0) +EIGEN_ALWAYS_INLINE void colLoops(Index depth, Index cols, Index rows, const Packet4f pAlpha, const bfloat16* indexA, const bfloat16* blockB, Index strideB, Index offsetB, float* result) { Index col = 0; if (cols >= (MAX_BFLOAT16_ACC * 4)) { - colLoopBody(col, depth, cols, rows, pAlpha, indexA, blockB, strideB, 0, result, extra_rows); + colLoopBody(col, depth, cols, rows, pAlpha, indexA, blockB, strideB, 0, result); blockB += (strideB >> 2)*col; result += rows*col; } if (cols & 3) { - colLoopBodyExtra(col, depth, cols, rows, pAlpha, indexA, blockB, strideB, offsetB, result, extra_rows); + colLoopBodyExtra(col, depth, cols, rows, pAlpha, indexA, blockB, strideB, offsetB, result); } else { - colLoopBodyExtra(col, depth, cols, rows, pAlpha, indexA, blockB, strideB, 0, result, extra_rows); + colLoopBodyExtra(col, depth, cols, rows, pAlpha, indexA, blockB, strideB, 0, result); } } -EIGEN_ALWAYS_INLINE Packet8bf convertF16toF32(const float *res) +template +EIGEN_ALWAYS_INLINE Packet8bf convertF32toBF16(const float *res) { Packet16uc fp16_0 = __builtin_vsx_xvcvspbf16(reinterpret_cast(ploadu(res + 0))); - Packet16uc fp16_1 = __builtin_vsx_xvcvspbf16(reinterpret_cast(ploadu(res + 4))); + Packet16uc fp16_1 = (full) ? __builtin_vsx_xvcvspbf16(reinterpret_cast(ploadu(res + 4))) : fp16_0; return vec_pack(reinterpret_cast(fp16_0), reinterpret_cast(fp16_1)); } -template -void gemmMMAbfloat16(const DataMapper& res, const bfloat16* blockA, const bfloat16* blockB, Index rows, Index depth, Index cols, bfloat16 alpha, Index strideA, Index strideB, Index offsetA, Index offsetB) +template +EIGEN_ALWAYS_INLINE void storeConvertBlockBF16(float* to, PacketBlock& block) { - if(rows == 0 || cols == 0 || depth == 0) return; - float falpha = Eigen::bfloat16_impl::bfloat16_to_float(alpha); - if (falpha == float(0)) return; - const Packet4f pAlpha = pset1(falpha); - ei_declare_aligned_stack_constructed_variable(float, result, cols*rows, 0); - - typedef typename DataMapper::LinearMapper LinearMapper; Packet8us z = pset1(0); - for(Index j = 0; j < cols; j++){ - const LinearMapper res2 = res.getLinearMapper(0, j); - float *result2 = result + j*rows; + pstore(to + 0, reinterpret_cast(vec_mergeh(z, block.packet[0].m_val))); + if (N >= 8) { + pstore(to + 4, reinterpret_cast(vec_mergel(z, block.packet[0].m_val))); + } + if (N >= 16) { + pstore(to + 8, reinterpret_cast(vec_mergeh(z, block.packet[1].m_val))); + pstore(to + 12, reinterpret_cast(vec_mergel(z, block.packet[1].m_val))); + } + if (N >= 32) { + pstore(to + 16, reinterpret_cast(vec_mergeh(z, block.packet[2].m_val))); + pstore(to + 20, reinterpret_cast(vec_mergel(z, block.packet[2].m_val))); + pstore(to + 24, reinterpret_cast(vec_mergeh(z, block.packet[3].m_val))); + pstore(to + 28, reinterpret_cast(vec_mergel(z, block.packet[3].m_val))); + } +} + +template +EIGEN_ALWAYS_INLINE void convertBF16toF32(Index& i, float *result, Index rows, const DataMapper& src) +{ + for(; i + size <= rows; i += size){ + PacketBlock r32; + r32.packet[0] = src.template loadPacket(i + 0); + if (size >= 16) { + r32.packet[1] = src.template loadPacket(i + 8); + } + if (size >= 32) { + r32.packet[2] = src.template loadPacket(i + 16); + r32.packet[3] = src.template loadPacket(i + 24); + } + storeConvertBlockBF16(result + i, r32); + } +} + +template +EIGEN_ALWAYS_INLINE void convertArrayBF16toF32(float *result, Index cols, Index rows, const DataMapper& src) +{ + typedef typename DataMapper::LinearMapper LinearMapper; + for(Index j = 0; j < cols; j++, result += rows){ + const LinearMapper src2 = src.getLinearMapper(0, j); Index i = 0; - for(; i + 32 <= rows; i+=32){ - Packet8us r32_0 = res2.template loadPacket(i + 0).m_val; - Packet8us r32_1 = res2.template loadPacket(i + 8).m_val; - Packet8us r32_2 = res2.template loadPacket(i + 16).m_val; - Packet8us r32_3 = res2.template loadPacket(i + 24).m_val; - pstore(result2 + i + 0, reinterpret_cast(vec_mergeh(z, r32_0))); - pstore(result2 + i + 4, reinterpret_cast(vec_mergel(z, r32_0))); - pstore(result2 + i + 8, reinterpret_cast(vec_mergeh(z, r32_1))); - pstore(result2 + i + 12, reinterpret_cast(vec_mergel(z, r32_1))); - pstore(result2 + i + 16, reinterpret_cast(vec_mergeh(z, r32_2))); - pstore(result2 + i + 20, reinterpret_cast(vec_mergel(z, r32_2))); - pstore(result2 + i + 24, reinterpret_cast(vec_mergeh(z, r32_3))); - pstore(result2 + i + 28, reinterpret_cast(vec_mergel(z, r32_3))); - } - for(; i + 16 <= rows; i+=16){ - Packet8us r32_0 = res2.template loadPacket(i + 0).m_val; - Packet8us r32_1 = res2.template loadPacket(i + 8).m_val; - pstore(result2 + i + 0, reinterpret_cast(vec_mergeh(z, r32_0))); - pstore(result2 + i + 4, reinterpret_cast(vec_mergel(z, r32_0))); - pstore(result2 + i + 8, reinterpret_cast(vec_mergeh(z, r32_1))); - pstore(result2 + i + 12, reinterpret_cast(vec_mergel(z, r32_1))); - } - for(; i + 8 <= rows; i+=8){ - Packet8us r32_0 = res2.template loadPacket(i + 0).m_val; - pstore(result2 + i + 0, reinterpret_cast(vec_mergeh(z, r32_0))); - pstore(result2 + i + 4, reinterpret_cast(vec_mergel(z, r32_0))); - } - for(; i + 4 <= rows; i+=4){ - Packet8us r32_0 = res2.template loadPacketPartial(i + 0, 4).m_val; - pstore(result2 + i + 0, reinterpret_cast(vec_mergeh(z, r32_0))); - } + convertBF16toF32<32, LinearMapper>(i, result, rows, src2); + convertBF16toF32<16, LinearMapper>(i, result, rows, src2); + convertBF16toF32<8, LinearMapper>(i, result, rows, src2); + convertBF16toF32<4, LinearMapper>(i, result, rows, src2); for(; i < rows; i++){ - result2[i] = res2(i); + result[i] = Eigen::bfloat16_impl::bfloat16_to_float(src2(i)); } } +} - Index row = 0; - Index col; - - if( strideA == -1 ) strideA = depth; - if( strideB == -1 ) strideB = depth; - //Packing is done in blocks. - //There's 4 possible sizes of blocks - //Blocks of 8 columns with 16 elements (8x16) - //Blocks of 8 columns with 8 elements (8x8). This happens when there's 16 > rows >= 8 - //Blocks of 8 columns with 4 elements (8x4). This happens when there's 8 > rows >= 4 - //Blocks of 8 columns with < 4 elements. This happens when there's less than 4 remaining rows - - //Loop for LHS standard block (8x16) - const Index standard_block_size = 16; - const Index standard_blocks_quantity = rows/standard_block_size; //Number of standard blocks - Index bigSuffix = (2*8) * (strideA-offsetA); - const bfloat16* indexA = blockA; - const bfloat16* indexB = blockB + 4*offsetB; - Index block_index; - strideB *= 4; - offsetB *= 3; - for(block_index = 0; block_index < standard_blocks_quantity; block_index++){ - indexA += 2*8*offsetA; - colLoops<16>(depth, cols, rows, pAlpha, indexA, indexB, strideB, offsetB, result + row); - row += 16; - indexA += bigSuffix; - } - //LHS (8x8) block - if(rows & 8){ - indexA += 1*8*offsetA; - colLoops<8>(depth, cols, rows, pAlpha, indexA, indexB, strideB, offsetB, result + row); - row += 8; - indexA += (bigSuffix >> 1); - } - //LHS (8x4) block - if(rows & 4){ - indexA += 1*4*offsetA; - colLoops<4>(depth, cols, rows, pAlpha, indexA, indexB, strideB, offsetB, result + row); - row += 4; - indexA += (bigSuffix >> 2); - } - //extra rows - Index extra_rows = rows & 3; - if(extra_rows){ - //This index is the beginning of remaining block. - colLoops<4, true>(depth, cols, rows, pAlpha, indexA, indexB, strideB, offsetB, result + row, extra_rows); - } - - //Convert back to bfloat16 +template +EIGEN_ALWAYS_INLINE void convertArrayF32toBF16(float *result, Index cols, Index rows, const DataMapper& res) +{ + typedef typename DataMapper::LinearMapper LinearMapper; + Index col, row; for(col = 0; col + 4 <= cols; col += 4){ const DataMapper res2 = res.getSubMapper(0, col); for(row = 0; row + 8 <= rows; row += 8){ //get and save block PacketBlock block; for(Index j = 0; j < 4; j++){ - block.packet[j].m_val = convertF16toF32(result + (col + j)*rows + row); + block.packet[j].m_val = convertF32toBF16(result + (col + j)*rows + row); } res2.template storePacketBlock(row, 0, block); @@ -303,18 +330,70 @@ void gemmMMAbfloat16(const DataMapper& res, const bfloat16* blockA, const bfloat while(col < cols){ const LinearMapper res2 = res.getLinearMapper(0, col); float *result2 = result + col*rows; - Index r = 0; - for(; r + 8 <= rows; r += 8){ - Packet8bf fp16 = convertF16toF32(result2 + r); - res2.template storePacket(r, fp16); + for(row = 0; row + 8 <= rows; row += 8){ + Packet8bf fp16 = convertF32toBF16(result2 + row); + res2.template storePacket(row, fp16); } - for(; r< rows; r++){ - res2(r) = Eigen::bfloat16(result2[r]); + for(; row < rows; row++){ + res2(row) = Eigen::bfloat16(result2[row]); } col++; } } +template +EIGEN_ALWAYS_INLINE void calcColLoops(const bfloat16*& indexA, Index& row, Index depth, Index cols, Index rows, const Packet4f pAlpha, const bfloat16* indexB, Index strideB, Index offsetA, Index offsetB, Index bigSuffix, float *result) +{ + if ((size == 16) || (rows & size)) { + indexA += size*offsetA; + colLoops(depth, cols, rows, pAlpha, indexA, indexB, strideB, offsetB, result + row); + row += size; + indexA += bigSuffix*size/16; + } +} + +template +void gemmMMAbfloat16(const DataMapper& res, const bfloat16* indexA, const bfloat16* indexB, Index rows, Index depth, Index cols, bfloat16 alpha, Index strideA, Index strideB, Index offsetA, Index offsetB) +{ + float falpha = Eigen::bfloat16_impl::bfloat16_to_float(alpha); + if (falpha == float(0)) return; + const Packet4f pAlpha = pset1(falpha); + ei_declare_aligned_stack_constructed_variable(float, result, cols*rows, 0); + + convertArrayBF16toF32(result, cols, rows, res); + + Index row = 0; + + if( strideA == -1 ) strideA = depth; + if( strideB == -1 ) strideB = depth; + //Packing is done in blocks. + //There's 4 possible sizes of blocks + //Blocks of 8 columns with 16 elements (8x16) + //Blocks of 8 columns with 8 elements (8x8). This happens when there's 16 > rows >= 8 + //Blocks of 8 columns with 4 elements (8x4). This happens when there's 8 > rows >= 4 + //Blocks of 8 columns with < 4 elements. This happens when there's less than 4 remaining rows + + //Loop for LHS standard block (8x16) + Index bigSuffix = (2*8) * (strideA-offsetA); + indexB += 4*offsetB; + strideB *= 4; + offsetB *= 3; + while(row + 16 <= rows){ + calcColLoops<16>(indexA, row, depth, cols, rows, pAlpha, indexB, strideB, offsetA, offsetB, bigSuffix, result); + } + //LHS (8x8) block + calcColLoops<8>(indexA, row, depth, cols, rows, pAlpha, indexB, strideB, offsetA, offsetB, bigSuffix, result); + //LHS (8x4) block + calcColLoops<4>(indexA, row, depth, cols, rows, pAlpha, indexB, strideB, offsetA, offsetB, bigSuffix, result); + //extra rows + if(rows & 3){ + //This index is the beginning of remaining block. + colLoops<4, true>(depth, cols, rows, pAlpha, indexA, indexB, strideB, offsetB, result + row); + } + + //Convert back to bfloat16 + convertArrayF32toBF16(result, cols, rows, res); +} } }