From 9d724123852b399979e72896987d81a9826386d0 Mon Sep 17 00:00:00 2001 From: Chip Kerchner Date: Mon, 13 Mar 2023 19:37:13 +0000 Subject: [PATCH] Add MMA to BF16 GEMV - 5.0-6.3X faster (for Power) --- .../arch/AltiVec/MatrixProductMMAbfloat16.h | 487 +++++++++++++++++- .../Core/arch/AltiVec/MatrixVectorProduct.h | 33 ++ 2 files changed, 517 insertions(+), 3 deletions(-) diff --git a/Eigen/src/Core/arch/AltiVec/MatrixProductMMAbfloat16.h b/Eigen/src/Core/arch/AltiVec/MatrixProductMMAbfloat16.h index 976a73ff5..4774587f5 100644 --- a/Eigen/src/Core/arch/AltiVec/MatrixProductMMAbfloat16.h +++ b/Eigen/src/Core/arch/AltiVec/MatrixProductMMAbfloat16.h @@ -146,8 +146,8 @@ EIGEN_ALWAYS_INLINE void colLoopBodyIter(Index depth, Index rows, const Packet4f zeroAccumulators(quad_acc); - Index k; - for(k = 0; k + 2 <= depth; k += 2){ + Index k = 0; + for(Index j = depth >> 1; j--; k += 2){ KLoop(indexA, indexB, quad_acc, strideB, k, offsetB, extra_cols, extra_rows); } if(depth&1){ @@ -356,7 +356,6 @@ template void gemmMMAbfloat16(const DataMapper& res, const bfloat16* indexA, const bfloat16* indexB, Index rows, Index depth, Index cols, bfloat16 alpha, Index strideA, Index strideB, Index offsetA, Index offsetB) { float falpha = Eigen::bfloat16_impl::bfloat16_to_float(alpha); - if (falpha == float(0)) return; const Packet4f pAlpha = pset1(falpha); ei_declare_aligned_stack_constructed_variable(float, result, cols*rows, 0); @@ -395,6 +394,488 @@ void gemmMMAbfloat16(const DataMapper& res, const bfloat16* indexA, const bfloat convertArrayF32toBF16(result, cols, rows, res); } +template +EIGEN_ALWAYS_INLINE void storeBF16fromResult(bfloat16* dst, Packet8bf data, Index resInc) +{ + if (inc) { + if (size == 4) { + pscatter_partial(dst + delta*resInc, data, resInc, 4); + } else { + pscatter(dst + delta*resInc, data, resInc); + } + } else { + if (size == 4) { + pstoreu_partial(dst + delta, data, 4); + } else { + pstoreu(dst + delta, data); + } + } +} + +template +EIGEN_ALWAYS_INLINE void convertPointerF32toBF16(Index& i, float* result, Index rows, bfloat16*& dst, Index resInc) +{ + for(Index j = (rows - i) / size; j--; i += size, dst += size*resInc){ + PacketBlock r32; + r32.packet[0] = convertF32toBF16(result + i + 0); + if (size >= 16) { + r32.packet[1] = convertF32toBF16(result + i + 8); + } + if (size >= 32) { + r32.packet[2] = convertF32toBF16(result + i + 16); + r32.packet[3] = convertF32toBF16(result + i + 24); + } + storeBF16fromResult(dst, r32.packet[0], resInc); + if (size >= 16) { + storeBF16fromResult(dst, r32.packet[1], resInc); + } + if (size >= 32) { + storeBF16fromResult(dst, r32.packet[2], resInc); + storeBF16fromResult(dst, r32.packet[3], resInc); + } + } +} + +template +EIGEN_ALWAYS_INLINE Packet8bf loadBF16fromResult(bfloat16* src, Index resInc) +{ + if (inc) { + return pgather(src + delta*resInc, resInc); + } else { + return ploadu(src + delta); + } +} + +template +EIGEN_ALWAYS_INLINE void convertPointerBF16toF32(Index& i, float *result, Index rows, bfloat16*& src, Index resInc) +{ + for(; i + size <= rows; i += size, src += size*resInc){ + PacketBlock r32; + r32.packet[0] = loadBF16fromResult(src, resInc); + if (size >= 16) { + r32.packet[1] = loadBF16fromResult(src, resInc); + } + if (size >= 32) { + r32.packet[2] = loadBF16fromResult(src, resInc); + r32.packet[3] = loadBF16fromResult(src, resInc); + } + storeConvertBlockBF16(result + i, r32); + } +} + +template +EIGEN_ALWAYS_INLINE void convertArrayPointerBF16toF32(float *result, Index rows, bfloat16* src, Index resInc = 1) +{ + Index i = 0; + convertPointerBF16toF32<32, inc>(i, result, rows, src, resInc); + convertPointerBF16toF32<16, inc>(i, result, rows, src, resInc); + convertPointerBF16toF32<8, inc>(i, result, rows, src, resInc); + convertPointerBF16toF32<4, inc>(i, result, rows, src, resInc); + for(; i < rows; i++, src += resInc){ + result[i] = Eigen::bfloat16_impl::bfloat16_to_float(*src); + } +} + +template +EIGEN_ALWAYS_INLINE void convertArrayPointerF32toBF16(float *result, Index rows, bfloat16* dst, Index resInc = 1) +{ + Index i = 0; + convertPointerF32toBF16<32,inc>(i, result, rows, dst, resInc); + convertPointerF32toBF16<16,inc>(i, result, rows, dst, resInc); + convertPointerF32toBF16<8,inc>(i, result, rows, dst, resInc); + convertPointerF32toBF16<4,inc>(i, result, rows, dst, resInc); + for(; i < rows; i++, dst += resInc){ + *dst = Eigen::bfloat16(result[i]); + } +} + +template +EIGEN_ALWAYS_INLINE void outputVecCol(Packet4f acc, float *result, Packet4f pAlpha, Index extra_rows) +{ + Packet4f d0 = ploadu(result); + d0 = pmadd(acc, pAlpha, d0); + if (extraRows) { + pstoreu_partial(result, d0, extra_rows); + } else { + pstoreu(result, d0); + } +} + +template +EIGEN_ALWAYS_INLINE void outputVecColResults(Packet4f (&acc)[num_acc][4], float *result, Packet4f pAlpha, Index extra_rows) +{ + for(Index k = 0; k < num_acc - (extraRows ? 1 : 0); k++) { + outputVecCol(acc[k][0], result + k*4, pAlpha, extra_rows); + } + if (extraRows) { + outputVecCol(acc[num_acc - 1][0], result + (num_acc - 1)*4, pAlpha, extra_rows); + } +} + +template +EIGEN_ALWAYS_INLINE void loadVecLoop(Index k, LhsMapper& lhs, Packet8bf (&a0)[num_acc], Packet8bf b1) +{ + a0[k + 0] = lhs.template loadPacket(k*4, 0); + if (!zero) { + b1 = lhs.template loadPacket(k*4, 1); + } + if (num_acc > (k + 1)) { + a0[k + 1] = vec_mergel(a0[k + 0].m_val, b1.m_val); + } + a0[k + 0] = vec_mergeh(a0[k + 0].m_val, b1.m_val); +} + +template +EIGEN_ALWAYS_INLINE void multVec(__vector_quad (&quad_acc)[num_acc], Packet8bf (&a0)[num_acc], Packet8bf b0) +{ + BFLOAT16_UNROLL + for(Index k = 0; k < num_acc; k++) { + __builtin_mma_xvbf16ger2pp(&(quad_acc[k]), reinterpret_cast(b0.m_val), reinterpret_cast(a0[k].m_val)); + } +} + +template +EIGEN_ALWAYS_INLINE void vecColLoop(Index j, LhsMapper& lhs, RhsMapper& rhs, __vector_quad (&quad_acc)[num_acc]) +{ + Packet8bf a0[num_acc]; + Packet8bf b1 = pset1(Eigen::bfloat16(0)); + Packet8bf b0 = rhs.template loadPacket(j + 0); + + if (zero) { + b0 = vec_mergeh(b0.m_val, b1.m_val); + } + + LhsMapper lhs2 = lhs.getSubMapper(0, j); + for(Index k = 0; k < num_acc; k += 2) { + loadVecLoop(k, lhs2, a0, b1); + } + + multVec(quad_acc, a0, b0); +} + +#define MAX_BFLOAT16_VEC_ACC 8 + +template +void colVecColLoopBody(Index& row, Index cend, Index rows, LhsMapper& lhs, RhsMapper& rhs, const Packet4f pAlpha, float *result) +{ + constexpr Index step = (num_acc * 4); + const Index extra_rows = (extraRows) ? (rows & 3) : 0; + constexpr bool multiIters = !extraRows && (num_acc == MAX_BFLOAT16_VEC_ACC); + + do{ + Packet4f acc[num_acc][4]; + __vector_quad quad_acc[num_acc]; + + zeroAccumulators(quad_acc); + + LhsMapper lhs2 = lhs.getSubMapper(row, 0); + Index j = 0; + for(Index k = cend >> 1; k--; j += 2) { + vecColLoop(j, lhs2, rhs, quad_acc); + } + if (cend & 1) { + vecColLoop(j, lhs2, rhs, quad_acc); + } + + disassembleAccumulators(quad_acc, acc); + + outputVecColResults(acc, result, pAlpha, extra_rows); + + result += step; + } while(multiIters && (step <= rows - (row += step))); +} + +template +EIGEN_ALWAYS_INLINE void colVecColLoopBodyExtraN(Index& row, Index cend, Index rows, LhsMapper& lhs, RhsMapper& rhs, const Packet4f pAlpha, float *result) +{ + if (MAX_BFLOAT16_VEC_ACC > num_acc) { + colVecColLoopBody(row, cend, rows, lhs, rhs, pAlpha, result); + } +} + +template +EIGEN_ALWAYS_INLINE void colVecColLoopBodyExtra(Index& row, Index cend, Index rows, LhsMapper& lhs, RhsMapper& rhs, const Packet4f pAlpha, float *result) +{ + switch ((rows - row) >> 2) { + case 7: + colVecColLoopBodyExtraN<7, LhsMapper, RhsMapper, extraRows>(row, cend, rows, lhs, rhs, pAlpha, result); + break; + case 6: + colVecColLoopBodyExtraN<6, LhsMapper, RhsMapper, extraRows>(row, cend, rows, lhs, rhs, pAlpha, result); + break; + case 5: + colVecColLoopBodyExtraN<5, LhsMapper, RhsMapper, extraRows>(row, cend, rows, lhs, rhs, pAlpha, result); + break; + case 4: + colVecColLoopBodyExtraN<4, LhsMapper, RhsMapper, extraRows>(row, cend, rows, lhs, rhs, pAlpha, result); + break; + case 3: + colVecColLoopBodyExtraN<3, LhsMapper, RhsMapper, extraRows>(row, cend, rows, lhs, rhs, pAlpha, result); + break; + case 2: + colVecColLoopBodyExtraN<2, LhsMapper, RhsMapper, extraRows>(row, cend, rows, lhs, rhs, pAlpha, result); + break; + case 1: + colVecColLoopBodyExtraN<1, LhsMapper, RhsMapper, extraRows>(row, cend, rows, lhs, rhs, pAlpha, result); + break; + default: + if (extraRows) { + colVecColLoopBody<1, LhsMapper, RhsMapper, true>(row, cend, rows, lhs, rhs, pAlpha, result); + } + break; + } +} + +template +EIGEN_ALWAYS_INLINE void calcVecColLoops(Index cend, Index rows, LhsMapper& lhs, RhsMapper& rhs, const Packet4f pAlpha, float *result) +{ + Index row = 0; + if (rows >= (MAX_BFLOAT16_VEC_ACC * 4)) { + colVecColLoopBody(row, cend, rows, lhs, rhs, pAlpha, result); + result += row; + } + if (rows & 3) { + colVecColLoopBodyExtra(row, cend, rows, lhs, rhs, pAlpha, result); + } else { + colVecColLoopBodyExtra(row, cend, rows, lhs, rhs, pAlpha, result); + } +} + +template +void gemvMMA_bfloat16_col( + Index rows, Index cols, + const LhsMapper& alhs, + const RhsMapper& rhs, + bfloat16* res, Index resIncr, + bfloat16 alpha) +{ + typedef typename RhsMapper::LinearMapper LinearMapper; + + EIGEN_UNUSED_VARIABLE(resIncr); + eigen_internal_assert(resIncr == 1); + + // The following copy tells the compiler that lhs's attributes are not modified outside this function + // This helps GCC to generate proper code. + LhsMapper lhs(alhs); + RhsMapper rhs2(rhs); + + const Index lhsStride = lhs.stride(); + + // TODO: improve the following heuristic: + const Index block_cols = cols < 128 ? cols : (lhsStride * sizeof(bfloat16) < 16000 ? 16 : 8); + float falpha = Eigen::bfloat16_impl::bfloat16_to_float(alpha); + Packet4f pAlpha = pset1(falpha); + + ei_declare_aligned_stack_constructed_variable(float, result, rows, 0); + + convertArrayPointerBF16toF32(result, rows, res); + + for (Index j2 = 0; j2 < cols; j2 += block_cols) + { + Index jend = numext::mini(j2 + block_cols, cols); + + LhsMapper lhs2 = lhs.getSubMapper(0, j2); + LinearMapper rhs3 = rhs2.getLinearMapper(j2, 0); + calcVecColLoops(jend - j2, rows, lhs2, rhs3, pAlpha, result); + } + + convertArrayPointerF32toBF16(result, rows, res); +} + +static Packet16uc p16uc_ELEMENT_VEC3 = { 0x0c,0x0d,0x0e,0x0f, 0x1c,0x1d,0x1e,0x1f, 0x0c,0x0d,0x0e,0x0f, 0x1c,0x1d,0x1e,0x1f }; + +template +EIGEN_ALWAYS_INLINE void preduxVecResults2(Packet4f (&acc)[num_acc][4], Index k) +{ + if (num_acc > (k + 1)) { + acc[k][0] = vec_mergeh(acc[k][0], acc[k + 1][0]); + acc[k][1] = vec_mergeo(acc[k][1], acc[k + 1][1]); + acc[k][2] = vec_mergel(acc[k][2], acc[k + 1][2]); + acc[k][3] = vec_perm(acc[k][3], acc[k + 1][3], p16uc_ELEMENT_VEC3); + + acc[k][0] = (acc[k][0] + acc[k][2]) + (acc[k][1] + acc[k][3]); + } else { + acc[k][0] = vec_mergeh(acc[k][0], acc[k][1]); + acc[k][0] += vec_mergel(acc[k][2], acc[k][3]); +#ifdef _BIG_ENDIAN + acc[k][0] += vec_sld(acc[k][0], acc[k][0], 12); +#else + acc[k][0] += vec_sld(acc[k][0], acc[k][0], 4); +#endif + } +} + +template +EIGEN_ALWAYS_INLINE void preduxVecResults(Packet4f (&acc)[num_acc][4]) +{ + for(Index k = 0; k < num_acc; k += 4) { + preduxVecResults2(acc, k + 0); + if (num_acc > (k + 2)) { + preduxVecResults2(acc, k + 2); + acc[k + 0][0] = reinterpret_cast(vec_mergeh(reinterpret_cast(acc[k + 0][0]), reinterpret_cast(acc[k + 2][0]))); + } + } +} + +template +EIGEN_ALWAYS_INLINE void outputVecResults(Packet4f (&acc)[num_acc][4], float *result, Packet4f pAlpha) +{ + constexpr Index extra = num_acc & 3; + + for(Index k = 0; k < num_acc; k += 4) { + Packet4f d0 = ploadu(result + k); + d0 = pmadd(acc[k + 0][0], pAlpha, d0); + + if (num_acc > (k + 3)) { + pstoreu(result + k, d0); + } else { + if (extra == 3) { + pstoreu_partial(result + k, d0, extra); + } else if (extra == 2) { + Packet2ul d1 = reinterpret_cast(d0); + *(unsigned long long *)(result + k) = d1[0]; + } else { + Packet4i d1 = reinterpret_cast(d0); + *(unsigned int *)(result + k) = d1[0]; + } + } + } +} + +template +EIGEN_ALWAYS_INLINE void multVecLoop(__vector_quad (&quad_acc)[num_acc], const LhsMapper& lhs, RhsMapper& rhs, Index j, Index extra_cols) +{ + Packet8bf a0[num_acc], b0; + + if (extra) { + b0 = rhs.template loadPacketPartial(j, extra_cols); + } else { + b0 = rhs.template loadPacket(j); + } + + const LhsMapper lhs2 = lhs.getSubMapper(0, j); + for(Index k = 0; k < num_acc; k++) { + if (extra) { + a0[k] = lhs2.template loadPacketPartial(k, 0, extra_cols); + } else { + a0[k] = lhs2.template loadPacket(k, 0); + } + } + + multVec(quad_acc, a0, b0); +} + +template +EIGEN_ALWAYS_INLINE void vecLoop(Index cols, const LhsMapper& lhs, RhsMapper& rhs, __vector_quad (&quad_acc)[num_acc], Index extra_cols) +{ + Index j = 0; + for(Index k = cols >> 3; k--; j += 8) { + multVecLoop(quad_acc, lhs, rhs, j, extra_cols); + } + + if (extra_cols) { + multVecLoop(quad_acc, lhs, rhs, j, extra_cols); + } +} + +template +void colVecLoopBody(Index& row, Index cols, Index rows, LhsMapper& lhs, RhsMapper& rhs, const Packet4f pAlpha, float *result) +{ + constexpr bool multiIters = (num_acc == MAX_BFLOAT16_VEC_ACC); + const Index extra_cols = (cols & 7); + + do{ + Packet4f acc[num_acc][4]; + __vector_quad quad_acc[num_acc]; + + zeroAccumulators(quad_acc); + + const LhsMapper lhs2 = lhs.getSubMapper(row, 0); + vecLoop(cols, lhs2, rhs, quad_acc, extra_cols); + + disassembleAccumulators(quad_acc, acc); + + preduxVecResults(acc); + + outputVecResults(acc, result, pAlpha); + + result += num_acc; + } while(multiIters && (num_acc <= rows - (row += num_acc))); +} + +template +EIGEN_ALWAYS_INLINE void colVecLoopBodyExtra(Index& row, Index cols, Index rows, LhsMapper& lhs, RhsMapper& rhs, const Packet4f pAlpha, float *result) +{ + switch (rows - row) { + case 7: + colVecLoopBody<7, LhsMapper, RhsMapper>(row, cols, rows, lhs, rhs, pAlpha, result); + break; + case 6: + colVecLoopBody<6, LhsMapper, RhsMapper>(row, cols, rows, lhs, rhs, pAlpha, result); + break; + case 5: + colVecLoopBody<5, LhsMapper, RhsMapper>(row, cols, rows, lhs, rhs, pAlpha, result); + break; + case 4: + colVecLoopBody<4, LhsMapper, RhsMapper>(row, cols, rows, lhs, rhs, pAlpha, result); + break; + case 3: + colVecLoopBody<3, LhsMapper, RhsMapper>(row, cols, rows, lhs, rhs, pAlpha, result); + break; + case 2: + colVecLoopBody<2, LhsMapper, RhsMapper>(row, cols, rows, lhs, rhs, pAlpha, result); + break; + case 1: + colVecLoopBody<1, LhsMapper, RhsMapper>(row, cols, rows, lhs, rhs, pAlpha, result); + break; + } +} + +template +EIGEN_ALWAYS_INLINE void calcVecLoops(Index cols, Index rows, LhsMapper& lhs, RhsMapper& rhs, const Packet4f pAlpha, float *result) +{ + Index row = 0; + if (rows >= MAX_BFLOAT16_VEC_ACC) { + colVecLoopBody(row, cols, rows, lhs, rhs, pAlpha, result); + result += row; + } + colVecLoopBodyExtra(row, cols, rows, lhs, rhs, pAlpha, result); +} + +template +EIGEN_STRONG_INLINE void gemvMMA_bfloat16_row( + Index rows, Index cols, + const LhsMapper& alhs, + const RhsMapper& rhs, + bfloat16* res, Index resIncr, + bfloat16 alpha) +{ + typedef typename RhsMapper::LinearMapper LinearMapper; + + // The following copy tells the compiler that lhs's attributes are not modified outside this function + // This helps GCC to generate proper code. + LhsMapper lhs(alhs); + LinearMapper rhs2 = rhs.getLinearMapper(0, 0); + + eigen_internal_assert(rhs.stride() == 1); + + float falpha = Eigen::bfloat16_impl::bfloat16_to_float(alpha); + const Packet4f pAlpha = pset1(falpha); + + ei_declare_aligned_stack_constructed_variable(float, result, rows, 0); + if (resIncr == 1) { + convertArrayPointerBF16toF32(result, rows, res); + } else { + convertArrayPointerBF16toF32(result, rows, res, resIncr); + } + calcVecLoops(cols, rows, lhs, rhs2, pAlpha, result); + if (resIncr == 1) { + convertArrayPointerF32toBF16(result, rows, res); + } else { + convertArrayPointerF32toBF16(result, rows, res, resIncr); + } +} + } } #endif //EIGEN_MATRIX_PRODUCT_MMA_BFLOAT16_ALTIVEC_H diff --git a/Eigen/src/Core/arch/AltiVec/MatrixVectorProduct.h b/Eigen/src/Core/arch/AltiVec/MatrixVectorProduct.h index bb84ac977..e107335b7 100644 --- a/Eigen/src/Core/arch/AltiVec/MatrixVectorProduct.h +++ b/Eigen/src/Core/arch/AltiVec/MatrixVectorProduct.h @@ -2061,6 +2061,39 @@ EIGEN_POWER_GEMV_REAL_SPECIALIZE_COL(double) EIGEN_POWER_GEMV_REAL_SPECIALIZE_ROW(float) EIGEN_POWER_GEMV_REAL_SPECIALIZE_ROW(double) +#ifdef USE_GEMV_MMA +#define EIGEN_POWER_GEMV_REAL_SPECIALIZE_COL_BFLOAT16() \ +template \ +struct general_matrix_vector_product \ +{ \ + EIGEN_DEVICE_FUNC EIGEN_DONT_INLINE static void run( \ + Index rows, Index cols, \ + const LhsMapper& lhs, \ + const RhsMapper& rhs, \ + bfloat16* res, Index resIncr, \ + bfloat16 alpha) { \ + gemvMMA_bfloat16_col(rows, cols, lhs, rhs, res, resIncr, alpha); \ + } \ +}; + +#define EIGEN_POWER_GEMV_REAL_SPECIALIZE_ROW_BFLOAT16() \ +template \ +struct general_matrix_vector_product \ +{ \ + EIGEN_DEVICE_FUNC EIGEN_DONT_INLINE static void run( \ + Index rows, Index cols, \ + const LhsMapper& lhs, \ + const RhsMapper& rhs, \ + bfloat16* res, Index resIncr, \ + bfloat16 alpha) { \ + gemvMMA_bfloat16_row(rows, cols, lhs, rhs, res, resIncr, alpha); \ + } \ +}; + +EIGEN_POWER_GEMV_REAL_SPECIALIZE_COL_BFLOAT16() +EIGEN_POWER_GEMV_REAL_SPECIALIZE_ROW_BFLOAT16() +#endif + template EIGEN_ALWAYS_INLINE ScalarBlock predux_complex(PResPacket& a0, PResPacket& b0, ResPacket& a1, ResPacket& b1) {