From 1148f0a9ec48bcedac69c9ed66d4b2f6bab89344 Mon Sep 17 00:00:00 2001 From: Chip Kerchner Date: Fri, 14 Apr 2023 22:20:42 +0000 Subject: [PATCH] Add dynamic dispatch to BF16 GEMM (Power) and new VSX version --- Eigen/src/Core/arch/AltiVec/MatrixProduct.h | 920 ++++++++++++++---- .../Core/arch/AltiVec/MatrixProductCommon.h | 12 + .../src/Core/arch/AltiVec/MatrixProductMMA.h | 2 - .../arch/AltiVec/MatrixProductMMAbfloat16.h | 314 ++---- .../Core/arch/AltiVec/MatrixVectorProduct.h | 4 +- Eigen/src/Core/arch/AltiVec/PacketMath.h | 442 ++++++--- 6 files changed, 1129 insertions(+), 565 deletions(-) diff --git a/Eigen/src/Core/arch/AltiVec/MatrixProduct.h b/Eigen/src/Core/arch/AltiVec/MatrixProduct.h index 50270603b..f4e7e8dbd 100644 --- a/Eigen/src/Core/arch/AltiVec/MatrixProduct.h +++ b/Eigen/src/Core/arch/AltiVec/MatrixProduct.h @@ -841,7 +841,6 @@ struct dhs_pack } }; -#if !EIGEN_ALTIVEC_DISABLE_MMA && defined(__MMA__) // General template for lhs packing, bfloat16 specialization. template struct dhs_pack @@ -900,42 +899,60 @@ struct dhs_pack bload(block1, lhs2, 0 * vectorSize, i); bload(block2, lhs2, 1 * vectorSize, i); - Packet2ul v1[8], v2[8]; + Packet4ui v1[8], v2[8]; - v1[0] = reinterpret_cast(vec_mergeh(reinterpret_cast(block1.packet[0].m_val), reinterpret_cast(block1.packet[1].m_val))); - v1[1] = reinterpret_cast(vec_mergel(reinterpret_cast(block1.packet[0].m_val), reinterpret_cast(block1.packet[1].m_val))); - v1[2] = reinterpret_cast(vec_mergeh(reinterpret_cast(block1.packet[2].m_val), reinterpret_cast(block1.packet[3].m_val))); - v1[3] = reinterpret_cast(vec_mergel(reinterpret_cast(block1.packet[2].m_val), reinterpret_cast(block1.packet[3].m_val))); - v1[4] = reinterpret_cast(vec_mergeh(reinterpret_cast(block1.packet[4].m_val), reinterpret_cast(block1.packet[5].m_val))); - v1[5] = reinterpret_cast(vec_mergel(reinterpret_cast(block1.packet[4].m_val), reinterpret_cast(block1.packet[5].m_val))); - v1[6] = reinterpret_cast(vec_mergeh(reinterpret_cast(block1.packet[6].m_val), reinterpret_cast(block1.packet[7].m_val))); - v1[7] = reinterpret_cast(vec_mergel(reinterpret_cast(block1.packet[6].m_val), reinterpret_cast(block1.packet[7].m_val))); - v2[0] = reinterpret_cast(vec_mergeh(reinterpret_cast(block2.packet[0].m_val), reinterpret_cast(block2.packet[1].m_val))); - v2[1] = reinterpret_cast(vec_mergel(reinterpret_cast(block2.packet[0].m_val), reinterpret_cast(block2.packet[1].m_val))); - v2[2] = reinterpret_cast(vec_mergeh(reinterpret_cast(block2.packet[2].m_val), reinterpret_cast(block2.packet[3].m_val))); - v2[3] = reinterpret_cast(vec_mergel(reinterpret_cast(block2.packet[2].m_val), reinterpret_cast(block2.packet[3].m_val))); - v2[4] = reinterpret_cast(vec_mergeh(reinterpret_cast(block2.packet[4].m_val), reinterpret_cast(block2.packet[5].m_val))); - v2[5] = reinterpret_cast(vec_mergel(reinterpret_cast(block2.packet[4].m_val), reinterpret_cast(block2.packet[5].m_val))); - v2[6] = reinterpret_cast(vec_mergeh(reinterpret_cast(block2.packet[6].m_val), reinterpret_cast(block2.packet[7].m_val))); - v2[7] = reinterpret_cast(vec_mergel(reinterpret_cast(block2.packet[6].m_val), reinterpret_cast(block2.packet[7].m_val))); - - block1.packet[0] = reinterpret_cast(vec_mergeh(v1[0],v1[2])); - block1.packet[2] = reinterpret_cast(vec_mergel(v1[0],v1[2])); - block1.packet[4] = reinterpret_cast(vec_mergeh(v1[1],v1[3])); - block1.packet[6] = reinterpret_cast(vec_mergel(v1[1],v1[3])); - block1.packet[1] = reinterpret_cast(vec_mergeh(v1[4],v1[6])); - block1.packet[3] = reinterpret_cast(vec_mergel(v1[4],v1[6])); - block1.packet[5] = reinterpret_cast(vec_mergeh(v1[5],v1[7])); - block1.packet[7] = reinterpret_cast(vec_mergel(v1[5],v1[7])); - block2.packet[0] = reinterpret_cast(vec_mergeh(v2[0],v2[2])); - block2.packet[2] = reinterpret_cast(vec_mergel(v2[0],v2[2])); - block2.packet[4] = reinterpret_cast(vec_mergeh(v2[1],v2[3])); - block2.packet[6] = reinterpret_cast(vec_mergel(v2[1],v2[3])); - block2.packet[1] = reinterpret_cast(vec_mergeh(v2[4],v2[6])); - block2.packet[3] = reinterpret_cast(vec_mergel(v2[4],v2[6])); - block2.packet[5] = reinterpret_cast(vec_mergeh(v2[5],v2[7])); - block2.packet[7] = reinterpret_cast(vec_mergel(v2[5],v2[7])); + v1[0] = vec_mergeh(reinterpret_cast(block1.packet[0].m_val), reinterpret_cast(block1.packet[1].m_val)); + v1[1] = vec_mergel(reinterpret_cast(block1.packet[0].m_val), reinterpret_cast(block1.packet[1].m_val)); + v1[2] = vec_mergeh(reinterpret_cast(block1.packet[2].m_val), reinterpret_cast(block1.packet[3].m_val)); + v1[3] = vec_mergel(reinterpret_cast(block1.packet[2].m_val), reinterpret_cast(block1.packet[3].m_val)); + v1[4] = vec_mergeh(reinterpret_cast(block1.packet[4].m_val), reinterpret_cast(block1.packet[5].m_val)); + v1[5] = vec_mergel(reinterpret_cast(block1.packet[4].m_val), reinterpret_cast(block1.packet[5].m_val)); + v1[6] = vec_mergeh(reinterpret_cast(block1.packet[6].m_val), reinterpret_cast(block1.packet[7].m_val)); + v1[7] = vec_mergel(reinterpret_cast(block1.packet[6].m_val), reinterpret_cast(block1.packet[7].m_val)); + v2[0] = vec_mergeh(reinterpret_cast(block2.packet[0].m_val), reinterpret_cast(block2.packet[1].m_val)); + v2[1] = vec_mergel(reinterpret_cast(block2.packet[0].m_val), reinterpret_cast(block2.packet[1].m_val)); + v2[2] = vec_mergeh(reinterpret_cast(block2.packet[2].m_val), reinterpret_cast(block2.packet[3].m_val)); + v2[3] = vec_mergel(reinterpret_cast(block2.packet[2].m_val), reinterpret_cast(block2.packet[3].m_val)); + v2[4] = vec_mergeh(reinterpret_cast(block2.packet[4].m_val), reinterpret_cast(block2.packet[5].m_val)); + v2[5] = vec_mergel(reinterpret_cast(block2.packet[4].m_val), reinterpret_cast(block2.packet[5].m_val)); + v2[6] = vec_mergeh(reinterpret_cast(block2.packet[6].m_val), reinterpret_cast(block2.packet[7].m_val)); + v2[7] = vec_mergel(reinterpret_cast(block2.packet[6].m_val), reinterpret_cast(block2.packet[7].m_val)); +#ifdef EIGEN_VECTORIZE_VSX + block1.packet[0] = reinterpret_cast(vec_mergeh(reinterpret_cast(v1[0]),reinterpret_cast(v1[2]))); + block1.packet[2] = reinterpret_cast(vec_mergel(reinterpret_cast(v1[0]),reinterpret_cast(v1[2]))); + block1.packet[4] = reinterpret_cast(vec_mergeh(reinterpret_cast(v1[1]),reinterpret_cast(v1[3]))); + block1.packet[6] = reinterpret_cast(vec_mergel(reinterpret_cast(v1[1]),reinterpret_cast(v1[3]))); + block1.packet[1] = reinterpret_cast(vec_mergeh(reinterpret_cast(v1[4]),reinterpret_cast(v1[6]))); + block1.packet[3] = reinterpret_cast(vec_mergel(reinterpret_cast(v1[4]),reinterpret_cast(v1[6]))); + block1.packet[5] = reinterpret_cast(vec_mergeh(reinterpret_cast(v1[5]),reinterpret_cast(v1[7]))); + block1.packet[7] = reinterpret_cast(vec_mergel(reinterpret_cast(v1[5]),reinterpret_cast(v1[7]))); + block2.packet[0] = reinterpret_cast(vec_mergeh(reinterpret_cast(v2[0]),reinterpret_cast(v2[2]))); + block2.packet[2] = reinterpret_cast(vec_mergel(reinterpret_cast(v2[0]),reinterpret_cast(v2[2]))); + block2.packet[4] = reinterpret_cast(vec_mergeh(reinterpret_cast(v2[1]),reinterpret_cast(v2[3]))); + block2.packet[6] = reinterpret_cast(vec_mergel(reinterpret_cast(v2[1]),reinterpret_cast(v2[3]))); + block2.packet[1] = reinterpret_cast(vec_mergeh(reinterpret_cast(v2[4]),reinterpret_cast(v2[6]))); + block2.packet[3] = reinterpret_cast(vec_mergel(reinterpret_cast(v2[4]),reinterpret_cast(v2[6]))); + block2.packet[5] = reinterpret_cast(vec_mergeh(reinterpret_cast(v2[5]),reinterpret_cast(v2[7]))); + block2.packet[7] = reinterpret_cast(vec_mergel(reinterpret_cast(v2[5]),reinterpret_cast(v2[7]))); +#else + block1.packet[0] = reinterpret_cast(vec_perm(v1[0],v1[2],p16uc_TRANSPOSE64_HI)); + block1.packet[2] = reinterpret_cast(vec_perm(v1[0],v1[2],p16uc_TRANSPOSE64_LO)); + block1.packet[4] = reinterpret_cast(vec_perm(v1[1],v1[3],p16uc_TRANSPOSE64_HI)); + block1.packet[6] = reinterpret_cast(vec_perm(v1[1],v1[3],p16uc_TRANSPOSE64_LO)); + block1.packet[1] = reinterpret_cast(vec_perm(v1[4],v1[6],p16uc_TRANSPOSE64_HI)); + block1.packet[3] = reinterpret_cast(vec_perm(v1[4],v1[6],p16uc_TRANSPOSE64_LO)); + block1.packet[5] = reinterpret_cast(vec_perm(v1[5],v1[7],p16uc_TRANSPOSE64_HI)); + block1.packet[7] = reinterpret_cast(vec_perm(v1[5],v1[7],p16uc_TRANSPOSE64_LO)); + block2.packet[0] = reinterpret_cast(vec_perm(v2[0],v2[2],p16uc_TRANSPOSE64_HI)); + block2.packet[2] = reinterpret_cast(vec_perm(v2[0],v2[2],p16uc_TRANSPOSE64_LO)); + block2.packet[4] = reinterpret_cast(vec_perm(v2[1],v2[3],p16uc_TRANSPOSE64_HI)); + block2.packet[6] = reinterpret_cast(vec_perm(v2[1],v2[3],p16uc_TRANSPOSE64_LO)); + block2.packet[1] = reinterpret_cast(vec_perm(v2[4],v2[6],p16uc_TRANSPOSE64_HI)); + block2.packet[3] = reinterpret_cast(vec_perm(v2[4],v2[6],p16uc_TRANSPOSE64_LO)); + block2.packet[5] = reinterpret_cast(vec_perm(v2[5],v2[7],p16uc_TRANSPOSE64_HI)); + block2.packet[7] = reinterpret_cast(vec_perm(v2[5],v2[7],p16uc_TRANSPOSE64_LO)); +#endif for(Index M = 0; M < 8; M+=2) { pstore(blockA + ri + (0 * vectorSize) + (2*vectorSize * M), block1.packet[M+0]); @@ -1005,26 +1022,37 @@ struct dhs_pack bload(block1, lhs2, 0 * vectorSize, i); - Packet2ul v1[8]; + Packet4ui v1[8]; // This is transposing and interleaving data - v1[0] = reinterpret_cast(vec_mergeh(reinterpret_cast(block1.packet[0].m_val), reinterpret_cast(block1.packet[1].m_val))); - v1[1] = reinterpret_cast(vec_mergel(reinterpret_cast(block1.packet[0].m_val), reinterpret_cast(block1.packet[1].m_val))); - v1[2] = reinterpret_cast(vec_mergeh(reinterpret_cast(block1.packet[2].m_val), reinterpret_cast(block1.packet[3].m_val))); - v1[3] = reinterpret_cast(vec_mergel(reinterpret_cast(block1.packet[2].m_val), reinterpret_cast(block1.packet[3].m_val))); - v1[4] = reinterpret_cast(vec_mergeh(reinterpret_cast(block1.packet[4].m_val), reinterpret_cast(block1.packet[5].m_val))); - v1[5] = reinterpret_cast(vec_mergel(reinterpret_cast(block1.packet[4].m_val), reinterpret_cast(block1.packet[5].m_val))); - v1[6] = reinterpret_cast(vec_mergeh(reinterpret_cast(block1.packet[6].m_val), reinterpret_cast(block1.packet[7].m_val))); - v1[7] = reinterpret_cast(vec_mergel(reinterpret_cast(block1.packet[6].m_val), reinterpret_cast(block1.packet[7].m_val))); + v1[0] = vec_mergeh(reinterpret_cast(block1.packet[0].m_val), reinterpret_cast(block1.packet[1].m_val)); + v1[1] = vec_mergel(reinterpret_cast(block1.packet[0].m_val), reinterpret_cast(block1.packet[1].m_val)); + v1[2] = vec_mergeh(reinterpret_cast(block1.packet[2].m_val), reinterpret_cast(block1.packet[3].m_val)); + v1[3] = vec_mergel(reinterpret_cast(block1.packet[2].m_val), reinterpret_cast(block1.packet[3].m_val)); + v1[4] = vec_mergeh(reinterpret_cast(block1.packet[4].m_val), reinterpret_cast(block1.packet[5].m_val)); + v1[5] = vec_mergel(reinterpret_cast(block1.packet[4].m_val), reinterpret_cast(block1.packet[5].m_val)); + v1[6] = vec_mergeh(reinterpret_cast(block1.packet[6].m_val), reinterpret_cast(block1.packet[7].m_val)); + v1[7] = vec_mergel(reinterpret_cast(block1.packet[6].m_val), reinterpret_cast(block1.packet[7].m_val)); - block1.packet[0] = reinterpret_cast(vec_mergeh(v1[0],v1[2])); - block1.packet[2] = reinterpret_cast(vec_mergel(v1[0],v1[2])); - block1.packet[4] = reinterpret_cast(vec_mergeh(v1[1],v1[3])); - block1.packet[6] = reinterpret_cast(vec_mergel(v1[1],v1[3])); - block1.packet[1] = reinterpret_cast(vec_mergeh(v1[4],v1[6])); - block1.packet[3] = reinterpret_cast(vec_mergel(v1[4],v1[6])); - block1.packet[5] = reinterpret_cast(vec_mergeh(v1[5],v1[7])); - block1.packet[7] = reinterpret_cast(vec_mergel(v1[5],v1[7])); +#ifdef EIGEN_VECTORIZE_VSX + block1.packet[0] = reinterpret_cast(vec_mergeh(reinterpret_cast(v1[0]),reinterpret_cast(v1[2]))); + block1.packet[2] = reinterpret_cast(vec_mergel(reinterpret_cast(v1[0]),reinterpret_cast(v1[2]))); + block1.packet[4] = reinterpret_cast(vec_mergeh(reinterpret_cast(v1[1]),reinterpret_cast(v1[3]))); + block1.packet[6] = reinterpret_cast(vec_mergel(reinterpret_cast(v1[1]),reinterpret_cast(v1[3]))); + block1.packet[1] = reinterpret_cast(vec_mergeh(reinterpret_cast(v1[4]),reinterpret_cast(v1[6]))); + block1.packet[3] = reinterpret_cast(vec_mergel(reinterpret_cast(v1[4]),reinterpret_cast(v1[6]))); + block1.packet[5] = reinterpret_cast(vec_mergeh(reinterpret_cast(v1[5]),reinterpret_cast(v1[7]))); + block1.packet[7] = reinterpret_cast(vec_mergel(reinterpret_cast(v1[5]),reinterpret_cast(v1[7]))); +#else + block1.packet[0] = reinterpret_cast(vec_perm(v1[0],v1[2],p16uc_TRANSPOSE64_HI)); + block1.packet[2] = reinterpret_cast(vec_perm(v1[0],v1[2],p16uc_TRANSPOSE64_LO)); + block1.packet[4] = reinterpret_cast(vec_perm(v1[1],v1[3],p16uc_TRANSPOSE64_HI)); + block1.packet[6] = reinterpret_cast(vec_perm(v1[1],v1[3],p16uc_TRANSPOSE64_LO)); + block1.packet[1] = reinterpret_cast(vec_perm(v1[4],v1[6],p16uc_TRANSPOSE64_HI)); + block1.packet[3] = reinterpret_cast(vec_perm(v1[4],v1[6],p16uc_TRANSPOSE64_LO)); + block1.packet[5] = reinterpret_cast(vec_perm(v1[5],v1[7],p16uc_TRANSPOSE64_HI)); + block1.packet[7] = reinterpret_cast(vec_perm(v1[5],v1[7],p16uc_TRANSPOSE64_LO)); +#endif for(Index M = 0; M < 8; M++) { pstore(blockA + ri + (vectorSize * M), block1.packet[M]); @@ -1157,16 +1185,24 @@ struct dhs_pack bload(block, rhs2, i, 0); - Packet2ul t0, t1, t2, t3; - t0 = reinterpret_cast(vec_mergeh(reinterpret_cast(block.packet[0].m_val), reinterpret_cast(block.packet[1].m_val))); - t1 = reinterpret_cast(vec_mergeh(reinterpret_cast(block.packet[2].m_val), reinterpret_cast(block.packet[3].m_val))); - t2 = reinterpret_cast(vec_mergel(reinterpret_cast(block.packet[0].m_val), reinterpret_cast(block.packet[1].m_val))); - t3 = reinterpret_cast(vec_mergel(reinterpret_cast(block.packet[2].m_val), reinterpret_cast(block.packet[3].m_val))); + Packet4ui t0, t1, t2, t3; - block.packet[0] = reinterpret_cast(vec_mergeh(t0, t1)); - block.packet[1] = reinterpret_cast(vec_mergel(t0, t1)); - block.packet[2] = reinterpret_cast(vec_mergeh(t2, t3)); - block.packet[3] = reinterpret_cast(vec_mergel(t2, t3)); + t0 = vec_mergeh(reinterpret_cast(block.packet[0].m_val), reinterpret_cast(block.packet[1].m_val)); + t1 = vec_mergel(reinterpret_cast(block.packet[0].m_val), reinterpret_cast(block.packet[1].m_val)); + t2 = vec_mergeh(reinterpret_cast(block.packet[2].m_val), reinterpret_cast(block.packet[3].m_val)); + t3 = vec_mergel(reinterpret_cast(block.packet[2].m_val), reinterpret_cast(block.packet[3].m_val)); + +#ifdef EIGEN_VECTORIZE_VSX + block.packet[0] = reinterpret_cast(vec_mergeh(reinterpret_cast(t0),reinterpret_cast(t2))); + block.packet[1] = reinterpret_cast(vec_mergel(reinterpret_cast(t0),reinterpret_cast(t2))); + block.packet[2] = reinterpret_cast(vec_mergeh(reinterpret_cast(t1),reinterpret_cast(t3))); + block.packet[3] = reinterpret_cast(vec_mergel(reinterpret_cast(t1),reinterpret_cast(t3))); +#else + block.packet[0] = reinterpret_cast(vec_perm(t0,t2,p16uc_TRANSPOSE64_HI)); + block.packet[1] = reinterpret_cast(vec_perm(t0,t2,p16uc_TRANSPOSE64_LO)); + block.packet[2] = reinterpret_cast(vec_perm(t1,t3,p16uc_TRANSPOSE64_HI)); + block.packet[3] = reinterpret_cast(vec_perm(t1,t3,p16uc_TRANSPOSE64_LO)); +#endif storeBlock(blockB + ri, block); } else { @@ -1254,7 +1290,6 @@ struct dhs_pack } } }; -#endif // General template for lhs complex packing, float64 specialization. template @@ -2674,8 +2709,596 @@ EIGEN_STRONG_INLINE void gemm_complex(const DataMapper& res, const LhsScalar* bl #undef advanceCols #undef advanceRows +EIGEN_ALWAYS_INLINE bool supportsMMA() +{ +#if defined(EIGEN_ALTIVEC_MMA_ONLY) + return true; +#else +#if EIGEN_COMP_LLVM + return false; // No dynamic dispatch for LLVM +#else + return __builtin_cpu_supports ("arch_3_1") && __builtin_cpu_supports ("mma"); +#endif +#endif +} + +EIGEN_ALWAYS_INLINE Packet4f loadAndMultiplyF32(Packet4f acc, const Packet4f pAlpha, float* result) +{ + Packet4f result_block = ploadu(result); + return pmadd(acc, pAlpha, result_block); +} + +template +EIGEN_ALWAYS_INLINE void storeF32(float*& result, Packet4f result_block, Index rows, Index extra_rows) +{ + if (lhsExtraRows) { + pstoreu_partial(result, result_block, extra_rows); + } else { + pstoreu(result, result_block); + } + result += rows; +} + +template +EIGEN_ALWAYS_INLINE void storeResults(Packet4f (&acc)[4], Index rows, const Packet4f pAlpha, float* result, Index extra_cols, Index extra_rows) +{ + Index x = 0; + if (rhsExtraCols) { + do{ + Packet4f result_block = loadAndMultiplyF32(acc[x], pAlpha, result); + storeF32(result, result_block, rows, extra_rows); + } while (++x < extra_cols); + } else { + Packet4f result_block[4]; + float *result2 = result; + do{ + result_block[x] = loadAndMultiplyF32(acc[x], pAlpha, result); + result += rows; + } while (++x < 4); + x = 0; + do{ + storeF32(result2, result_block[x], rows, extra_rows); + } while (++x < 4); + } +} + +EIGEN_ALWAYS_INLINE Packet4f oneConvertBF16Hi(Packet8us data) +{ + Packet8us z = pset1(0); +#ifdef _BIG_ENDIAN + return reinterpret_cast(vec_mergeh(data, z)); +#else + return reinterpret_cast(vec_mergeh(z, data)); +#endif +} + +EIGEN_ALWAYS_INLINE Packet4f oneConvertBF16Lo(Packet8us data) +{ + Packet8us z = pset1(0); +#ifdef _BIG_ENDIAN + return reinterpret_cast(vec_mergel(data, z)); +#else + return reinterpret_cast(vec_mergel(z, data)); +#endif +} + +template +EIGEN_ALWAYS_INLINE void storeConvertTwoBF16(float* to, PacketBlock& block, Index extra = 0) +{ + if (N < 4) { + pstoreu_partial(to + 0, oneConvertBF16Hi(block.packet[0].m_val), extra); + } else if (N >= (M*8+4)) { + pstoreu(to + 0, oneConvertBF16Hi(block.packet[M].m_val)); + if (N >= 8) { + pstoreu(to + 4, oneConvertBF16Lo(block.packet[M].m_val)); + } + } +} + +template +EIGEN_ALWAYS_INLINE void storeConvertBlockBF16(float* to, PacketBlock& block, Index extra) +{ + storeConvertTwoBF16(to + 0, block, extra); + if (N >= 16) { + storeConvertTwoBF16(to + 8, block); + } + if (N >= 32) { + storeConvertTwoBF16(to + 16, block); + storeConvertTwoBF16(to + 24, block); + } +} + +template +EIGEN_ALWAYS_INLINE Packet8bf loadBF16fromResult(bfloat16* src, Index resInc) +{ + if (non_unit_stride) { + return pgather(src + delta*resInc, resInc); + } else { + return ploadu(src + delta); + } +} + +static Packet16uc p16uc_MERGE16_32_1 = { 0, 1, 16,17, 2, 3, 18,19, 0, 1, 16,17, 2, 3, 18,19 }; +static Packet16uc p16uc_MERGE16_32_2 = { 4, 5, 20,21, 6, 7, 22,23, 4, 5, 20,21, 6, 7, 22,23 }; +static Packet16uc p16uc_MERGE16_32_3 = { 8, 9, 24,25, 10,11, 26,27, 8, 9, 24,25, 10,11, 26,27 }; +static Packet16uc p16uc_MERGE16_32_4 = { 12,13, 28,29, 14,15, 30,31, 12,13, 28,29, 14,15, 30,31 }; + +static Packet16uc p16uc_MERGE16_32_5 = { 0,1, 16,17, 16,17, 16,17, 0,1, 16,17, 16,17, 16,17 }; +static Packet16uc p16uc_MERGE16_32_6 = { 2,3, 18,19, 18,19, 18,19, 2,3, 18,19, 18,19, 18,19 }; +static Packet16uc p16uc_MERGE16_32_7 = { 4,5, 20,21, 20,21, 20,21, 4,5, 20,21, 20,21, 20,21 }; +static Packet16uc p16uc_MERGE16_32_8 = { 6,7, 22,23, 22,23, 22,23, 6,7, 22,23, 22,23, 22,23 }; + +EIGEN_ALWAYS_INLINE Packet4f oneConvertBF16Perm(Packet8us data, Packet16uc mask) +{ + Packet8us z = pset1(0); +#ifdef _BIG_ENDIAN + return reinterpret_cast(vec_perm(data, z, mask)); +#else + return reinterpret_cast(vec_perm(z, data, mask)); +#endif +} + +template +EIGEN_ALWAYS_INLINE void convertArrayPointerBF16toF32DupOne(float *result, Index col, Index rows, const bfloat16* src, Index extra_rows) +{ + Packet4f dup[4*4]; + Packet8bf data[4]; + + for (Index i = 0; i < size; i++) { + data[i] = ploadu(src + col + rows*i); + } + + for (Index i = 0, j = 0; i < size; i++, j += 4) { + dup[j+0] = oneConvertBF16Perm(data[i].m_val, odd ? p16uc_MERGE16_32_5 : p16uc_MERGE16_32_1); + dup[j+1] = oneConvertBF16Perm(data[i].m_val, odd ? p16uc_MERGE16_32_6 : p16uc_MERGE16_32_2); + dup[j+2] = oneConvertBF16Perm(data[i].m_val, odd ? p16uc_MERGE16_32_7 : p16uc_MERGE16_32_3); + dup[j+3] = oneConvertBF16Perm(data[i].m_val, odd ? p16uc_MERGE16_32_8 : p16uc_MERGE16_32_4); + } + + for (Index j = 0; j < 4*size; j += 4) { + if (lhsExtraRows) { + Packet4f z = pset1(float(0)); + Index i = 0; + do { + pstoreu(result + (j+i)*4, dup[j+i]); + } while (++i < extra_rows); + do { + pstoreu(result + (j+i)*4, z); + } while (++i < 4); + } else { + for (Index i = 0; i < 4; i++) { + pstoreu(result + (j+i)*4, dup[j+i]); + } + } + } +} + +template +EIGEN_ALWAYS_INLINE void convertArrayPointerBF16toF32Dup(float *result, Index cols, Index rows, const bfloat16* src, Index delta, Index extra_rows) +{ + Index col2 = 0, col = 0; + for(; col + 4*2 <= cols; col += 4*2, col2 += 4*rows, result += 4*4*4) { + convertArrayPointerBF16toF32DupOne(result, col2 + delta*2, rows, src, extra_rows); + } + for(; col + 2 <= cols; col += 2, col2 += rows, result += 4*4) { + convertArrayPointerBF16toF32DupOne(result, col2 + delta*2, rows, src, extra_rows); + } + if (cols & 1) { + convertArrayPointerBF16toF32DupOne(result, col2 + delta, rows, src, extra_rows); + } +} + +template +EIGEN_ALWAYS_INLINE void convertPointerBF16toF32(Index& i, float *result, Index rows, bfloat16*& src, Index resInc) +{ + constexpr Index extra = ((size < 4) ? 4 : size); + for(; i + size <= rows; i += extra, src += extra*resInc){ + PacketBlock r32; + r32.packet[0] = loadBF16fromResult(src, resInc); + if (size >= 16) { + r32.packet[1] = loadBF16fromResult(src, resInc); + } + if (size >= 32) { + r32.packet[2] = loadBF16fromResult(src, resInc); + r32.packet[3] = loadBF16fromResult(src, resInc); + } + storeConvertBlockBF16(result + i, r32, rows & 3); + } +} + +template +EIGEN_ALWAYS_INLINE void convertArrayPointerBF16toF32(float *result, Index cols, Index rows, bfloat16* src, Index resInc) +{ + for(Index col = 0; col < cols; col++, src += (rows*resInc), result += rows) { + Index i = 0; + bfloat16* src2 = src; + convertPointerBF16toF32<32, non_unit_stride>(i, result, rows, src2, resInc); + convertPointerBF16toF32<16, non_unit_stride>(i, result, rows, src2, resInc); + convertPointerBF16toF32<8, non_unit_stride>(i, result, rows, src2, resInc); + convertPointerBF16toF32<4, non_unit_stride>(i, result, rows, src2, resInc); + convertPointerBF16toF32<1, non_unit_stride>(i, result, rows, src2, resInc); + } +} + +template +EIGEN_ALWAYS_INLINE void zeroAccumulators(Packet4f (&acc)[num_acc][4]) +{ + Packet4f z = pset1(float(0)); + + for(Index k = 0; k < num_acc; k++) { + for(Index j = 0; j < 4; j++) { + acc[k][j] = z; + } + } +} + +template +EIGEN_ALWAYS_INLINE void tranposeResults(Packet4f (&acc)[num_acc][4]) +{ + for(Index i = 0; i < num_acc; i++) { + Packet4ui t0, t1, t2, t3; + t0 = vec_mergeh(reinterpret_cast(acc[i][0]), reinterpret_cast(acc[i][2])); + t1 = vec_mergel(reinterpret_cast(acc[i][0]), reinterpret_cast(acc[i][2])); + t2 = vec_mergeh(reinterpret_cast(acc[i][1]), reinterpret_cast(acc[i][3])); + t3 = vec_mergel(reinterpret_cast(acc[i][1]), reinterpret_cast(acc[i][3])); + acc[i][0] = reinterpret_cast(vec_mergeh(t0, t2)); + acc[i][1] = reinterpret_cast(vec_mergel(t0, t2)); + acc[i][2] = reinterpret_cast(vec_mergeh(t1, t3)); + acc[i][3] = reinterpret_cast(vec_mergel(t1, t3)); + } +} + +template +EIGEN_ALWAYS_INLINE void addResults(Packet4f (&acc)[num_acc][4]) +{ + for(Index i = 0, j = 0; j < num_acc; i++, j += 2) { + for(Index x = 0, y = 0; x < 2; x++, y += 2) { + for(Index w = 0, z = 0; w < 2; w++, z += 2) { + acc[i][y+w] = acc[j+x][z+0] + acc[j+x][z+1]; + } + } + } +} + +template +EIGEN_ALWAYS_INLINE void outputResultsVSX(Packet4f (&acc)[num_acc][4], Index rows, const Packet4f pAlpha, float* result, const Index extra_cols, Index extra_rows) +{ + tranposeResults(acc); + addResults(acc); + + constexpr Index real_rhs = ((num_rhs / 2) - (rhsExtraCols ? 1 : 0)); + Index k = 0; + for(Index i = 0; i < real_rhs; i++, result += 4*rows, k++){ + storeResults(acc[k], rows, pAlpha, result, extra_cols, extra_rows); + } + if(rhsExtraCols) { + storeResults(acc[k], rows, pAlpha, result, extra_cols, extra_rows); + } +} + +template +EIGEN_ALWAYS_INLINE void loadTwoRhsFloat32(const float* block, Index strideB, Index i, Packet4f& dhs0, Packet4f &dhs1) +{ + dhs0 = ploadu(block + strideB*i + 0); + if (zero) { + Packet4f dhs2 = pset1(float(0)); + dhs1 = vec_mergel(dhs0, dhs2); + dhs0 = vec_mergeh(dhs0, dhs2); + } else { + dhs1 = ploadu(block + strideB*i + 4); + } +} + +template +EIGEN_ALWAYS_INLINE void KLoop +( + const float* indexA, + const float* indexB, + Packet4f (&acc)[num_acc][4], + Index strideB, + Index k, + Index offsetB, + Index extra_cols +) +{ + constexpr Index num_lhs = 4; + Packet4f lhs[num_lhs], rhs[num_rhs]; + + constexpr Index real_rhs = (num_rhs - (rhsExtraCols ? 2 : 0)); + for(Index i = 0; i < real_rhs; i += 2){ + loadTwoRhsFloat32(indexB + k*4, strideB, i, rhs[i + 0], rhs[i + 1]); + } + if(rhsExtraCols) { + loadTwoRhsFloat32(indexB + k*extra_cols - offsetB, strideB, real_rhs, rhs[real_rhs + 0], rhs[real_rhs + 1]); + } + + indexA += 2*k*4; + for(Index j = 0; j < num_lhs; j++) { + lhs[j] = ploadu(indexA + j*4); + } + + for(Index j = 0; j < num_rhs; j++) { + for(Index i = 0; i < num_lhs; i++) { + acc[j][i] = pmadd(rhs[j], lhs[i], acc[j][i]); + } + } +} + +template +EIGEN_ALWAYS_INLINE void colVSXLoopBodyIter(Index depth, Index rows, const Packet4f pAlpha, const float* indexA, const float* indexB, Index strideB, Index offsetB, float* result, const Index extra_cols, const Index extra_rows) +{ + constexpr Index num_rhs = num_acc; + + Packet4f acc[num_acc][4]; + + zeroAccumulators(acc); + + Index k; + for(k = 0; k + 2 <= depth; k += 2){ + KLoop(indexA, indexB, acc, strideB, k, offsetB, extra_cols); + } + if(depth&1){ + KLoop(indexA, indexB, acc, strideB, k, offsetB, extra_cols); + } + + outputResultsVSX(acc, rows, pAlpha, result, extra_cols, extra_rows); +} + +// No more than 4 (uses 2X the accumulators or 8X the number of VSX registers) +#define MAX_BFLOAT16_ACC_VSX 4 + +template +void colVSXLoopBody(Index& col, Index depth, Index cols, Index rows, const Packet4f pAlpha, const float* indexA, const float* indexB, Index strideB, Index offsetB, float* result) +{ + constexpr Index step = (num_acc * 4); // each accumulator has 4 elements + const Index extra_cols = (rhsExtraCols) ? (cols & 3) : 0; + const Index extra_rows = (lhsExtraRows) ? (rows & 3) : 0; + constexpr bool multiIters = !rhsExtraCols && (num_acc == MAX_BFLOAT16_ACC_VSX); + + do{ + colVSXLoopBodyIter(depth, rows, pAlpha, indexA, indexB, strideB, offsetB, result, extra_cols, extra_rows); + + indexB += strideB*(num_acc * 2); + result += rows*step; + } while(multiIters && (step <= cols - (col += step))); +} + +template +EIGEN_ALWAYS_INLINE void colVSXLoopBodyExtraN(Index col, Index depth, Index cols, Index rows, const Packet4f pAlpha, const float* indexA, const float* blockB, Index strideB, Index offsetB, float* result) +{ + if (MAX_BFLOAT16_ACC_VSX > num_acc) { + colVSXLoopBody(col, depth, cols, rows, pAlpha, indexA, blockB, strideB, offsetB, result); + } +} + +template +void colVSXLoopBodyExtra(Index col, Index depth, Index cols, Index rows, const Packet4f pAlpha, const float* indexA, const float* blockB, Index strideB, Index offsetB, float* result) +{ + switch ((cols - col) >> 2) { + case 3: + colVSXLoopBodyExtraN<3, rhsExtraCols, lhsExtraRows>(col, depth, cols, rows, pAlpha, indexA, blockB, strideB, offsetB, result); + break; + case 2: + colVSXLoopBodyExtraN<2, rhsExtraCols, lhsExtraRows>(col, depth, cols, rows, pAlpha, indexA, blockB, strideB, offsetB, result); + break; + case 1: + colVSXLoopBodyExtraN<1, rhsExtraCols, lhsExtraRows>(col, depth, cols, rows, pAlpha, indexA, blockB, strideB, offsetB, result); + break; + default: + if (rhsExtraCols) { + colVSXLoopBody<1, true, lhsExtraRows>(col, depth, cols, rows, pAlpha, indexA, blockB, strideB, offsetB, result); + } + break; + } +} + +template +EIGEN_ALWAYS_INLINE void colVSXLoops(Index depth, Index cols, Index rows, const Packet4f pAlpha, const bfloat16* indexA, const float* indexA2, const float* blockB2, Index strideA, Index strideB, Index offsetB, float* result2) +{ + Index delta_rows = 2*(lhsExtraRows ? (rows & 3) : size); + for (Index row = 0; row < size; row += 4) { + convertArrayPointerBF16toF32Dup(const_cast(indexA2), strideA, delta_rows, indexA, row, rows & 3); + + const float *blockB = blockB2; + float *result = result2 + row; + + Index col = 0; + if (cols >= (MAX_BFLOAT16_ACC_VSX * 4)) { + colVSXLoopBody(col, depth, cols, rows, pAlpha, indexA2, blockB, strideB, 0, result); + blockB += (strideB >> 1)*col; + result += rows*col; + } + if (cols & 3) { + colVSXLoopBodyExtra(col, depth, cols, rows, pAlpha, indexA2, blockB, strideB, offsetB, result); + } else { + colVSXLoopBodyExtra(col, depth, cols, rows, pAlpha, indexA2, blockB, strideB, 0, result); + } + } +} + +template +EIGEN_ALWAYS_INLINE void calcVSXColLoops(const bfloat16*& indexA, const float* indexA2, Index& row, Index depth, Index cols, Index rows, const Packet4f pAlpha, const float* indexB, Index strideA, Index strideB, Index offsetA, Index offsetB, Index bigSuffix, float *result) +{ + if ((size == 16) || (rows & size)) { + indexA += size*offsetA; + colVSXLoops(depth, cols, rows, pAlpha, indexA, indexA2, indexB, strideA, strideB, offsetB, result + row); + row += size; + indexA += bigSuffix*size/16; + } +} + +template +EIGEN_ALWAYS_INLINE void convertBF16toF32(Index& i, float *result, Index rows, const DataMapper& src) +{ + constexpr Index extra = ((size < 4) ? 4 : size); + for(; i + size <= rows; i += extra){ + PacketBlock r32; + r32.packet[0] = src.template loadPacket(i + 0); + if (size >= 16) { + r32.packet[1] = src.template loadPacket(i + 8); + } + if (size >= 32) { + r32.packet[2] = src.template loadPacket(i + 16); + r32.packet[3] = src.template loadPacket(i + 24); + } + storeConvertBlockBF16(result + i, r32, rows & 3); + } +} + +template +EIGEN_ALWAYS_INLINE void convertArrayBF16toF32(float *result, Index cols, Index rows, const DataMapper& src) +{ + typedef typename DataMapper::LinearMapper LinearMapper; + for(Index j = 0; j < cols; j++, result += rows){ + const LinearMapper src2 = src.getLinearMapper(0, j); + Index i = 0; + convertBF16toF32<32, LinearMapper>(i, result, rows, src2); + convertBF16toF32<16, LinearMapper>(i, result, rows, src2); + convertBF16toF32<8, LinearMapper>(i, result, rows, src2); + convertBF16toF32<4, LinearMapper>(i, result, rows, src2); + convertBF16toF32<1, LinearMapper>(i, result, rows, src2); + } +} + +EIGEN_ALWAYS_INLINE Packet8bf convertF32toBF16VSX(const float *res) +{ + return F32ToBf16Both(ploadu(res + 0), ploadu(res + 4)); +} + +template +EIGEN_ALWAYS_INLINE void convertArrayF32toBF16ColVSX(float *result, Index col, Index rows, const DataMapper& res) +{ + const DataMapper res2 = res.getSubMapper(0, col); + Index row; + float *result2 = result + col*rows; + for(row = 0; row + 8 <= rows; row += 8){ + // get and save block + PacketBlock block; + for(Index j = 0; j < size; j++){ + block.packet[j] = convertF32toBF16VSX(result2 + j*rows + row); + } + res2.template storePacketBlock(row, 0, block); + } + // extra rows + if(row < rows){ + for(Index j = 0; j < size; j++){ + Packet8bf fp16 = convertF32toBF16VSX(result2 + j*rows + row); + res2.template storePacketPartial(row, j, fp16, rows & 7); + } + } +} + +template +EIGEN_ALWAYS_INLINE void convertArrayF32toBF16VSX(float *result, Index cols, Index rows, const DataMapper& res) +{ + Index col; + for(col = 0; col + 4 <= cols; col += 4){ + convertArrayF32toBF16ColVSX(result, col, rows, res); + } + // extra cols + while(col < cols){ + convertArrayF32toBF16ColVSX(result, col, rows, res); + col++; + } +} + +template +void gemmbfloat16(const DataMapper& res, const bfloat16* indexA, const bfloat16* indexB, Index rows, Index depth, Index cols, bfloat16 alpha, Index strideA, Index strideB, Index offsetA, Index offsetB) +{ + float falpha = Eigen::bfloat16_impl::bfloat16_to_float(alpha); + const Packet4f pAlpha = pset1(falpha); + + if( strideA == -1 ) strideA = depth; + if( strideB == -1 ) strideB = depth; + + ei_declare_aligned_stack_constructed_variable(float, result, cols*rows, 0); + ei_declare_aligned_stack_constructed_variable(float, indexB2, strideB*cols, 0); + ei_declare_aligned_stack_constructed_variable(float, indexA2, ((strideA + 1) & -2)*4*2, 0); + + convertArrayBF16toF32(result, cols, rows, res); + convertArrayPointerBF16toF32(indexB2, cols, strideB, const_cast(indexB)); + + Index bigSuffix = 2*8*(strideA-offsetA); + float* indexBF32 = indexB2 + 4*offsetB; + offsetB *= 3; + strideB *= 2; + + Index row = 0; + // LHS (8x16) block + while(row + 16 <= rows){ + calcVSXColLoops<16>(indexA, indexA2, row, depth, cols, rows, pAlpha, indexBF32, strideA, strideB, offsetA, offsetB, bigSuffix, result); + } + // LHS (8x8) block + calcVSXColLoops<8>(indexA, indexA2, row, depth, cols, rows, pAlpha, indexBF32, strideA, strideB, offsetA, offsetB, bigSuffix, result); + // LHS (8x4) block + calcVSXColLoops<4>(indexA, indexA2, row, depth, cols, rows, pAlpha, indexBF32, strideA, strideB, offsetA, offsetB, bigSuffix, result); + // extra rows + if(rows & 3){ + // This index is the beginning of remaining block. + colVSXLoops<4, true>(depth, cols, rows, pAlpha, indexA, indexA2, indexBF32, strideA, strideB, offsetB, result + row); + } + + // Convert back to bfloat16 + convertArrayF32toBF16VSX(result, cols, rows, res); +} + +#undef MAX_BFLOAT16_ACC_VSX + #include "MatrixVectorProduct.h" +template +EIGEN_ALWAYS_INLINE void storeBF16fromResult(bfloat16* dst, Packet8bf data, Index resInc, Index extra) +{ + if (non_unit_stride) { + if (size < 8) { + pscatter_partial(dst + delta*resInc, data, resInc, extra); + } else { + pscatter(dst + delta*resInc, data, resInc); + } + } else { + if (size < 8) { + pstoreu_partial(dst + delta, data, extra); + } else { + pstoreu(dst + delta, data); + } + } +} + +template +EIGEN_ALWAYS_INLINE void convertPointerF32toBF16VSX(Index& i, float* result, Index rows, bfloat16*& dst, Index resInc = 1) +{ + constexpr Index extra = ((size < 8) ? 8 : size); + for(; i + size <= rows; i += extra, dst += extra*resInc){ + PacketBlock r32; + r32.packet[0] = convertF32toBF16VSX(result + i + 0); + if (size >= 16) { + r32.packet[1] = convertF32toBF16VSX(result + i + 8); + } + if (size >= 32) { + r32.packet[2] = convertF32toBF16VSX(result + i + 16); + r32.packet[3] = convertF32toBF16VSX(result + i + 24); + } + storeBF16fromResult(dst, r32.packet[0], resInc, rows & 7); + if (size >= 16) { + storeBF16fromResult(dst, r32.packet[1], resInc); + } + if (size >= 32) { + storeBF16fromResult(dst, r32.packet[2], resInc); + storeBF16fromResult(dst, r32.packet[3], resInc); + } + } +} + +template +EIGEN_ALWAYS_INLINE void convertArrayPointerF32toBF16VSX(float *result, Index rows, bfloat16* dst, Index resInc = 1) +{ + Index i = 0; + convertPointerF32toBF16VSX<32,non_unit_stride>(i, result, rows, dst, resInc); + convertPointerF32toBF16VSX<16,non_unit_stride>(i, result, rows, dst, resInc); + convertPointerF32toBF16VSX<8,non_unit_stride>(i, result, rows, dst, resInc); + convertPointerF32toBF16VSX<1,non_unit_stride>(i, result, rows, dst, resInc); +} + /************************************ * ppc64le template specializations * * **********************************/ @@ -2735,10 +3358,7 @@ void gemm_pack_rhs pack; pack(blockB, rhs, depth, cols, stride, offset); } -#endif -#if !EIGEN_ALTIVEC_DISABLE_MMA && defined(__MMA__) -#if EIGEN_ALTIVEC_USE_CUSTOM_PACK template struct gemm_pack_rhs { @@ -2795,7 +3415,6 @@ void gemm_pack_lhs pack; pack(blockA, lhs, depth, rows, stride, offset); } -#endif template struct gemm_pack_lhs @@ -2987,21 +3606,12 @@ void gebp_kernel::rows; const Index accCols = quad_traits::size; - void (*gemm_function)(const DataMapper&, const float*, const float*, Index, Index, Index, float, Index, Index, Index, Index); - - #if defined(EIGEN_ALTIVEC_MMA_ONLY) - //generate with MMA only - gemm_function = &Eigen::internal::gemmMMA; - #elif defined(EIGEN_ALTIVEC_MMA_DYNAMIC_DISPATCH) - if (__builtin_cpu_supports ("arch_3_1") && __builtin_cpu_supports ("mma")){ - gemm_function = &Eigen::internal::gemmMMA; - } - else{ - gemm_function = &Eigen::internal::gemm; - } - #else - gemm_function = &Eigen::internal::gemm; + static void (*gemm_function)(const DataMapper&, const float*, const float*, Index, Index, Index, float, Index, Index, Index, Index) = + #ifdef EIGEN_MATRIX_PRODUCT_MMA_ALTIVEC_H + (supportsMMA()) ? + &Eigen::internal::gemmMMA : #endif + &Eigen::internal::gemm; gemm_function(res, blockA, blockB, rows, depth, cols, alpha, strideA, strideB, offsetA, offsetB); } @@ -3025,22 +3635,13 @@ void gebp_kernel, std::complex, Index, DataMapper, mr { const Index accRows = quad_traits::rows; const Index accCols = quad_traits::size; - void (*gemm_function)(const DataMapper&, const std::complex*, const std::complex*, - Index, Index, Index, std::complex, Index, Index, Index, Index); - - #if defined(EIGEN_ALTIVEC_MMA_ONLY) - //generate with MMA only - gemm_function = &Eigen::internal::gemm_complexMMA, std::complex, std::complex, float, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, false, false>; - #elif defined(EIGEN_ALTIVEC_MMA_DYNAMIC_DISPATCH) - if (__builtin_cpu_supports ("arch_3_1") && __builtin_cpu_supports ("mma")){ - gemm_function = &Eigen::internal::gemm_complexMMA, std::complex, std::complex, float, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, false, false>; - } - else{ - gemm_function = &Eigen::internal::gemm_complex, std::complex, std::complex, float, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, false, false>; - } - #else - gemm_function = &Eigen::internal::gemm_complex, std::complex, std::complex, float, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, false, false>; + static void (*gemm_function)(const DataMapper&, const std::complex*, const std::complex*, + Index, Index, Index, std::complex, Index, Index, Index, Index) = + #ifdef EIGEN_MATRIX_PRODUCT_MMA_ALTIVEC_H + (supportsMMA()) ? + &Eigen::internal::gemm_complexMMA, std::complex, std::complex, float, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, false, false> : #endif + &Eigen::internal::gemm_complex, std::complex, std::complex, float, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, false, false>; gemm_function(res, blockA, blockB, rows, depth, cols, alpha, strideA, strideB, offsetA, offsetB); } @@ -3064,21 +3665,13 @@ void gebp_kernel, Index, DataMapper, mr, nr, Conjugat { const Index accRows = quad_traits::rows; const Index accCols = quad_traits::size; - void (*gemm_function)(const DataMapper&, const float*, const std::complex*, - Index, Index, Index, std::complex, Index, Index, Index, Index); - #if defined(EIGEN_ALTIVEC_MMA_ONLY) - //generate with MMA only - gemm_function = &Eigen::internal::gemm_complexMMA, std::complex, float, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, true, false>; - #elif defined(EIGEN_ALTIVEC_MMA_DYNAMIC_DISPATCH) - if (__builtin_cpu_supports ("arch_3_1") && __builtin_cpu_supports ("mma")){ - gemm_function = &Eigen::internal::gemm_complexMMA, std::complex, float, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, true, false>; - } - else{ - gemm_function = &Eigen::internal::gemm_complex, std::complex, float, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, true, false>; - } - #else - gemm_function = &Eigen::internal::gemm_complex, std::complex, float, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, true, false>; + static void (*gemm_function)(const DataMapper&, const float*, const std::complex*, + Index, Index, Index, std::complex, Index, Index, Index, Index) = + #ifdef EIGEN_MATRIX_PRODUCT_MMA_ALTIVEC_H + (supportsMMA()) ? + &Eigen::internal::gemm_complexMMA, std::complex, float, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, true, false> : #endif + &Eigen::internal::gemm_complex, std::complex, float, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, true, false>; gemm_function(res, blockA, blockB, rows, depth, cols, alpha, strideA, strideB, offsetA, offsetB); } @@ -3102,21 +3695,13 @@ void gebp_kernel, float, Index, DataMapper, mr, nr, Conjugat { const Index accRows = quad_traits::rows; const Index accCols = quad_traits::size; - void (*gemm_function)(const DataMapper&, const std::complex*, const float*, - Index, Index, Index, std::complex, Index, Index, Index, Index); - #if defined(EIGEN_ALTIVEC_MMA_ONLY) - //generate with MMA only - gemm_function = &Eigen::internal::gemm_complexMMA, float, std::complex, float, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, false, true>; - #elif defined(EIGEN_ALTIVEC_MMA_DYNAMIC_DISPATCH) - if (__builtin_cpu_supports ("arch_3_1") && __builtin_cpu_supports ("mma")){ - gemm_function = &Eigen::internal::gemm_complexMMA, float, std::complex, float, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, false, true>; - } - else{ - gemm_function = &Eigen::internal::gemm_complex, float, std::complex, float, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, false, true>; - } - #else - gemm_function = &Eigen::internal::gemm_complex, float, std::complex, float, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, false, true>; + static void (*gemm_function)(const DataMapper&, const std::complex*, const float*, + Index, Index, Index, std::complex, Index, Index, Index, Index) = + #ifdef EIGEN_MATRIX_PRODUCT_MMA_ALTIVEC_H + (supportsMMA()) ? + &Eigen::internal::gemm_complexMMA, float, std::complex, float, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, false, true> : #endif + &Eigen::internal::gemm_complex, float, std::complex, float, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, false, true>; gemm_function(res, blockA, blockB, rows, depth, cols, alpha, strideA, strideB, offsetA, offsetB); } @@ -3139,21 +3724,12 @@ void gebp_kernel::rows; const Index accCols = quad_traits::size; - void (*gemm_function)(const DataMapper&, const double*, const double*, Index, Index, Index, double, Index, Index, Index, Index); - - #if defined(EIGEN_ALTIVEC_MMA_ONLY) - //generate with MMA only - gemm_function = &Eigen::internal::gemmMMA; - #elif defined(EIGEN_ALTIVEC_MMA_DYNAMIC_DISPATCH) - if (__builtin_cpu_supports ("arch_3_1") && __builtin_cpu_supports ("mma")){ - gemm_function = &Eigen::internal::gemmMMA; - } - else{ - gemm_function = &Eigen::internal::gemm; - } - #else - gemm_function = &Eigen::internal::gemm; + static void (*gemm_function)(const DataMapper&, const double*, const double*, Index, Index, Index, double, Index, Index, Index, Index) = + #ifdef EIGEN_MATRIX_PRODUCT_MMA_ALTIVEC_H + (supportsMMA()) ? + &Eigen::internal::gemmMMA : #endif + &Eigen::internal::gemm; gemm_function(res, blockA, blockB, rows, depth, cols, alpha, strideA, strideB, offsetA, offsetB); } @@ -3177,21 +3753,13 @@ void gebp_kernel, std::complex, Index, DataMapper, { const Index accRows = quad_traits::rows; const Index accCols = quad_traits::size; - void (*gemm_function)(const DataMapper&, const std::complex*, const std::complex*, - Index, Index, Index, std::complex, Index, Index, Index, Index); - #if defined(EIGEN_ALTIVEC_MMA_ONLY) - //generate with MMA only - gemm_function = &Eigen::internal::gemm_complexMMA, std::complex, std::complex, double, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, false, false>; - #elif defined(EIGEN_ALTIVEC_MMA_DYNAMIC_DISPATCH) - if (__builtin_cpu_supports ("arch_3_1") && __builtin_cpu_supports ("mma")){ - gemm_function = &Eigen::internal::gemm_complexMMA, std::complex, std::complex, double, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, false, false>; - } - else{ - gemm_function = &Eigen::internal::gemm_complex, std::complex, std::complex, double, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, false, false>; - } - #else - gemm_function = &Eigen::internal::gemm_complex, std::complex, std::complex, double, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, false, false>; + static void (*gemm_function)(const DataMapper&, const std::complex*, const std::complex*, + Index, Index, Index, std::complex, Index, Index, Index, Index) = + #ifdef EIGEN_MATRIX_PRODUCT_MMA_ALTIVEC_H + (supportsMMA()) ? + &Eigen::internal::gemm_complexMMA, std::complex, std::complex, double, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, false, false> : #endif + &Eigen::internal::gemm_complex, std::complex, std::complex, double, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, false, false>; gemm_function(res, blockA, blockB, rows, depth, cols, alpha, strideA, strideB, offsetA, offsetB); } @@ -3215,21 +3783,13 @@ void gebp_kernel, double, Index, DataMapper, mr, nr, Conjug { const Index accRows = quad_traits::rows; const Index accCols = quad_traits::size; - void (*gemm_function)(const DataMapper&, const std::complex*, const double*, - Index, Index, Index, std::complex, Index, Index, Index, Index); - #if defined(EIGEN_ALTIVEC_MMA_ONLY) - //generate with MMA only - gemm_function = &Eigen::internal::gemm_complexMMA, double, std::complex, double, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, false, true>; - #elif defined(EIGEN_ALTIVEC_MMA_DYNAMIC_DISPATCH) - if (__builtin_cpu_supports ("arch_3_1") && __builtin_cpu_supports ("mma")){ - gemm_function = &Eigen::internal::gemm_complexMMA, double, std::complex, double, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, false, true>; - } - else{ - gemm_function = &Eigen::internal::gemm_complex, double, std::complex, double, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, false, true>; - } - #else - gemm_function = &Eigen::internal::gemm_complex, double, std::complex, double, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, false, true>; + static void (*gemm_function)(const DataMapper&, const std::complex*, const double*, + Index, Index, Index, std::complex, Index, Index, Index, Index) = + #ifdef EIGEN_MATRIX_PRODUCT_MMA_ALTIVEC_H + (supportsMMA()) ? + &Eigen::internal::gemm_complexMMA, double, std::complex, double, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, false, true> : #endif + &Eigen::internal::gemm_complex, double, std::complex, double, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, false, true>; gemm_function(res, blockA, blockB, rows, depth, cols, alpha, strideA, strideB, offsetA, offsetB); } @@ -3253,25 +3813,16 @@ void gebp_kernel, Index, DataMapper, mr, nr, Conjug { const Index accRows = quad_traits::rows; const Index accCols = quad_traits::size; - void (*gemm_function)(const DataMapper&, const double*, const std::complex*, - Index, Index, Index, std::complex, Index, Index, Index, Index); - #if defined(EIGEN_ALTIVEC_MMA_ONLY) - //generate with MMA only - gemm_function = &Eigen::internal::gemm_complexMMA, std::complex, double, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, true, false>; - #elif defined(EIGEN_ALTIVEC_MMA_DYNAMIC_DISPATCH) - if (__builtin_cpu_supports ("arch_3_1") && __builtin_cpu_supports ("mma")){ - gemm_function = &Eigen::internal::gemm_complexMMA, std::complex, double, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, true, false>; - } - else{ - gemm_function = &Eigen::internal::gemm_complex, std::complex, double, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, true, false>; - } - #else - gemm_function = &Eigen::internal::gemm_complex, std::complex, double, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, true, false>; + static void (*gemm_function)(const DataMapper&, const double*, const std::complex*, + Index, Index, Index, std::complex, Index, Index, Index, Index) = + #ifdef EIGEN_MATRIX_PRODUCT_MMA_ALTIVEC_H + (supportsMMA()) ? + &Eigen::internal::gemm_complexMMA, std::complex, double, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, true, false> : #endif + &Eigen::internal::gemm_complex, std::complex, double, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, true, false>; gemm_function(res, blockA, blockB, rows, depth, cols, alpha, strideA, strideB, offsetA, offsetB); } -#if !EIGEN_ALTIVEC_DISABLE_MMA && defined(__MMA__) template struct gebp_kernel { @@ -3289,9 +3840,14 @@ void gebp_kernel(res, blockA, blockB, rows, depth, cols, alpha, strideA, strideB, offsetA, offsetB); + static void (*gemm_function)(const DataMapper&, const bfloat16*, const bfloat16*, Index, Index, Index, bfloat16, Index, Index, Index, Index) = + #ifdef EIGEN_MATRIX_PRODUCT_MMA_ALTIVEC_H + (supportsMMA()) ? + &Eigen::internal::gemmMMAbfloat16 : + #endif + &Eigen::internal::gemmbfloat16; + gemm_function(res, blockA, blockB, rows, depth, cols, alpha, strideA, strideB, offsetA, offsetB); } -#endif } // end namespace internal } // end namespace Eigen diff --git a/Eigen/src/Core/arch/AltiVec/MatrixProductCommon.h b/Eigen/src/Core/arch/AltiVec/MatrixProductCommon.h index 1ac66292e..e89b5e557 100644 --- a/Eigen/src/Core/arch/AltiVec/MatrixProductCommon.h +++ b/Eigen/src/Core/arch/AltiVec/MatrixProductCommon.h @@ -84,6 +84,18 @@ EIGEN_ALWAYS_INLINE void gemm_complex_extra_cols( const Packet& pAlphaImag, const Packet& pMask); +template +EIGEN_ALWAYS_INLINE void convertArrayBF16toF32(float *result, Index cols, Index rows, const DataMapper& src); + +template +EIGEN_ALWAYS_INLINE void storeBF16fromResult(bfloat16* dst, Packet8bf data, Index resInc, Index extra = 0); + +template +EIGEN_ALWAYS_INLINE void convertArrayPointerBF16toF32(float *result, Index cols, Index rows, bfloat16* src, Index resInc = 1); + +template +EIGEN_ALWAYS_INLINE void storeResults(Packet4f (&acc)[4], Index rows, const Packet4f pAlpha, float* result, Index extra_cols, Index extra_rows); + template EIGEN_ALWAYS_INLINE Packet ploadLhs(const __UNPACK_TYPE__(Packet)* lhs); diff --git a/Eigen/src/Core/arch/AltiVec/MatrixProductMMA.h b/Eigen/src/Core/arch/AltiVec/MatrixProductMMA.h index 05d180c0c..e4013a747 100644 --- a/Eigen/src/Core/arch/AltiVec/MatrixProductMMA.h +++ b/Eigen/src/Core/arch/AltiVec/MatrixProductMMA.h @@ -28,9 +28,7 @@ #include "../../InternalHeaderCheck.h" -#if !EIGEN_ALTIVEC_DISABLE_MMA #include "MatrixProductMMAbfloat16.h" -#endif namespace Eigen { diff --git a/Eigen/src/Core/arch/AltiVec/MatrixProductMMAbfloat16.h b/Eigen/src/Core/arch/AltiVec/MatrixProductMMAbfloat16.h index fe4906d10..731fd9bcb 100644 --- a/Eigen/src/Core/arch/AltiVec/MatrixProductMMAbfloat16.h +++ b/Eigen/src/Core/arch/AltiVec/MatrixProductMMAbfloat16.h @@ -53,7 +53,7 @@ EIGEN_ALWAYS_INLINE void KLoop indexA += k*(lhsExtraRows ? extra_rows : num_packets); for(Index j = 0; j < num_lhs; j++) { - lhs[j] = loadBfloat16(indexA + j*(zero ? 4 : 8)); //a packet of bfloat16 has 8 elements + lhs[j] = loadBfloat16(indexA + j*(zero ? 4 : 8)); // a packet of bfloat16 has 8 elements } BFLOAT16_UNROLL @@ -65,46 +65,6 @@ EIGEN_ALWAYS_INLINE void KLoop } } -EIGEN_ALWAYS_INLINE Packet4f loadAndMultiplyF32(Packet4f acc, const Packet4f pAlpha, float* result) -{ - Packet4f result_block = ploadu(result); - return pmadd(acc, pAlpha, result_block); -} - -template -EIGEN_ALWAYS_INLINE void storeF32(float*& result, Packet4f result_block, Index rows, Index extra_rows) -{ - if (lhsExtraRows) { - pstoreu_partial(result, result_block, extra_rows); - } else { - pstoreu(result, result_block); - } - result += rows; -} - -template -EIGEN_ALWAYS_INLINE void storeResults(Packet4f (&acc)[4], Index rows, const Packet4f pAlpha, float* result, Index extra_cols, Index extra_rows) -{ - Index x = 0; - if (rhsExtraCols) { - do{ - Packet4f result_block = loadAndMultiplyF32(acc[x], pAlpha, result); - storeF32(result, result_block, rows, extra_rows); - } while (++x < extra_cols); - } else { - Packet4f result_block[4]; - float *result2 = result; - do{ - result_block[x] = loadAndMultiplyF32(acc[x], pAlpha, result); - result += rows; - } while (++x < 4); - x = 0; - do{ - storeF32(result2, result_block[x], rows, extra_rows); - } while (++x < 4); - } -} - template EIGEN_ALWAYS_INLINE void zeroAccumulators(__vector_quad (&quad_acc)[num_acc]) { @@ -165,17 +125,14 @@ EIGEN_ALWAYS_INLINE void colLoopBodyIter(Index depth, Index rows, const Packet4f template void colLoopBody(Index& col, Index depth, Index cols, Index rows, const Packet4f pAlpha, const bfloat16* indexA, const bfloat16* indexB, Index strideB, Index offsetB, float* result) { - constexpr Index step = (num_acc * 4); //each accumulator has 4 elements + constexpr Index step = (num_acc * 4); // each accumulator has 4 elements const Index extra_cols = (rhsExtraCols) ? (cols & 3) : 0; const Index extra_rows = (lhsExtraRows) ? (rows & 3) : 0; constexpr bool multiIters = !rhsExtraCols && (num_acc == MAX_BFLOAT16_ACC); + constexpr bool normIters = multiIters && ((num_acc % (num_packets / 4)) == 0); do{ - if (multiIters && ((num_acc % (num_packets / 4)) == 0)) { - colLoopBodyIter(depth, rows, pAlpha, indexA, indexB, strideB, offsetB, result, extra_cols, extra_rows); - } else { - colLoopBodyIter(depth, rows, pAlpha, indexA, indexB, strideB, offsetB, result, extra_cols, extra_rows); - } + colLoopBodyIter(depth, rows, pAlpha, indexA, indexB, strideB, offsetB, result, extra_cols, extra_rows); indexB += strideB*num_acc; result += rows*step; @@ -239,104 +196,89 @@ EIGEN_ALWAYS_INLINE void colLoops(Index depth, Index cols, Index rows, const Pac } } -template EIGEN_ALWAYS_INLINE Packet8bf convertF32toBF16(const float *res) { - Packet16uc fp16_0 = __builtin_vsx_xvcvspbf16(reinterpret_cast(ploadu(res + 0))); - Packet16uc fp16_1 = (full) ? __builtin_vsx_xvcvspbf16(reinterpret_cast(ploadu(res + 4))) : fp16_0; - return vec_pack(reinterpret_cast(fp16_0), reinterpret_cast(fp16_1)); + Packet16uc fp16[2]; +#if EIGEN_COMP_LLVM + __vector_pair fp16_vp = *reinterpret_cast<__vector_pair *>(const_cast(res)); + __builtin_vsx_disassemble_pair(reinterpret_cast(fp16), &fp16_vp); + fp16[0] = __builtin_vsx_xvcvspbf16(fp16[0]); + fp16[1] = __builtin_vsx_xvcvspbf16(fp16[1]); +#else + fp16[0] = __builtin_vsx_xvcvspbf16(reinterpret_cast(ploadu(res + 0))); + fp16[1] = __builtin_vsx_xvcvspbf16(reinterpret_cast(ploadu(res + 4))); +#endif + return vec_pack(reinterpret_cast(fp16[0]), reinterpret_cast(fp16[1])); } -template -EIGEN_ALWAYS_INLINE void storeConvertBlockBF16(float* to, PacketBlock& block) +template +EIGEN_ALWAYS_INLINE void convertArrayF32toBF16Col(float *result, Index col, Index rows, const DataMapper& res) { - Packet8us z = pset1(0); - pstore(to + 0, reinterpret_cast(vec_mergeh(z, block.packet[0].m_val))); - if (N >= 8) { - pstore(to + 4, reinterpret_cast(vec_mergel(z, block.packet[0].m_val))); + const DataMapper res2 = res.getSubMapper(0, col); + Index row; + float *result2 = result + col*rows; + for(row = 0; row + 8 <= rows; row += 8){ + // get and save block + PacketBlock block; + for(Index j = 0; j < size; j++){ + block.packet[j] = convertF32toBF16(result2 + j*rows + row); + } + res2.template storePacketBlock(row, 0, block); } - if (N >= 16) { - pstore(to + 8, reinterpret_cast(vec_mergeh(z, block.packet[1].m_val))); - pstore(to + 12, reinterpret_cast(vec_mergel(z, block.packet[1].m_val))); - } - if (N >= 32) { - pstore(to + 16, reinterpret_cast(vec_mergeh(z, block.packet[2].m_val))); - pstore(to + 20, reinterpret_cast(vec_mergel(z, block.packet[2].m_val))); - pstore(to + 24, reinterpret_cast(vec_mergeh(z, block.packet[3].m_val))); - pstore(to + 28, reinterpret_cast(vec_mergel(z, block.packet[3].m_val))); + // extra rows + if(row < rows){ + for(Index j = 0; j < size; j++){ + Packet8bf fp16 = convertF32toBF16(result2 + j*rows + row); + res2.template storePacketPartial(row, j, fp16, rows & 7); + } } } -template -EIGEN_ALWAYS_INLINE void convertBF16toF32(Index& i, float *result, Index rows, const DataMapper& src) +template +EIGEN_ALWAYS_INLINE void convertPointerF32toBF16(Index& i, float* result, Index rows, bfloat16*& dst, Index resInc = 1) { - for(; i + size <= rows; i += size){ - PacketBlock r32; - r32.packet[0] = src.template loadPacket(i + 0); + constexpr Index extra = ((size < 8) ? 8 : size); + for(; i + size <= rows; i += extra, dst += extra*resInc){ + PacketBlock r32; + r32.packet[0] = convertF32toBF16(result + i + 0); if (size >= 16) { - r32.packet[1] = src.template loadPacket(i + 8); + r32.packet[1] = convertF32toBF16(result + i + 8); } if (size >= 32) { - r32.packet[2] = src.template loadPacket(i + 16); - r32.packet[3] = src.template loadPacket(i + 24); + r32.packet[2] = convertF32toBF16(result + i + 16); + r32.packet[3] = convertF32toBF16(result + i + 24); + } + storeBF16fromResult(dst, r32.packet[0], resInc, rows & 7); + if (size >= 16) { + storeBF16fromResult(dst, r32.packet[1], resInc); + } + if (size >= 32) { + storeBF16fromResult(dst, r32.packet[2], resInc); + storeBF16fromResult(dst, r32.packet[3], resInc); } - storeConvertBlockBF16(result + i, r32); } } -template -EIGEN_ALWAYS_INLINE void convertArrayBF16toF32(float *result, Index cols, Index rows, const DataMapper& src) +template +EIGEN_ALWAYS_INLINE void convertArrayPointerF32toBF16(float *result, Index rows, bfloat16* dst, Index resInc = 1) { - typedef typename DataMapper::LinearMapper LinearMapper; - for(Index j = 0; j < cols; j++, result += rows){ - const LinearMapper src2 = src.getLinearMapper(0, j); - Index i = 0; - convertBF16toF32<32, LinearMapper>(i, result, rows, src2); - convertBF16toF32<16, LinearMapper>(i, result, rows, src2); - convertBF16toF32<8, LinearMapper>(i, result, rows, src2); - convertBF16toF32<4, LinearMapper>(i, result, rows, src2); - for(; i < rows; i++){ - result[i] = Eigen::bfloat16_impl::bfloat16_to_float(src2(i)); - } - } + Index i = 0; + convertPointerF32toBF16<32,non_unit_stride>(i, result, rows, dst, resInc); + convertPointerF32toBF16<16,non_unit_stride>(i, result, rows, dst, resInc); + convertPointerF32toBF16<8,non_unit_stride>(i, result, rows, dst, resInc); + convertPointerF32toBF16<1,non_unit_stride>(i, result, rows, dst, resInc); } template EIGEN_ALWAYS_INLINE void convertArrayF32toBF16(float *result, Index cols, Index rows, const DataMapper& res) { - typedef typename DataMapper::LinearMapper LinearMapper; - Index col, row; + Index col; for(col = 0; col + 4 <= cols; col += 4){ - const DataMapper res2 = res.getSubMapper(0, col); - for(row = 0; row + 8 <= rows; row += 8){ - //get and save block - PacketBlock block; - for(Index j = 0; j < 4; j++){ - block.packet[j].m_val = convertF32toBF16(result + (col + j)*rows + row); - } - - res2.template storePacketBlock(row, 0, block); - } - //extra rows - while(row < rows){ - for(Index col_off = 0; col_off < 4; col_off++){ - res2(row, col_off) = Eigen::bfloat16(result[(col+col_off)*rows+row]); - } - row++; - } - + convertArrayF32toBF16Col(result, col, rows, res); } - //extra cols + // extra cols while(col < cols){ - const LinearMapper res2 = res.getLinearMapper(0, col); - float *result2 = result + col*rows; - for(row = 0; row + 8 <= rows; row += 8){ - Packet8bf fp16 = convertF32toBF16(result2 + row); - res2.template storePacket(row, fp16); - } - for(; row < rows; row++){ - res2(row) = Eigen::bfloat16(result2[row]); - } + convertArrayF32toBF16Col(result, col, rows, res); col++; } } @@ -361,134 +303,42 @@ void gemmMMAbfloat16(const DataMapper& res, const bfloat16* indexA, const bfloat convertArrayBF16toF32(result, cols, rows, res); - Index row = 0; - if( strideA == -1 ) strideA = depth; if( strideB == -1 ) strideB = depth; - //Packing is done in blocks. - //There's 4 possible sizes of blocks - //Blocks of 8 columns with 16 elements (8x16) - //Blocks of 8 columns with 8 elements (8x8). This happens when there's 16 > rows >= 8 - //Blocks of 8 columns with 4 elements (8x4). This happens when there's 8 > rows >= 4 - //Blocks of 8 columns with < 4 elements. This happens when there's less than 4 remaining rows + // Packing is done in blocks. + // There's 4 possible sizes of blocks + // Blocks of 8 columns with 16 elements (8x16) + // Blocks of 8 columns with 8 elements (8x8). This happens when there's 16 > rows >= 8 + // Blocks of 8 columns with 4 elements (8x4). This happens when there's 8 > rows >= 4 + // Blocks of 8 columns with < 4 elements. This happens when there's less than 4 remaining rows - //Loop for LHS standard block (8x16) + // Loop for LHS standard block (8x16) Index bigSuffix = (2*8) * (strideA-offsetA); indexB += 4*offsetB; strideB *= 4; offsetB *= 3; + + Index row = 0; while(row + 16 <= rows){ calcColLoops<16>(indexA, row, depth, cols, rows, pAlpha, indexB, strideB, offsetA, offsetB, bigSuffix, result); } - //LHS (8x8) block + // LHS (8x8) block calcColLoops<8>(indexA, row, depth, cols, rows, pAlpha, indexB, strideB, offsetA, offsetB, bigSuffix, result); - //LHS (8x4) block + // LHS (8x4) block calcColLoops<4>(indexA, row, depth, cols, rows, pAlpha, indexB, strideB, offsetA, offsetB, bigSuffix, result); - //extra rows + // extra rows if(rows & 3){ - //This index is the beginning of remaining block. + // This index is the beginning of remaining block. colLoops<4, true>(depth, cols, rows, pAlpha, indexA, indexB, strideB, offsetB, result + row); } - //Convert back to bfloat16 + // Convert back to bfloat16 convertArrayF32toBF16(result, cols, rows, res); } -template -EIGEN_ALWAYS_INLINE void storeBF16fromResult(bfloat16* dst, Packet8bf data, Index resInc) -{ - if (inc) { - if (size == 4) { - pscatter_partial(dst + delta*resInc, data, resInc, 4); - } else { - pscatter(dst + delta*resInc, data, resInc); - } - } else { - if (size == 4) { - pstoreu_partial(dst + delta, data, 4); - } else { - pstoreu(dst + delta, data); - } - } -} - -template -EIGEN_ALWAYS_INLINE void convertPointerF32toBF16(Index& i, float* result, Index rows, bfloat16*& dst, Index resInc) -{ - for(; i + size <= rows; i += size, dst += size*resInc){ - PacketBlock r32; - r32.packet[0] = convertF32toBF16(result + i + 0); - if (size >= 16) { - r32.packet[1] = convertF32toBF16(result + i + 8); - } - if (size >= 32) { - r32.packet[2] = convertF32toBF16(result + i + 16); - r32.packet[3] = convertF32toBF16(result + i + 24); - } - storeBF16fromResult(dst, r32.packet[0], resInc); - if (size >= 16) { - storeBF16fromResult(dst, r32.packet[1], resInc); - } - if (size >= 32) { - storeBF16fromResult(dst, r32.packet[2], resInc); - storeBF16fromResult(dst, r32.packet[3], resInc); - } - } -} - -template -EIGEN_ALWAYS_INLINE Packet8bf loadBF16fromResult(bfloat16* src, Index resInc) -{ - if (inc) { - return pgather(src + delta*resInc, resInc); - } else { - return ploadu(src + delta); - } -} - -template -EIGEN_ALWAYS_INLINE void convertPointerBF16toF32(Index& i, float *result, Index rows, bfloat16*& src, Index resInc) -{ - for(; i + size <= rows; i += size, src += size*resInc){ - PacketBlock r32; - r32.packet[0] = loadBF16fromResult(src, resInc); - if (size >= 16) { - r32.packet[1] = loadBF16fromResult(src, resInc); - } - if (size >= 32) { - r32.packet[2] = loadBF16fromResult(src, resInc); - r32.packet[3] = loadBF16fromResult(src, resInc); - } - storeConvertBlockBF16(result + i, r32); - } -} - -template -EIGEN_ALWAYS_INLINE void convertArrayPointerBF16toF32(float *result, Index rows, bfloat16* src, Index resInc = 1) -{ - Index i = 0; - convertPointerBF16toF32<32, inc>(i, result, rows, src, resInc); - convertPointerBF16toF32<16, inc>(i, result, rows, src, resInc); - convertPointerBF16toF32<8, inc>(i, result, rows, src, resInc); - convertPointerBF16toF32<4, inc>(i, result, rows, src, resInc); - for(; i < rows; i++, src += resInc){ - result[i] = Eigen::bfloat16_impl::bfloat16_to_float(*src); - } -} - -template -EIGEN_ALWAYS_INLINE void convertArrayPointerF32toBF16(float *result, Index rows, bfloat16* dst, Index resInc = 1) -{ - Index i = 0; - convertPointerF32toBF16<32,inc>(i, result, rows, dst, resInc); - convertPointerF32toBF16<16,inc>(i, result, rows, dst, resInc); - convertPointerF32toBF16<8,inc>(i, result, rows, dst, resInc); - convertPointerF32toBF16<4,inc>(i, result, rows, dst, resInc); - for(; i < rows; i++, dst += resInc){ - *dst = Eigen::bfloat16(result[i]); - } -} +#undef MAX_BFLOAT16_ACC +#if !EIGEN_ALTIVEC_DISABLE_MMA template EIGEN_ALWAYS_INLINE void outputVecCol(Packet4f acc, float *result, Packet4f pAlpha, Index extra_rows) { @@ -667,7 +517,7 @@ void gemvMMA_bfloat16_col( ei_declare_aligned_stack_constructed_variable(float, result, rows, 0); - convertArrayPointerBF16toF32(result, rows, res); + convertArrayPointerBF16toF32(result, 1, rows, res); for (Index j2 = 0; j2 < cols; j2 += block_cols) { @@ -867,9 +717,9 @@ EIGEN_STRONG_INLINE void gemvMMA_bfloat16_row( ei_declare_aligned_stack_constructed_variable(float, result, rows, 0); if (resIncr == 1) { - convertArrayPointerBF16toF32(result, rows, res); + convertArrayPointerBF16toF32(result, 1, rows, res); } else { - convertArrayPointerBF16toF32(result, rows, res, resIncr); + convertArrayPointerBF16toF32(result, 1, rows, res, resIncr); } calcVecLoops(cols, rows, lhs, rhs2, pAlpha, result); if (resIncr == 1) { @@ -878,6 +728,10 @@ EIGEN_STRONG_INLINE void gemvMMA_bfloat16_row( convertArrayPointerF32toBF16(result, rows, res, resIncr); } } +#endif + +#undef MAX_BFLOAT16_VEC_ACC +#undef BFLOAT16_UNROLL } } diff --git a/Eigen/src/Core/arch/AltiVec/MatrixVectorProduct.h b/Eigen/src/Core/arch/AltiVec/MatrixVectorProduct.h index e107335b7..fa29b34e3 100644 --- a/Eigen/src/Core/arch/AltiVec/MatrixVectorProduct.h +++ b/Eigen/src/Core/arch/AltiVec/MatrixVectorProduct.h @@ -17,8 +17,8 @@ #define USE_GEMV_MMA #endif -#if !EIGEN_COMP_LLVM && (__GNUC__ == 10 && __GNUC_MINOR__ <= 3) -// Only allow one vector_pair in buggy gcc - gcc 10.3 has a bug +#if !EIGEN_COMP_LLVM && (__GNUC__ < 11) +// Only allow one vector_pair in buggy gcc - gcc 10.x has a bug #define GCC_ONE_VECTORPAIR_BUG #endif #endif diff --git a/Eigen/src/Core/arch/AltiVec/PacketMath.h b/Eigen/src/Core/arch/AltiVec/PacketMath.h index d477ab721..e0168513d 100644 --- a/Eigen/src/Core/arch/AltiVec/PacketMath.h +++ b/Eigen/src/Core/arch/AltiVec/PacketMath.h @@ -35,6 +35,7 @@ typedef __vector unsigned int Packet4ui; typedef __vector __bool int Packet4bi; typedef __vector short int Packet8s; typedef __vector unsigned short int Packet8us; +typedef __vector __bool short Packet8bi; typedef __vector signed char Packet16c; typedef __vector unsigned char Packet16uc; typedef eigen_packet_wrapper<__vector unsigned short int,0> Packet8bf; @@ -83,10 +84,7 @@ static EIGEN_DECLARE_CONST_FAST_Packet4i(MINUS16,-16); //{ -16, -16, -16, -16} static EIGEN_DECLARE_CONST_FAST_Packet4i(MINUS1,-1); //{ -1, -1, -1, -1} static EIGEN_DECLARE_CONST_FAST_Packet4ui(SIGN, 0x80000000u); static EIGEN_DECLARE_CONST_FAST_Packet4ui(PREV0DOT5, 0x3EFFFFFFu); -#ifndef __POWER8_VECTOR__ static EIGEN_DECLARE_CONST_FAST_Packet8us(ONE,1); //{ 1, 1, 1, 1, 1, 1, 1, 1} -static EIGEN_DECLARE_CONST_FAST_Packet16uc(ONE,1); -#endif static Packet4f p4f_MZERO = (Packet4f) vec_sl((Packet4ui)p4i_MINUS1, (Packet4ui)p4i_MINUS1); //{ 0x80000000, 0x80000000, 0x80000000, 0x80000000} #ifndef __VSX__ static Packet4f p4f_ONE = vec_ctf(p4i_ONE, 0); //{ 1.0, 1.0, 1.0, 1.0} @@ -116,6 +114,14 @@ static const Packet16uc p16uc_DUPLICATE16_ODD = { 2,3 ,2,3, 6,7, 6,7, 10,11, 10, static Packet16uc p16uc_QUADRUPLICATE16_HI = { 0,1,0,1,0,1,0,1, 2,3,2,3,2,3,2,3 }; +static Packet16uc p16uc_MERGEE16 = { 0,1, 16,17, 4,5, 20,21, 8,9, 24,25, 12,13, 28,29 }; +static Packet16uc p16uc_MERGEO16 = { 2,3, 18,19, 6,7, 22,23, 10,11, 26,27, 14,15, 30,31 }; +#ifdef _BIG_ENDIAN +static Packet16uc p16uc_MERGEH16 = { 0,1, 4,5, 8,9, 12,13, 16,17, 20,21, 24,25, 28,29 }; +#else +static Packet16uc p16uc_MERGEL16 = { 2,3, 6,7, 10,11, 14,15, 18,19, 22,23, 26,27, 30,31 }; +#endif + // Handle endianness properly while loading constants // Define global static constants: #ifdef _BIG_ENDIAN @@ -537,31 +543,20 @@ EIGEN_ALWAYS_INLINE Packet pload_partial_common(const __UNPACK_TYPE__(Packet)* f } return load; #else - EIGEN_ALIGN16 __UNPACK_TYPE__(Packet) load[packet_size]; - unsigned char* load2 = reinterpret_cast(load + offset); - unsigned char* from2 = reinterpret_cast(const_cast<__UNPACK_TYPE__(Packet)*>(from)); - Index n2 = n * size; - Index i = 0; - if (16 <= n2) { - pstoreu(load2, ploadu(from2)); - i += 16; + if (n) { + EIGEN_ALIGN16 __UNPACK_TYPE__(Packet) load[packet_size]; + unsigned char* load2 = reinterpret_cast(load + offset); + unsigned char* from2 = reinterpret_cast(const_cast<__UNPACK_TYPE__(Packet)*>(from)); + Index n2 = n * size; + if (16 <= n2) { + pstoreu(load2, ploadu(from2)); + } else { + memcpy((void *)load2, (void *)from2, n2); + } + return pload_ignore(load); + } else { + return Packet(pset1(0)); } - if (i + 8 <= n2) { - *reinterpret_cast(load2 + i) = *reinterpret_cast(from2 + i); - i += 8; - } - if (i + 4 <= n2) { - *reinterpret_cast(load2 + i) = *reinterpret_cast(from2 + i); - i += 4; - } - if (i + 2 <= n2) { - *reinterpret_cast(load2 + i) = *reinterpret_cast(from2 + i); - i += 2; - } - if (i < n2) { - *reinterpret_cast(load2 + i) = *reinterpret_cast(from2 + i); - } - return pload_ignore(load); #endif } @@ -635,7 +630,7 @@ template<> EIGEN_STRONG_INLINE void pstore(unsigned short in template<> EIGEN_STRONG_INLINE void pstore(bfloat16* to, const Packet8bf& from) { - pstore_common(reinterpret_cast(to), from); + pstore_common(reinterpret_cast(to), from.m_val); } template<> EIGEN_STRONG_INLINE void pstore(signed char* to, const Packet16c& from) @@ -670,30 +665,17 @@ template EIGEN_ALWAYS_INLINE void pstore_partial_common(__UNPAC } vec_xst_len(store, to, n * size); #else - EIGEN_ALIGN16 __UNPACK_TYPE__(Packet) store[packet_size]; - pstore(store, from); - unsigned char* store2 = reinterpret_cast(store + offset); - unsigned char* to2 = reinterpret_cast(to); - Index n2 = n * size; - Index i = 0; - if (16 <= n2) { - pstore(to2, ploadu(store2)); - i += 16; - } - if (i + 8 <= n2) { - *reinterpret_cast(to2 + i) = *reinterpret_cast(store2 + i); - i += 8; - } - if (i + 4 <= n2) { - *reinterpret_cast(to2 + i) = *reinterpret_cast(store2 + i); - i += 4; - } - if (i + 2 <= n2) { - *reinterpret_cast(to2 + i) = *reinterpret_cast(store2 + i); - i += 2; - } - if (i < n2) { - *reinterpret_cast(to2 + i) = *reinterpret_cast(store2 + i); + if (n) { + EIGEN_ALIGN16 __UNPACK_TYPE__(Packet) store[packet_size]; + pstore(store, from); + unsigned char* store2 = reinterpret_cast(store + offset); + unsigned char* to2 = reinterpret_cast(to); + Index n2 = n * size; + if (16 <= n2) { + pstore(to2, ploadu(store2)); + } else { + memcpy((void *)to2, (void *)store2, n2); + } } #endif } @@ -720,7 +702,7 @@ template<> EIGEN_ALWAYS_INLINE void pstore_partial(unsigned template<> EIGEN_ALWAYS_INLINE void pstore_partial(bfloat16* to, const Packet8bf& from, const Index n, const Index offset) { - pstore_partial_common(reinterpret_cast(to), from, n, offset); + pstore_partial_common(reinterpret_cast(to), from.m_val, n, offset); } template<> EIGEN_ALWAYS_INLINE void pstore_partial(signed char* to, const Packet16c& from, const Index n, const Index offset) @@ -1003,6 +985,22 @@ template<> EIGEN_STRONG_INLINE Packet4f pnegate(const Packet4f& a) return vec_xor(a, p4f_MZERO); #endif } +template<> EIGEN_STRONG_INLINE Packet16c pnegate(const Packet16c& a) +{ +#ifdef __POWER8_VECTOR__ + return vec_neg(a); +#else + return reinterpret_cast(p4i_ZERO) - a; +#endif +} +template<> EIGEN_STRONG_INLINE Packet8s pnegate(const Packet8s& a) +{ +#ifdef __POWER8_VECTOR__ + return vec_neg(a); +#else + return reinterpret_cast(p4i_ZERO) - a; +#endif +} template<> EIGEN_STRONG_INLINE Packet4i pnegate(const Packet4i& a) { #ifdef __POWER8_VECTOR__ @@ -1102,7 +1100,7 @@ template<> EIGEN_STRONG_INLINE Packet16uc pmax(const Packet16uc& a, template<> EIGEN_STRONG_INLINE Packet4f pcmp_le(const Packet4f& a, const Packet4f& b) { return reinterpret_cast(vec_cmple(a,b)); } // To fix bug with vec_cmplt on older versions -#if defined(__POWER8_VECTOR__) || EIGEN_COMP_LLVM +#ifdef EIGEN_VECTORIZE_VSX template<> EIGEN_STRONG_INLINE Packet4f pcmp_lt(const Packet4f& a, const Packet4f& b) { return reinterpret_cast(vec_cmplt(a,b)); } #endif template<> EIGEN_STRONG_INLINE Packet4f pcmp_eq(const Packet4f& a, const Packet4f& b) { return reinterpret_cast(vec_cmpeq(a,b)); } @@ -1256,31 +1254,20 @@ template EIGEN_ALWAYS_INLINE Packet ploadu_partial_common(const EIGEN_DEBUG_UNALIGNED_LOAD return vec_xl_len(const_cast<__UNPACK_TYPE__(Packet)*>(from), n * size); #else - EIGEN_ALIGN16 __UNPACK_TYPE__(Packet) load[packet_size]; - unsigned char* load2 = reinterpret_cast(load); - unsigned char* from2 = reinterpret_cast(const_cast<__UNPACK_TYPE__(Packet)*>(from)); - Index n2 = n * size; - Index i = 0; - if (16 <= n2) { - pstore(load2, ploadu(from2)); - i += 16; + if (n) { + EIGEN_ALIGN16 __UNPACK_TYPE__(Packet) load[packet_size]; + unsigned char* load2 = reinterpret_cast(load); + unsigned char* from2 = reinterpret_cast(const_cast<__UNPACK_TYPE__(Packet)*>(from)); + Index n2 = n * size; + if (16 <= n2) { + pstore(load2, ploadu(from2)); + } else { + memcpy((void *)load2, (void *)from2, n2); + } + return pload_ignore(load); + } else { + return Packet(pset1(0)); } - if (i + 8 <= n2) { - *reinterpret_cast(load2 + i) = *reinterpret_cast(from2 + i); - i += 8; - } - if (i + 4 <= n2) { - *reinterpret_cast(load2 + i) = *reinterpret_cast(from2 + i); - i += 4; - } - if (i + 2 <= n2) { - *reinterpret_cast(load2 + i) = *reinterpret_cast(from2 + i); - i += 2; - } - if (i < n2) { - *reinterpret_cast(load2 + i) = *reinterpret_cast(from2 + i); - } - return pload_ignore(load); #endif } @@ -1422,7 +1409,7 @@ template<> EIGEN_STRONG_INLINE void pstoreu(unsigned short i } template<> EIGEN_STRONG_INLINE void pstoreu(bfloat16* to, const Packet8bf& from) { - pstoreu_common(reinterpret_cast(to), from); + pstoreu_common(reinterpret_cast(to), from.m_val); } template<> EIGEN_STRONG_INLINE void pstoreu(signed char* to, const Packet16c& from) { @@ -1443,30 +1430,17 @@ template EIGEN_ALWAYS_INLINE void pstoreu_partial_common(__UNPA EIGEN_DEBUG_UNALIGNED_STORE vec_xst_len(from, to, n * size); #else - EIGEN_ALIGN16 __UNPACK_TYPE__(Packet) store[packet_size]; - pstore(store, from); - unsigned char* store2 = reinterpret_cast(store); - unsigned char* to2 = reinterpret_cast(to); - Index n2 = n * size; - Index i = 0; - if (16 <= n2) { - pstoreu(to2, pload(store2)); - i += 16; - } - if (i + 8 <= n2) { - *reinterpret_cast(to2 + i) = *reinterpret_cast(store2 + i); - i += 8; - } - if (i + 4 <= n2) { - *reinterpret_cast(to2 + i) = *reinterpret_cast(store2 + i); - i += 4; - } - if (i + 2 <= n2) { - *reinterpret_cast(to2 + i) = *reinterpret_cast(store2 + i); - i += 2; - } - if (i < n2) { - *reinterpret_cast(to2 + i) = *reinterpret_cast(store2 + i); + if (n) { + EIGEN_ALIGN16 __UNPACK_TYPE__(Packet) store[packet_size]; + pstore(store, from); + unsigned char* store2 = reinterpret_cast(store); + unsigned char* to2 = reinterpret_cast(to); + Index n2 = n * size; + if (16 <= n2) { + pstoreu(to2, pload(store2)); + } else { + memcpy((void *)to2, (void *)store2, n2); + } } #endif } @@ -1636,17 +1610,37 @@ EIGEN_STRONG_INLINE Packet4f Bf16ToF32Odd(const Packet8bf& bf){ ); } +EIGEN_ALWAYS_INLINE Packet8us pmerge(Packet4ui even, Packet4ui odd) { +#ifdef _BIG_ENDIAN + return vec_perm(reinterpret_cast(odd), reinterpret_cast(even), p16uc_MERGEO16); +#else + return vec_perm(reinterpret_cast(even), reinterpret_cast(odd), p16uc_MERGEE16); +#endif +} + // Simple interleaving of bool masks, prevents true values from being // converted to NaNs. EIGEN_STRONG_INLINE Packet8bf F32ToBf16Bool(Packet4f even, Packet4f odd) { - const EIGEN_DECLARE_CONST_FAST_Packet4ui(high_mask, 0xFFFF0000); - Packet4f bf_odd, bf_even; - bf_odd = pand(reinterpret_cast(p4ui_high_mask), odd); - bf_even = plogical_shift_right<16>(even); - return reinterpret_cast(por(bf_even, bf_odd)); + return pmerge(reinterpret_cast(even), reinterpret_cast(odd)); } +//#define SUPPORT_BF16_SUBNORMALS + +#ifndef __VEC_CLASS_FP_NAN +#define __VEC_CLASS_FP_NAN (1<<6) +#endif + +#if defined(SUPPORT_BF16_SUBNORMALS) && !defined(__VEC_CLASS_FP_SUBNORMAL) +#define __VEC_CLASS_FP_SUBNORMAL_P (1<<1) +#define __VEC_CLASS_FP_SUBNORMAL_N (1<<0) + +#define __VEC_CLASS_FP_SUBNORMAL (__VEC_CLASS_FP_SUBNORMAL_P | __VEC_CLASS_FP_SUBNORMAL_N) +#endif + EIGEN_STRONG_INLINE Packet8bf F32ToBf16(Packet4f p4f){ +#ifdef _ARCH_PWR10 + return reinterpret_cast(__builtin_vsx_xvcvspbf16(reinterpret_cast(p4f))); +#else Packet4ui input = reinterpret_cast(p4f); Packet4ui lsb = plogical_shift_right<16>(input); lsb = pand(lsb, reinterpret_cast(p4i_ONE)); @@ -1655,43 +1649,202 @@ EIGEN_STRONG_INLINE Packet8bf F32ToBf16(Packet4f p4f){ Packet4ui rounding_bias = padd(lsb, p4ui_BIAS); input = padd(input, rounding_bias); - //Test NaN and Subnormal - Begin + const EIGEN_DECLARE_CONST_FAST_Packet4ui(nan, 0x7FC00000); +#ifdef _ARCH_PWR9 + Packet4bi nan_selector = vec_test_data_class(p4f, __VEC_CLASS_FP_NAN); + input = vec_sel(input, p4ui_nan, nan_selector); + +#ifdef SUPPORT_BF16_SUBNORMALS + Packet4bi subnormal_selector = vec_test_data_class(p4f, __VEC_CLASS_FP_SUBNORMAL); + input = vec_sel(input, reinterpret_cast(p4f), subnormal_selector); +#endif +#else +#ifdef SUPPORT_BF16_SUBNORMALS + //Test NaN and Subnormal const EIGEN_DECLARE_CONST_FAST_Packet4ui(exp_mask, 0x7F800000); Packet4ui exp = pand(p4ui_exp_mask, reinterpret_cast(p4f)); const EIGEN_DECLARE_CONST_FAST_Packet4ui(mantissa_mask, 0x7FFFFF); Packet4ui mantissa = pand(p4ui_mantissa_mask, reinterpret_cast(p4f)); - const EIGEN_DECLARE_CONST_FAST_Packet4ui(max_exp, 0x7F800000); - Packet4bi is_max_exp = vec_cmpeq(exp, p4ui_max_exp); - Packet4bi is_zero_exp = vec_cmpeq(exp, reinterpret_cast(p4i_ZERO)); - + Packet4bi is_max_exp = vec_cmpeq(exp, p4ui_exp_mask); Packet4bi is_mant_zero = vec_cmpeq(mantissa, reinterpret_cast(p4i_ZERO)); + Packet4ui nan_selector = pandnot( reinterpret_cast(is_max_exp), reinterpret_cast(is_mant_zero) ); + Packet4bi is_zero_exp = vec_cmpeq(exp, reinterpret_cast(p4i_ZERO)); + Packet4ui subnormal_selector = pandnot( reinterpret_cast(is_zero_exp), reinterpret_cast(is_mant_zero) ); - const EIGEN_DECLARE_CONST_FAST_Packet4ui(nan, 0x7FC00000); input = vec_sel(input, p4ui_nan, nan_selector); input = vec_sel(input, reinterpret_cast(p4f), subnormal_selector); - //Test NaN and Subnormal - End +#else + //Test only NaN + Packet4bi nan_selector = vec_cmpeq(p4f, p4f); + + input = vec_sel(p4ui_nan, input, nan_selector); +#endif +#endif input = plogical_shift_right<16>(input); return reinterpret_cast(input); +#endif } +#ifdef _BIG_ENDIAN +/** + * Pack the high portion of two float Packets into one bfloat16 Packet + * + * @param lohi to expect either a low & high OR odd & even order + */ +template +EIGEN_ALWAYS_INLINE Packet8bf Bf16PackHigh(Packet4f lo, Packet4f hi) +{ + if (lohi) { + return vec_perm(reinterpret_cast(lo), reinterpret_cast(hi), p16uc_MERGEH16); + } else { + return vec_perm(reinterpret_cast(hi), reinterpret_cast(lo), p16uc_MERGEE16); + } +} + +/** + * Pack the low portion of two float Packets into one bfloat16 Packet + * + * @param lohi to expect either a low & high OR odd & even order + */ +template +EIGEN_ALWAYS_INLINE Packet8bf Bf16PackLow(Packet4f lo, Packet4f hi) +{ + if (lohi) { + return vec_pack(reinterpret_cast(lo), reinterpret_cast(hi)); + } else { + return vec_perm(reinterpret_cast(hi), reinterpret_cast(lo), p16uc_MERGEO16); + } +} +#else +template +EIGEN_ALWAYS_INLINE Packet8bf Bf16PackLow(Packet4f hi, Packet4f lo) +{ + if (lohi) { + return vec_pack(reinterpret_cast(hi), reinterpret_cast(lo)); + } else { + return vec_perm(reinterpret_cast(hi), reinterpret_cast(lo), p16uc_MERGEE16); + } +} + +template +EIGEN_ALWAYS_INLINE Packet8bf Bf16PackHigh(Packet4f hi, Packet4f lo) +{ + if (lohi) { + return vec_perm(reinterpret_cast(hi), reinterpret_cast(lo), p16uc_MERGEL16); + } else { + return vec_perm(reinterpret_cast(hi), reinterpret_cast(lo), p16uc_MERGEO16); + } +} +#endif + +/** + * Convert and pack two float Packets into one bfloat16 Packet + * + * @param lohi to expect either a low & high OR odd & even order + */ +template +EIGEN_ALWAYS_INLINE Packet8bf F32ToBf16Two(Packet4f lo, Packet4f hi) +{ + Packet8us p4f = Bf16PackHigh(lo, hi); + Packet8us p4f2 = Bf16PackLow(lo, hi); + + Packet8us lsb = pand(p4f, p8us_ONE); + EIGEN_DECLARE_CONST_FAST_Packet8us(BIAS,0x7FFFu); + lsb = padd(lsb, p8us_BIAS); + lsb = padd(lsb, p4f2); + + Packet8bi rounding_bias = vec_cmplt(lsb, p4f2); + Packet8us input = psub(p4f, reinterpret_cast(rounding_bias)); + +#ifdef _ARCH_PWR9 + Packet4bi nan_selector_lo = vec_test_data_class(lo, __VEC_CLASS_FP_NAN); + Packet4bi nan_selector_hi = vec_test_data_class(hi, __VEC_CLASS_FP_NAN); + Packet8us nan_selector = Bf16PackLow(reinterpret_cast(nan_selector_lo), reinterpret_cast(nan_selector_hi)); + + input = vec_sel(input, p8us_BIAS, nan_selector); + +#ifdef SUPPORT_BF16_SUBNORMALS + Packet4bi subnormal_selector_lo = vec_test_data_class(lo, __VEC_CLASS_FP_SUBNORMAL); + Packet4bi subnormal_selector_hi = vec_test_data_class(hi, __VEC_CLASS_FP_SUBNORMAL); + Packet8us subnormal_selector = Bf16PackLow(reinterpret_cast(subnormal_selector_lo), reinterpret_cast(subnormal_selector_hi)); + + input = vec_sel(input, reinterpret_cast(p4f), subnormal_selector); +#endif +#else +#ifdef SUPPORT_BF16_SUBNORMALS + //Test NaN and Subnormal + const EIGEN_DECLARE_CONST_FAST_Packet8us(exp_mask, 0x7F80); + Packet8us exp = pand(p8us_exp_mask, p4f); + + const EIGEN_DECLARE_CONST_FAST_Packet8us(mantissa_mask, 0x7Fu); + Packet8us mantissa = pand(p8us_mantissa_mask, p4f); + + Packet8bi is_max_exp = vec_cmpeq(exp, p8us_exp_mask); + Packet8bi is_mant_zero = vec_cmpeq(mantissa, reinterpret_cast(p4i_ZERO)); + + Packet8us nan_selector = pandnot( + reinterpret_cast(is_max_exp), + reinterpret_cast(is_mant_zero) + ); + + Packet8bi is_zero_exp = vec_cmpeq(exp, reinterpret_cast(p4i_ZERO)); + + Packet8us subnormal_selector = pandnot( + reinterpret_cast(is_zero_exp), + reinterpret_cast(is_mant_zero) + ); + + // Using BIAS as NaN (since any or all of the last 7 bits can be set) + input = vec_sel(input, p8us_BIAS, nan_selector); + input = vec_sel(input, reinterpret_cast(p4f), subnormal_selector); +#else + //Test only NaN + Packet4bi nan_selector_lo = vec_cmpeq(lo, lo); + Packet4bi nan_selector_hi = vec_cmpeq(hi, hi); + Packet8us nan_selector = Bf16PackLow(reinterpret_cast(nan_selector_lo), reinterpret_cast(nan_selector_hi)); + + input = vec_sel(p8us_BIAS, input, nan_selector); +#endif +#endif + + return input; +} + +/** + * Convert and pack two float Packets into one bfloat16 Packet - low & high order + */ +EIGEN_STRONG_INLINE Packet8bf F32ToBf16Both(Packet4f lo, Packet4f hi) +{ +#ifdef _ARCH_PWR10 + Packet8bf fp16_0 = F32ToBf16(lo); + Packet8bf fp16_1 = F32ToBf16(hi); + return vec_pack(reinterpret_cast(fp16_0.m_val), reinterpret_cast(fp16_1.m_val)); +#else + return F32ToBf16Two(lo, hi); +#endif +} + +/** + * Convert and pack two float Packets into one bfloat16 Packet - odd & even order + */ EIGEN_STRONG_INLINE Packet8bf F32ToBf16(Packet4f even, Packet4f odd){ - Packet4f bf_odd, bf_even; - bf_odd = reinterpret_cast(F32ToBf16(odd).m_val); - bf_odd = plogical_shift_left<16>(bf_odd); - bf_even = reinterpret_cast(F32ToBf16(even).m_val); - return reinterpret_cast(por(bf_even, bf_odd)); +#ifdef _ARCH_PWR10 + return pmerge(reinterpret_cast(F32ToBf16(even).m_val), reinterpret_cast(F32ToBf16(odd).m_val)); +#else + return F32ToBf16Two(even, odd); +#endif } #define BF16_TO_F32_UNARY_OP_WRAPPER(OP, A) \ Packet4f a_even = Bf16ToF32Even(A);\ @@ -2493,11 +2646,7 @@ ptranspose(PacketBlock& kernel) { template EIGEN_STRONG_INLINE Packet pblend4(const Selector<4>& ifPacket, const Packet& thenPacket, const Packet& elsePacket) { Packet4ui select = { ifPacket.select[0], ifPacket.select[1], ifPacket.select[2], ifPacket.select[3] }; -#ifdef __POWER8_VECTOR__ - Packet4ui mask = reinterpret_cast(vec_neg(reinterpret_cast(select))); -#else - Packet4ui mask = reinterpret_cast(vec_cmpeq(reinterpret_cast(select), reinterpret_cast(p4i_ONE))); -#endif + Packet4ui mask = reinterpret_cast(pnegate(reinterpret_cast(select))); return vec_sel(elsePacket, thenPacket, mask); } @@ -2512,11 +2661,7 @@ template<> EIGEN_STRONG_INLINE Packet4f pblend(const Selector<4>& ifPacket, cons template<> EIGEN_STRONG_INLINE Packet8s pblend(const Selector<8>& ifPacket, const Packet8s& thenPacket, const Packet8s& elsePacket) { Packet8us select = { ifPacket.select[0], ifPacket.select[1], ifPacket.select[2], ifPacket.select[3], ifPacket.select[4], ifPacket.select[5], ifPacket.select[6], ifPacket.select[7] }; -#ifdef __POWER8_VECTOR__ - Packet8us mask = reinterpret_cast(vec_neg(reinterpret_cast(select))); -#else - Packet8us mask = reinterpret_cast(vec_cmpeq(select, p8us_ONE)); -#endif + Packet8us mask = reinterpret_cast(pnegate(reinterpret_cast(select))); Packet8s result = vec_sel(elsePacket, thenPacket, mask); return result; } @@ -2524,11 +2669,7 @@ template<> EIGEN_STRONG_INLINE Packet8s pblend(const Selector<8>& ifPacket, cons template<> EIGEN_STRONG_INLINE Packet8us pblend(const Selector<8>& ifPacket, const Packet8us& thenPacket, const Packet8us& elsePacket) { Packet8us select = { ifPacket.select[0], ifPacket.select[1], ifPacket.select[2], ifPacket.select[3], ifPacket.select[4], ifPacket.select[5], ifPacket.select[6], ifPacket.select[7] }; -#ifdef __POWER8_VECTOR__ - Packet8us mask = reinterpret_cast(vec_neg(reinterpret_cast(select))); -#else - Packet8us mask = reinterpret_cast(vec_cmpeq(reinterpret_cast(select), p8us_ONE)); -#endif + Packet8us mask = reinterpret_cast(pnegate(reinterpret_cast(select))); return vec_sel(elsePacket, thenPacket, mask); } @@ -2542,11 +2683,7 @@ template<> EIGEN_STRONG_INLINE Packet16c pblend(const Selector<16>& ifPacket, co ifPacket.select[8], ifPacket.select[9], ifPacket.select[10], ifPacket.select[11], ifPacket.select[12], ifPacket.select[13], ifPacket.select[14], ifPacket.select[15] }; -#ifdef __POWER8_VECTOR__ - Packet16uc mask = reinterpret_cast(vec_neg(reinterpret_cast(select))); -#else - Packet16uc mask = reinterpret_cast(vec_cmpeq(reinterpret_cast(select), p16uc_ONE)); -#endif + Packet16uc mask = reinterpret_cast(pnegate(reinterpret_cast(select))); return vec_sel(elsePacket, thenPacket, mask); } @@ -2556,11 +2693,7 @@ template<> EIGEN_STRONG_INLINE Packet16uc pblend(const Selector<16>& ifPacket, c ifPacket.select[8], ifPacket.select[9], ifPacket.select[10], ifPacket.select[11], ifPacket.select[12], ifPacket.select[13], ifPacket.select[14], ifPacket.select[15] }; -#ifdef __POWER8_VECTOR__ - Packet16uc mask = reinterpret_cast(vec_neg(reinterpret_cast(select))); -#else - Packet16uc mask = reinterpret_cast(vec_cmpeq(reinterpret_cast(select), p16uc_ONE)); -#endif + Packet16uc mask = reinterpret_cast(pnegate(reinterpret_cast(select))); return vec_sel(elsePacket, thenPacket, mask); } @@ -2636,10 +2769,7 @@ template<> EIGEN_STRONG_INLINE Packet8us pcast(const Packe low_odd = vec_sel(low_even, p4ui_low_mask, overflow_selector); } - low_odd = plogical_shift_left<16>(low_odd); - - Packet4ui int_final = por(low_even, low_odd); - return reinterpret_cast(int_final); + return pmerge(low_even, low_odd); } template<> EIGEN_STRONG_INLINE Packet8bf pcast(const Packet8us& a) { @@ -2937,7 +3067,21 @@ template<> EIGEN_STRONG_INLINE Packet2d preverse(const Packet2d& a) return vec_sld(a, a, 8); } template<> EIGEN_STRONG_INLINE Packet2d pabs(const Packet2d& a) { return vec_abs(a); } +#ifdef __POWER8_VECTOR__ template<> EIGEN_STRONG_INLINE Packet2d psignbit(const Packet2d& a) { return (Packet2d)vec_sra((Packet2l)a, vec_splats((unsigned long long)(63))); } +#else +#ifdef _BIG_ENDIAN +static Packet16uc p16uc_DUPSIGN = { 0,0,0,0, 0,0,0,0, 8,8,8,8, 8,8,8,8 }; +#else +static Packet16uc p16uc_DUPSIGN = { 7,7,7,7, 7,7,7,7, 15,15,15,15, 15,15,15,15 }; +#endif + +template<> EIGEN_STRONG_INLINE Packet2d psignbit(const Packet2d& a) +{ + Packet16c tmp = vec_sra(reinterpret_cast(a), vec_splats((unsigned char)(7))); + return reinterpret_cast(vec_perm(tmp, tmp, p16uc_DUPSIGN)); +} +#endif // VSX support varies between different compilers and even different // versions of the same compiler. For gcc version >= 4.9.3, we can use // vec_cts to efficiently convert Packet2d to Packet2l. Otherwise, use