From 03f646b7e352a68f26fff252615bedb4a2b359aa Mon Sep 17 00:00:00 2001 From: Chip Kerchner Date: Fri, 21 Apr 2023 17:06:59 +0000 Subject: [PATCH] New VSX version of BF16 GEMV (Power) - up to 6.7X faster --- Eigen/src/Core/arch/AltiVec/MatrixProduct.h | 59 +-- .../Core/arch/AltiVec/MatrixProductCommon.h | 6 + .../arch/AltiVec/MatrixProductMMAbfloat16.h | 67 +-- .../Core/arch/AltiVec/MatrixVectorProduct.h | 498 +++++++++++++++++- 4 files changed, 525 insertions(+), 105 deletions(-) diff --git a/Eigen/src/Core/arch/AltiVec/MatrixProduct.h b/Eigen/src/Core/arch/AltiVec/MatrixProduct.h index f4e7e8dbd..5545234f9 100644 --- a/Eigen/src/Core/arch/AltiVec/MatrixProduct.h +++ b/Eigen/src/Core/arch/AltiVec/MatrixProduct.h @@ -2920,13 +2920,13 @@ EIGEN_ALWAYS_INLINE void convertArrayPointerBF16toF32(float *result, Index cols, } } -template -EIGEN_ALWAYS_INLINE void zeroAccumulators(Packet4f (&acc)[num_acc][4]) +template +EIGEN_ALWAYS_INLINE void zeroAccumulators(Packet4f (&acc)[num_acc][size]) { Packet4f z = pset1(float(0)); for(Index k = 0; k < num_acc; k++) { - for(Index j = 0; j < 4; j++) { + for(Index j = 0; j < size; j++) { acc[k][j] = z; } } @@ -3246,59 +3246,6 @@ void gemmbfloat16(const DataMapper& res, const bfloat16* indexA, const bfloat16* #include "MatrixVectorProduct.h" -template -EIGEN_ALWAYS_INLINE void storeBF16fromResult(bfloat16* dst, Packet8bf data, Index resInc, Index extra) -{ - if (non_unit_stride) { - if (size < 8) { - pscatter_partial(dst + delta*resInc, data, resInc, extra); - } else { - pscatter(dst + delta*resInc, data, resInc); - } - } else { - if (size < 8) { - pstoreu_partial(dst + delta, data, extra); - } else { - pstoreu(dst + delta, data); - } - } -} - -template -EIGEN_ALWAYS_INLINE void convertPointerF32toBF16VSX(Index& i, float* result, Index rows, bfloat16*& dst, Index resInc = 1) -{ - constexpr Index extra = ((size < 8) ? 8 : size); - for(; i + size <= rows; i += extra, dst += extra*resInc){ - PacketBlock r32; - r32.packet[0] = convertF32toBF16VSX(result + i + 0); - if (size >= 16) { - r32.packet[1] = convertF32toBF16VSX(result + i + 8); - } - if (size >= 32) { - r32.packet[2] = convertF32toBF16VSX(result + i + 16); - r32.packet[3] = convertF32toBF16VSX(result + i + 24); - } - storeBF16fromResult(dst, r32.packet[0], resInc, rows & 7); - if (size >= 16) { - storeBF16fromResult(dst, r32.packet[1], resInc); - } - if (size >= 32) { - storeBF16fromResult(dst, r32.packet[2], resInc); - storeBF16fromResult(dst, r32.packet[3], resInc); - } - } -} - -template -EIGEN_ALWAYS_INLINE void convertArrayPointerF32toBF16VSX(float *result, Index rows, bfloat16* dst, Index resInc = 1) -{ - Index i = 0; - convertPointerF32toBF16VSX<32,non_unit_stride>(i, result, rows, dst, resInc); - convertPointerF32toBF16VSX<16,non_unit_stride>(i, result, rows, dst, resInc); - convertPointerF32toBF16VSX<8,non_unit_stride>(i, result, rows, dst, resInc); - convertPointerF32toBF16VSX<1,non_unit_stride>(i, result, rows, dst, resInc); -} - /************************************ * ppc64le template specializations * * **********************************/ diff --git a/Eigen/src/Core/arch/AltiVec/MatrixProductCommon.h b/Eigen/src/Core/arch/AltiVec/MatrixProductCommon.h index e89b5e557..fe135dc13 100644 --- a/Eigen/src/Core/arch/AltiVec/MatrixProductCommon.h +++ b/Eigen/src/Core/arch/AltiVec/MatrixProductCommon.h @@ -96,6 +96,12 @@ EIGEN_ALWAYS_INLINE void convertArrayPointerBF16toF32(float *result, Index cols, template EIGEN_ALWAYS_INLINE void storeResults(Packet4f (&acc)[4], Index rows, const Packet4f pAlpha, float* result, Index extra_cols, Index extra_rows); +template +EIGEN_ALWAYS_INLINE void outputVecColResults(Packet4f (&acc)[num_acc][size], float *result, Packet4f pAlpha, Index extra_rows); + +template +EIGEN_ALWAYS_INLINE void outputVecResults(Packet4f (&acc)[num_acc][size], float *result, Packet4f pAlpha); + template EIGEN_ALWAYS_INLINE Packet ploadLhs(const __UNPACK_TYPE__(Packet)* lhs); diff --git a/Eigen/src/Core/arch/AltiVec/MatrixProductMMAbfloat16.h b/Eigen/src/Core/arch/AltiVec/MatrixProductMMAbfloat16.h index 731fd9bcb..6bdbb0b56 100644 --- a/Eigen/src/Core/arch/AltiVec/MatrixProductMMAbfloat16.h +++ b/Eigen/src/Core/arch/AltiVec/MatrixProductMMAbfloat16.h @@ -44,6 +44,7 @@ EIGEN_ALWAYS_INLINE void KLoop { Packet8bf lhs[num_lhs], rhs[num_rhs]; + BFLOAT16_UNROLL for(Index i = 0; i < (num_rhs - (rhsExtraCols ? 1 : 0)); i++){ rhs[i] = loadRhsBfloat16(indexB + k*4, strideB, i); } @@ -52,8 +53,21 @@ EIGEN_ALWAYS_INLINE void KLoop } indexA += k*(lhsExtraRows ? extra_rows : num_packets); - for(Index j = 0; j < num_lhs; j++) { - lhs[j] = loadBfloat16(indexA + j*(zero ? 4 : 8)); // a packet of bfloat16 has 8 elements + if (num_lhs == 1) { + lhs[0] = loadBfloat16(indexA); + } else { + BFLOAT16_UNROLL + for(Index j = 0; j < num_lhs; j += 2) { + Packet8bf lhs1 = ploadu(indexA + (j + 0)*(zero ? 4 : 8)); + if (zero) { + Packet8bf lhs2 = pset1(Eigen::bfloat16(0)); + lhs[j + 0] = vec_mergeh(lhs1.m_val, lhs2.m_val); + lhs[j + 1] = vec_mergel(lhs1.m_val, lhs2.m_val); + } else { + lhs[j + 0] = lhs1; + lhs[j + 1] = ploadu(indexA + (j + 1)*8); + } + } } BFLOAT16_UNROLL @@ -84,7 +98,9 @@ EIGEN_ALWAYS_INLINE void disassembleAccumulators(__vector_quad (&quad_acc)[num_a template EIGEN_ALWAYS_INLINE void outputResults(Packet4f (&acc)[num_acc][4], Index rows, const Packet4f pAlpha, float* result, const Index extra_cols, Index extra_rows) { + BFLOAT16_UNROLL for(Index i = 0, k = 0; i < num_rhs - (rhsExtraCols ? 1 : 0); i++, result += 4*rows){ + BFLOAT16_UNROLL for(Index j = 0; j < num_lhs; j++, k++) { storeResults(acc[k], rows, pAlpha, result + j*4, extra_cols, extra_rows); } @@ -339,29 +355,6 @@ void gemmMMAbfloat16(const DataMapper& res, const bfloat16* indexA, const bfloat #undef MAX_BFLOAT16_ACC #if !EIGEN_ALTIVEC_DISABLE_MMA -template -EIGEN_ALWAYS_INLINE void outputVecCol(Packet4f acc, float *result, Packet4f pAlpha, Index extra_rows) -{ - Packet4f d0 = ploadu(result); - d0 = pmadd(acc, pAlpha, d0); - if (extraRows) { - pstoreu_partial(result, d0, extra_rows); - } else { - pstoreu(result, d0); - } -} - -template -EIGEN_ALWAYS_INLINE void outputVecColResults(Packet4f (&acc)[num_acc][4], float *result, Packet4f pAlpha, Index extra_rows) -{ - for(Index k = 0; k < num_acc - (extraRows ? 1 : 0); k++) { - outputVecCol(acc[k][0], result + k*4, pAlpha, extra_rows); - } - if (extraRows) { - outputVecCol(acc[num_acc - 1][0], result + (num_acc - 1)*4, pAlpha, extra_rows); - } -} - template EIGEN_ALWAYS_INLINE void loadVecLoop(Index k, LhsMapper& lhs, Packet8bf (&a0)[num_acc], Packet8bf b1) { @@ -396,6 +389,7 @@ EIGEN_ALWAYS_INLINE void vecColLoop(Index j, LhsMapper& lhs, RhsMapper& rhs, __v } LhsMapper lhs2 = lhs.getSubMapper(0, j); + BFLOAT16_UNROLL for(Index k = 0; k < num_acc; k += 2) { loadVecLoop(k, lhs2, a0, b1); } @@ -557,6 +551,7 @@ EIGEN_ALWAYS_INLINE void preduxVecResults2(Packet4f (&acc)[num_acc][4], Index k) template EIGEN_ALWAYS_INLINE void preduxVecResults(Packet4f (&acc)[num_acc][4]) { + BFLOAT16_UNROLL for(Index k = 0; k < num_acc; k += 4) { preduxVecResults2(acc, k + 0); if (num_acc > (k + 2)) { @@ -566,27 +561,6 @@ EIGEN_ALWAYS_INLINE void preduxVecResults(Packet4f (&acc)[num_acc][4]) } } -template -EIGEN_ALWAYS_INLINE void outputVecResults(Packet4f (&acc)[num_acc][4], float *result, Packet4f pAlpha) -{ - constexpr Index extra = num_acc & 3; - - for(Index k = 0; k < num_acc; k += 4) { - Packet4f d0 = ploadu(result + k); - d0 = pmadd(acc[k + 0][0], pAlpha, d0); - - if (num_acc > (k + 3)) { - pstoreu(result + k, d0); - } else { - if (extra == 3) { - pstoreu_partial(result + k, d0, extra); - } else { - memcpy((void *)(result + k), (void *)(&d0), sizeof(float) * extra); - } - } - } -} - template EIGEN_ALWAYS_INLINE void multVecLoop(__vector_quad (&quad_acc)[num_acc], const LhsMapper& lhs, RhsMapper& rhs, Index j, Index extra_cols) { @@ -599,6 +573,7 @@ EIGEN_ALWAYS_INLINE void multVecLoop(__vector_quad (&quad_acc)[num_acc], const L } const LhsMapper lhs2 = lhs.getSubMapper(0, j); + BFLOAT16_UNROLL for(Index k = 0; k < num_acc; k++) { if (extra) { a0[k] = lhs2.template loadPacketPartial(k, 0, extra_cols); diff --git a/Eigen/src/Core/arch/AltiVec/MatrixVectorProduct.h b/Eigen/src/Core/arch/AltiVec/MatrixVectorProduct.h index fa29b34e3..bb851fe1b 100644 --- a/Eigen/src/Core/arch/AltiVec/MatrixVectorProduct.h +++ b/Eigen/src/Core/arch/AltiVec/MatrixVectorProduct.h @@ -464,6 +464,492 @@ EIGEN_STRONG_INLINE void gemv_col( } } +template +EIGEN_ALWAYS_INLINE void outputVecCol(Packet4f acc, float *result, Packet4f pAlpha, Index extra_rows) +{ + Packet4f d0 = ploadu(result); + d0 = pmadd(acc, pAlpha, d0); + if (extraRows) { + pstoreu_partial(result, d0, extra_rows); + } else { + pstoreu(result, d0); + } +} + +template +EIGEN_ALWAYS_INLINE void outputVecColResults(Packet4f (&acc)[num_acc][size], float *result, Packet4f pAlpha, Index extra_rows) +{ + constexpr Index real_acc = (num_acc - (extraRows ? 1 : 0)); + for(Index k = 0; k < real_acc; k++) { + outputVecCol(acc[k][0], result + k*4, pAlpha, extra_rows); + } + if (extraRows) { + outputVecCol(acc[real_acc][0], result + real_acc*4, pAlpha, extra_rows); + } +} + +static Packet16uc p16uc_MERGE16_32_V1 = { 0, 1, 16,17, 0, 1, 16,17, 0, 1, 16,17, 0, 1, 16,17 }; +static Packet16uc p16uc_MERGE16_32_V2 = { 2, 3, 18,19, 2, 3, 18,19, 2, 3, 18,19, 2, 3, 18,19 }; + +template +EIGEN_ALWAYS_INLINE void loadVecLoopVSX(Index k, LhsMapper& lhs, Packet4f (&a0)[num_acc][2]) +{ + Packet8bf c0 = lhs.template loadPacket(k*4, 0); + Packet8bf b1; + if (!zero) { + b1 = lhs.template loadPacket(k*4, 1); + + a0[k + 0][1] = oneConvertBF16Hi(b1.m_val); + } + a0[k + 0][0] = oneConvertBF16Hi(c0.m_val); + + if (num_acc > (k + 1)) { + a0[k + 1][0] = oneConvertBF16Lo(c0.m_val); + if (!zero) { + a0[k + 1][1] = oneConvertBF16Lo(b1.m_val); + } + } +} + +template +EIGEN_ALWAYS_INLINE void multVecVSX(Packet4f (&acc)[num_acc][2], Packet4f (&a0)[num_acc][2], Packet4f (&b0)[2]) +{ + for(Index k = 0; k < num_acc; k++) { + for(Index i = 0; i < (zero ? 1 : 2); i++) { + acc[k][i] = pmadd(b0[i], a0[k][i], acc[k][i]); + } + } +} + +template +EIGEN_ALWAYS_INLINE void vecColLoopVSX(Index j, LhsMapper& lhs, RhsMapper& rhs, Packet4f (&acc)[num_acc][2]) +{ + Packet4f a0[num_acc][2], b0[2]; + Packet8bf b2 = rhs.template loadPacket(j + 0); + + b0[0] = oneConvertBF16Perm(b2.m_val, p16uc_MERGE16_32_V1); + if (!zero) { + b0[1] = oneConvertBF16Perm(b2.m_val, p16uc_MERGE16_32_V2); + } + + LhsMapper lhs2 = lhs.getSubMapper(0, j); + for(Index k = 0; k < num_acc; k += 2) { + loadVecLoopVSX(k, lhs2, a0); + } + + multVecVSX(acc, a0, b0); +} + +template +EIGEN_ALWAYS_INLINE void addResultsVSX(Packet4f (&acc)[num_acc][2]) +{ + for(Index i = 0; i < num_acc; i++) { + acc[i][0] = acc[i][0] + acc[i][1]; + } +} + +// Uses 2X the accumulators or 4X the number of VSX registers +#define MAX_BFLOAT16_VEC_ACC_VSX 8 + +template +void colVSXVecColLoopBody(Index& row, Index cend, Index rows, LhsMapper& lhs, RhsMapper& rhs, const Packet4f pAlpha, float *result) +{ + constexpr Index step = (num_acc * 4); + const Index extra_rows = (extraRows) ? (rows & 3) : 0; + constexpr bool multiIters = !extraRows && (num_acc == MAX_BFLOAT16_VEC_ACC_VSX); + + do{ + Packet4f acc[num_acc][2]; + + zeroAccumulators(acc); + + LhsMapper lhs2 = lhs.getSubMapper(row, 0); + for(Index j = 0; j + 2 <= cend; j += 2) { + vecColLoopVSX(j, lhs2, rhs, acc); + } + if (cend & 1) { + vecColLoopVSX(cend - 1, lhs2, rhs, acc); + } + + addResultsVSX(acc); + + outputVecColResults(acc, result, pAlpha, extra_rows); + + result += step; + } while(multiIters && (step <= rows - (row += step))); +} + +template +EIGEN_ALWAYS_INLINE void colVSXVecColLoopBodyExtraN(Index& row, Index cend, Index rows, LhsMapper& lhs, RhsMapper& rhs, const Packet4f pAlpha, float *result) +{ + if (MAX_BFLOAT16_VEC_ACC_VSX > num_acc) { + colVSXVecColLoopBody(row, cend, rows, lhs, rhs, pAlpha, result); + } +} + +template +EIGEN_ALWAYS_INLINE void colVSXVecColLoopBodyExtra(Index& row, Index cend, Index rows, LhsMapper& lhs, RhsMapper& rhs, const Packet4f pAlpha, float *result) +{ + switch ((rows - row) >> 2) { + case 7: + colVSXVecColLoopBodyExtraN<7, LhsMapper, RhsMapper, extraRows>(row, cend, rows, lhs, rhs, pAlpha, result); + break; + case 6: + colVSXVecColLoopBodyExtraN<6, LhsMapper, RhsMapper, extraRows>(row, cend, rows, lhs, rhs, pAlpha, result); + break; + case 5: + colVSXVecColLoopBodyExtraN<5, LhsMapper, RhsMapper, extraRows>(row, cend, rows, lhs, rhs, pAlpha, result); + break; + case 4: + colVSXVecColLoopBodyExtraN<4, LhsMapper, RhsMapper, extraRows>(row, cend, rows, lhs, rhs, pAlpha, result); + break; + case 3: + colVSXVecColLoopBodyExtraN<3, LhsMapper, RhsMapper, extraRows>(row, cend, rows, lhs, rhs, pAlpha, result); + break; + case 2: + colVSXVecColLoopBodyExtraN<2, LhsMapper, RhsMapper, extraRows>(row, cend, rows, lhs, rhs, pAlpha, result); + break; + case 1: + colVSXVecColLoopBodyExtraN<1, LhsMapper, RhsMapper, extraRows>(row, cend, rows, lhs, rhs, pAlpha, result); + break; + default: + if (extraRows) { + colVSXVecColLoopBody<1, LhsMapper, RhsMapper, true>(row, cend, rows, lhs, rhs, pAlpha, result); + } + break; + } +} + +template +EIGEN_ALWAYS_INLINE void calcVSXVecColLoops(Index cend, Index rows, LhsMapper& lhs, RhsMapper& rhs, const Packet4f pAlpha, float *result) +{ + Index row = 0; + if (rows >= (MAX_BFLOAT16_VEC_ACC_VSX * 4)) { + colVSXVecColLoopBody(row, cend, rows, lhs, rhs, pAlpha, result); + result += row; + } + if (rows & 3) { + colVSXVecColLoopBodyExtra(row, cend, rows, lhs, rhs, pAlpha, result); + } else { + colVSXVecColLoopBodyExtra(row, cend, rows, lhs, rhs, pAlpha, result); + } +} + +template +EIGEN_ALWAYS_INLINE void storeBF16fromResult(bfloat16* dst, Packet8bf data, Index resInc, Index extra) +{ + if (inc) { + if (size < 8) { + pscatter_partial(dst + delta*resInc, data, resInc, extra); + } else { + pscatter(dst + delta*resInc, data, resInc); + } + } else { + if (size < 8) { + pstoreu_partial(dst + delta, data, extra); + } else { + pstoreu(dst + delta, data); + } + } +} + +template +EIGEN_ALWAYS_INLINE void convertPointerF32toBF16VSX(Index& i, float* result, Index rows, bfloat16*& dst, Index resInc = 1) +{ + constexpr Index extra = ((size < 8) ? 8 : size); + for(; i + size <= rows; i += extra, dst += extra*resInc){ + PacketBlock r32; + r32.packet[0] = convertF32toBF16VSX(result + i + 0); + if (size >= 16) { + r32.packet[1] = convertF32toBF16VSX(result + i + 8); + } + if (size >= 32) { + r32.packet[2] = convertF32toBF16VSX(result + i + 16); + r32.packet[3] = convertF32toBF16VSX(result + i + 24); + } + storeBF16fromResult(dst, r32.packet[0], resInc, rows & 7); + if (size >= 16) { + storeBF16fromResult(dst, r32.packet[1], resInc); + } + if (size >= 32) { + storeBF16fromResult(dst, r32.packet[2], resInc); + storeBF16fromResult(dst, r32.packet[3], resInc); + } + } +} + +template +EIGEN_ALWAYS_INLINE void convertArrayPointerF32toBF16VSX(float *result, Index rows, bfloat16* dst, Index resInc = 1) +{ + Index i = 0; + convertPointerF32toBF16VSX<32,inc>(i, result, rows, dst, resInc); + convertPointerF32toBF16VSX<16,inc>(i, result, rows, dst, resInc); + convertPointerF32toBF16VSX<8,inc>(i, result, rows, dst, resInc); + convertPointerF32toBF16VSX<1,inc>(i, result, rows, dst, resInc); +} + +template +void gemv_bfloat16_col( + Index rows, Index cols, + const LhsMapper& alhs, + const RhsMapper& rhs, + bfloat16* res, Index resIncr, + bfloat16 alpha) +{ + typedef typename RhsMapper::LinearMapper LinearMapper; + + EIGEN_UNUSED_VARIABLE(resIncr); + eigen_internal_assert(resIncr == 1); + + // The following copy tells the compiler that lhs's attributes are not modified outside this function + // This helps GCC to generate proper code. + LhsMapper lhs(alhs); + RhsMapper rhs2(rhs); + + const Index lhsStride = lhs.stride(); + + // TODO: improve the following heuristic: + const Index block_cols = cols < 128 ? cols : (lhsStride * sizeof(bfloat16) < 16000 ? 16 : 8); + float falpha = Eigen::bfloat16_impl::bfloat16_to_float(alpha); + Packet4f pAlpha = pset1(falpha); + + ei_declare_aligned_stack_constructed_variable(float, result, rows, 0); + + convertArrayPointerBF16toF32(result, 1, rows, res); + + for (Index j2 = 0; j2 < cols; j2 += block_cols) + { + Index jend = numext::mini(j2 + block_cols, cols); + + LhsMapper lhs2 = lhs.getSubMapper(0, j2); + LinearMapper rhs3 = rhs2.getLinearMapper(j2, 0); + calcVSXVecColLoops(jend - j2, rows, lhs2, rhs3, pAlpha, result); + } + + convertArrayPointerF32toBF16VSX(result, rows, res); +} + +template +EIGEN_ALWAYS_INLINE void outputVecResults(Packet4f (&acc)[num_acc][size], float *result, Packet4f pAlpha) +{ + constexpr Index extra = num_acc & 3; + + for(Index k = 0; k < num_acc; k += 4) { + Packet4f d0 = ploadu(result + k); + d0 = pmadd(acc[k + 0][0], pAlpha, d0); + + if (num_acc > (k + 3)) { + pstoreu(result + k, d0); + } else { + if (extra == 3) { + pstoreu_partial(result + k, d0, extra); + } else { + memcpy((void *)(result + k), (void *)(&d0), sizeof(float) * extra); + } + } + } +} + +template +EIGEN_ALWAYS_INLINE void preduxVecResults2VSX(Packet4f (&acc)[num_acc][2], Index k) +{ + if (num_acc > (k + 1)) { + acc[k][1] = vec_mergel(acc[k + 0][0], acc[k + 1][0]); + acc[k][0] = vec_mergeh(acc[k + 0][0], acc[k + 1][0]); + acc[k][0] = acc[k][0] + acc[k][1]; + acc[k][0] += vec_sld(acc[k][0], acc[k][0], 8); + } else { + acc[k][0] += vec_sld(acc[k][0], acc[k][0], 8); +#ifdef _BIG_ENDIAN + acc[k][0] += vec_sld(acc[k][0], acc[k][0], 12); +#else + acc[k][0] += vec_sld(acc[k][0], acc[k][0], 4); +#endif + } +} + +template +EIGEN_ALWAYS_INLINE void preduxVecResultsVSX(Packet4f (&acc)[num_acc][2]) +{ + for(Index k = 0; k < num_acc; k += 4) { + preduxVecResults2VSX(acc, k + 0); + if (num_acc > (k + 2)) { + preduxVecResults2VSX(acc, k + 2); +#ifdef EIGEN_VECTORIZE_VSX + acc[k + 0][0] = reinterpret_cast(vec_mergeh(reinterpret_cast(acc[k + 0][0]), reinterpret_cast(acc[k + 2][0]))); +#else + acc[k + 0][0] = reinterpret_cast(vec_perm(acc[k + 0][0],acc[k + 2][0],p16uc_TRANSPOSE64_HI)); +#endif + } + } +} + +#ifndef _ARCH_PWR9 +EIGEN_ALWAYS_INLINE Packet8us loadPacketPartialZero(Packet8us data, Index extra_cols) +{ + Packet16uc shift = pset1(8 * 2 * (8 - extra_cols)); +#ifdef _BIG_ENDIAN + return reinterpret_cast(vec_slo(vec_sro(reinterpret_cast(data), shift), shift)); +#else + return reinterpret_cast(vec_sro(vec_slo(reinterpret_cast(data), shift), shift)); +#endif +} +#endif + +template +EIGEN_ALWAYS_INLINE void multVSXVecLoop(Packet4f (&acc)[num_acc][2], const LhsMapper& lhs, RhsMapper& rhs, Index j, Index extra_cols) +{ + Packet4f a0[num_acc][2], b0[2]; + Packet8bf a1, b1; + + if (extra) { + b1 = rhs.template loadPacketPartial(j, extra_cols); +#ifndef _ARCH_PWR9 + b1 = loadPacketPartialZero(b1.m_val, extra_cols); +#endif + } else { + b1 = rhs.template loadPacket(j); + } + b0[0] = oneConvertBF16Hi(b1.m_val); + b0[1] = oneConvertBF16Lo(b1.m_val); + + const LhsMapper lhs2 = lhs.getSubMapper(0, j); + for(Index k = 0; k < num_acc; k++) { + if (extra) { + a1 = lhs2.template loadPacketPartial(k, 0, extra_cols); +#ifndef _ARCH_PWR9 + a1 = loadPacketPartialZero(a1.m_val, extra_cols); +#endif + } else { + a1 = lhs2.template loadPacket(k, 0); + } + a0[k][0] = oneConvertBF16Hi(a1.m_val); + a0[k][1] = oneConvertBF16Lo(a1.m_val); + } + + multVecVSX(acc, a0, b0); +} + +template +EIGEN_ALWAYS_INLINE void vecVSXLoop(Index cols, const LhsMapper& lhs, RhsMapper& rhs, Packet4f (&acc)[num_acc][2], Index extra_cols) +{ + Index j = 0; + for(; j + 8 <= cols; j += 8){ + multVSXVecLoop(acc, lhs, rhs, j, extra_cols); + } + + if (extra_cols) { + multVSXVecLoop(acc, lhs, rhs, j, extra_cols); + } +} + +template +void colVSXVecLoopBody(Index& row, Index cols, Index rows, LhsMapper& lhs, RhsMapper& rhs, const Packet4f pAlpha, float *result) +{ + constexpr bool multiIters = (num_acc == MAX_BFLOAT16_VEC_ACC_VSX); + const Index extra_cols = (cols & 7); + + do{ + Packet4f acc[num_acc][2]; + + zeroAccumulators(acc); + + const LhsMapper lhs2 = lhs.getSubMapper(row, 0); + vecVSXLoop(cols, lhs2, rhs, acc, extra_cols); + + addResultsVSX(acc); + + preduxVecResultsVSX(acc); + + outputVecResults(acc, result, pAlpha); + + result += num_acc; + } while(multiIters && (num_acc <= rows - (row += num_acc))); +} + +template +EIGEN_ALWAYS_INLINE void colVSXVecLoopBodyExtraN(Index& row, Index cols, Index rows, LhsMapper& lhs, RhsMapper& rhs, const Packet4f pAlpha, float *result) +{ + if (MAX_BFLOAT16_VEC_ACC_VSX > num_acc) { + colVSXVecLoopBody(row, cols, rows, lhs, rhs, pAlpha, result); + } +} + +template +EIGEN_ALWAYS_INLINE void colVSXVecLoopBodyExtra(Index& row, Index cols, Index rows, LhsMapper& lhs, RhsMapper& rhs, const Packet4f pAlpha, float *result) +{ + switch (rows - row) { + case 7: + colVSXVecLoopBodyExtraN<7, LhsMapper, RhsMapper>(row, cols, rows, lhs, rhs, pAlpha, result); + break; + case 6: + colVSXVecLoopBodyExtraN<6, LhsMapper, RhsMapper>(row, cols, rows, lhs, rhs, pAlpha, result); + break; + case 5: + colVSXVecLoopBodyExtraN<5, LhsMapper, RhsMapper>(row, cols, rows, lhs, rhs, pAlpha, result); + break; + case 4: + colVSXVecLoopBodyExtraN<4, LhsMapper, RhsMapper>(row, cols, rows, lhs, rhs, pAlpha, result); + break; + case 3: + colVSXVecLoopBodyExtraN<3, LhsMapper, RhsMapper>(row, cols, rows, lhs, rhs, pAlpha, result); + break; + case 2: + colVSXVecLoopBodyExtraN<2, LhsMapper, RhsMapper>(row, cols, rows, lhs, rhs, pAlpha, result); + break; + case 1: + colVSXVecLoopBodyExtraN<1, LhsMapper, RhsMapper>(row, cols, rows, lhs, rhs, pAlpha, result); + break; + } +} + +template +EIGEN_ALWAYS_INLINE void calcVSXVecLoops(Index cols, Index rows, LhsMapper& lhs, RhsMapper& rhs, const Packet4f pAlpha, float *result) +{ + Index row = 0; + if (rows >= MAX_BFLOAT16_VEC_ACC_VSX) { + colVSXVecLoopBody(row, cols, rows, lhs, rhs, pAlpha, result); + result += row; + } + colVSXVecLoopBodyExtra(row, cols, rows, lhs, rhs, pAlpha, result); +} + +template +EIGEN_STRONG_INLINE void gemv_bfloat16_row( + Index rows, Index cols, + const LhsMapper& alhs, + const RhsMapper& rhs, + bfloat16* res, Index resIncr, + bfloat16 alpha) +{ + typedef typename RhsMapper::LinearMapper LinearMapper; + + // The following copy tells the compiler that lhs's attributes are not modified outside this function + // This helps GCC to generate proper code. + LhsMapper lhs(alhs); + LinearMapper rhs2 = rhs.getLinearMapper(0, 0); + + eigen_internal_assert(rhs.stride() == 1); + + float falpha = Eigen::bfloat16_impl::bfloat16_to_float(alpha); + const Packet4f pAlpha = pset1(falpha); + + ei_declare_aligned_stack_constructed_variable(float, result, rows, 0); + if (resIncr == 1) { + convertArrayPointerBF16toF32(result, 1, rows, res); + } else { + convertArrayPointerBF16toF32(result, 1, rows, res, resIncr); + } + calcVSXVecLoops(cols, rows, lhs, rhs2, pAlpha, result); + if (resIncr == 1) { + convertArrayPointerF32toBF16VSX(result, rows, res); + } else { + convertArrayPointerF32toBF16VSX(result, rows, res, resIncr); + } +} + +#undef MAX_BFLOAT16_VEC_ACC_VSX + const Packet16uc p16uc_COMPLEX32_XORFLIP = { 0x44,0x55,0x66,0x77, 0x00,0x11,0x22,0x33, 0xcc,0xdd,0xee,0xff, 0x88,0x99,0xaa,0xbb }; const Packet16uc p16uc_COMPLEX64_XORFLIP = { 0x88,0x99,0xaa,0xbb, 0xcc,0xdd,0xee,0xff, 0x00,0x11,0x22,0x33, 0x44,0x55,0x66,0x77 }; @@ -2062,6 +2548,13 @@ EIGEN_POWER_GEMV_REAL_SPECIALIZE_ROW(float) EIGEN_POWER_GEMV_REAL_SPECIALIZE_ROW(double) #ifdef USE_GEMV_MMA +#define gemv_bf16_col gemvMMA_bfloat16_col +#define gemv_bf16_row gemvMMA_bfloat16_row +#else +#define gemv_bf16_col gemv_bfloat16_col +#define gemv_bf16_row gemv_bfloat16_row +#endif + #define EIGEN_POWER_GEMV_REAL_SPECIALIZE_COL_BFLOAT16() \ template \ struct general_matrix_vector_product \ @@ -2072,7 +2565,7 @@ struct general_matrix_vector_product(rows, cols, lhs, rhs, res, resIncr, alpha); \ + gemv_bf16_col(rows, cols, lhs, rhs, res, resIncr, alpha); \ } \ }; @@ -2086,13 +2579,12 @@ struct general_matrix_vector_product(rows, cols, lhs, rhs, res, resIncr, alpha); \ + gemv_bf16_row(rows, cols, lhs, rhs, res, resIncr, alpha); \ } \ }; EIGEN_POWER_GEMV_REAL_SPECIALIZE_COL_BFLOAT16() EIGEN_POWER_GEMV_REAL_SPECIALIZE_ROW_BFLOAT16() -#endif template EIGEN_ALWAYS_INLINE ScalarBlock predux_complex(PResPacket& a0, PResPacket& b0, ResPacket& a1, ResPacket& b1)