diff --git a/Eigen/src/Core/arch/AltiVec/MatrixProduct.h b/Eigen/src/Core/arch/AltiVec/MatrixProduct.h index 54a7195f8..1c5c04859 100644 --- a/Eigen/src/Core/arch/AltiVec/MatrixProduct.h +++ b/Eigen/src/Core/arch/AltiVec/MatrixProduct.h @@ -15,8 +15,6 @@ #define EIGEN_ALTIVEC_USE_CUSTOM_PACK 1 #endif -#include "MatrixProductCommon.h" - #if !defined(EIGEN_ALTIVEC_DISABLE_MMA) #define EIGEN_ALTIVEC_DISABLE_MMA 0 #endif @@ -45,6 +43,8 @@ #endif // EIGEN_ALTIVEC_MMA_SUPPORT +#include "MatrixProductCommon.h" + #if defined(EIGEN_ALTIVEC_MMA_ONLY) || defined(EIGEN_ALTIVEC_MMA_DYNAMIC_DISPATCH) #include "MatrixProductMMA.h" #endif @@ -2713,12 +2713,10 @@ EIGEN_ALWAYS_INLINE bool supportsMMA() { #if defined(EIGEN_ALTIVEC_MMA_ONLY) return true; -#else -#if defined(EIGEN_ALTIVEC_MMA_DYNAMIC_DISPATCH) && __has_builtin(__builtin_cpu_supports) +#elif defined(EIGEN_ALTIVEC_MMA_DYNAMIC_DISPATCH) && defined(__BUILTIN_CPU_SUPPORTS__) return __builtin_cpu_supports ("arch_3_1") && __builtin_cpu_supports ("mma"); #else - return false; // No dynamic dispatch for LLVM -#endif + return false; // No dynamic dispatch for LLVM or older GCC #endif } diff --git a/Eigen/src/Core/arch/AltiVec/MatrixProductCommon.h b/Eigen/src/Core/arch/AltiVec/MatrixProductCommon.h index cb18311e1..daed8c165 100644 --- a/Eigen/src/Core/arch/AltiVec/MatrixProductCommon.h +++ b/Eigen/src/Core/arch/AltiVec/MatrixProductCommon.h @@ -5,7 +5,7 @@ #define EIGEN_POWER_PREFETCH(p) #endif -#ifdef _ARCH_PWR9 +#if defined(_ARCH_PWR9) || defined(EIGEN_ALTIVEC_MMA_DYNAMIC_DISPATCH) #define USE_PARTIAL_PACKETS #endif diff --git a/Eigen/src/Core/arch/AltiVec/MatrixProductMMAbfloat16.h b/Eigen/src/Core/arch/AltiVec/MatrixProductMMAbfloat16.h index 011d68e69..509411839 100644 --- a/Eigen/src/Core/arch/AltiVec/MatrixProductMMAbfloat16.h +++ b/Eigen/src/Core/arch/AltiVec/MatrixProductMMAbfloat16.h @@ -394,10 +394,12 @@ EIGEN_ALWAYS_INLINE void vecColLoop(Index j, LhsMapper& lhs, RhsMapper& rhs, __v b0 = vec_mergeh(b0.m_val, b1.m_val); } - LhsMapper lhs2 = lhs.getSubMapper(0, j); + using LhsSubMapper = typename LhsMapper::SubMapper; + + LhsSubMapper lhs2 = lhs.getSubMapper(0, j); BFLOAT16_UNROLL for(Index k = 0; k < num_acc; k += 2) { - loadVecLoop(k, lhs2, a0, b1); + loadVecLoop(k, lhs2, a0, b1); } multVec(quad_acc, a0, b0); @@ -418,12 +420,14 @@ void colVecColLoopBody(Index& row, Index cend, Index rows, LhsMapper& lhs, RhsMa zeroAccumulators(quad_acc); - LhsMapper lhs2 = lhs.getSubMapper(row, 0); + using LhsSubMapper = typename LhsMapper::SubMapper; + + LhsSubMapper lhs2 = lhs.getSubMapper(row, 0); for(Index j = 0; j + 2 <= cend; j += 2) { - vecColLoop(j, lhs2, rhs, quad_acc); + vecColLoop(j, lhs2, rhs, quad_acc); } if (cend & 1) { - vecColLoop(cend - 1, lhs2, rhs, quad_acc); + vecColLoop(cend - 1, lhs2, rhs, quad_acc); } disassembleAccumulators(quad_acc, acc); @@ -490,6 +494,33 @@ EIGEN_ALWAYS_INLINE void calcVecColLoops(Index cend, Index rows, LhsMapper& lhs, } } +template +struct UseMMAStride : std::false_type { + static EIGEN_ALWAYS_INLINE void run(Index j2, Index jend, Index rows, LhsMapper& lhs, RhsMapper& rhs, Packet4f pAlpha, float *result) + { + using RhsSubMapper = typename RhsMapper::SubMapper; + + RhsSubMapper rhs2 = rhs.getSubMapper(j2, 0); + calcVecColLoops(jend - j2, rows, lhs, rhs2, pAlpha, result); + } +}; + +template +struct UseMMAStride::value>> : std::true_type { + static EIGEN_ALWAYS_INLINE void run(Index j2, Index jend, Index rows, LhsMapper& lhs, RhsMapper& rhs, Packet4f pAlpha, float *result) + { + using RhsSubMapper = typename RhsMapper::SubMapper; + + RhsSubMapper rhs2 = rhs.getSubMapper(j2, 0); + if (rhs.stride() == 1) { + calcVecColLoops(jend - j2, rows, lhs, rhs2, pAlpha, result); + } else { + calcVecColLoops(jend - j2, rows, lhs, rhs2, pAlpha, result); + } + } +}; + template void gemvMMA_bfloat16_col( Index rows, Index cols, @@ -498,8 +529,6 @@ void gemvMMA_bfloat16_col( bfloat16* res, Index resIncr, bfloat16 alpha) { - typedef typename RhsMapper::LinearMapper LinearMapper; - EIGEN_UNUSED_VARIABLE(resIncr); eigen_internal_assert(resIncr == 1); @@ -523,14 +552,10 @@ void gemvMMA_bfloat16_col( { Index jend = numext::mini(j2 + block_cols, cols); - LhsMapper lhs2 = lhs.getSubMapper(0, j2); - if (rhs.stride() == 1) { - LinearMapper rhs3 = rhs2.getLinearMapper(j2, 0); - calcVecColLoops(jend - j2, rows, lhs2, rhs3, pAlpha, result); - } else { - RhsMapper rhs3 = rhs2.getSubMapper(j2, 0); - calcVecColLoops(jend - j2, rows, lhs2, rhs3, pAlpha, result); - } + using LhsSubMapper = typename LhsMapper::SubMapper; + + LhsSubMapper lhs2 = lhs.getSubMapper(0, j2); + UseMMAStride::run(j2, jend, rows, lhs2, rhs2, pAlpha, result); } convertArrayPointerF32toBF16(result, rows, res); diff --git a/Eigen/src/Core/arch/AltiVec/MatrixVectorProduct.h b/Eigen/src/Core/arch/AltiVec/MatrixVectorProduct.h index 73c6aa15b..917023013 100644 --- a/Eigen/src/Core/arch/AltiVec/MatrixVectorProduct.h +++ b/Eigen/src/Core/arch/AltiVec/MatrixVectorProduct.h @@ -527,7 +527,13 @@ struct loadColData_impl // linear == false static EIGEN_ALWAYS_INLINE Packet8bf run(RhsMapper& rhs, Index j) { - return pgather(&rhs(j + 0, 0), rhs.stride()); + const Index n = unpacket_traits::size; + EIGEN_ALIGN16 bfloat16 to[n]; + LOAD_STORE_UNROLL_16 + for (Index i = 0; i < n; i++) { + to[i] = rhs(j + i, 0); + } + return pload(to); } }; @@ -537,7 +543,7 @@ struct loadColData_impl // linear == true static EIGEN_ALWAYS_INLINE Packet8bf run(RhsMapper& rhs, Index j) { - return rhs.template loadPacket(j + 0); + return rhs.template loadPacket(j + 0, 0); } }; @@ -558,9 +564,11 @@ EIGEN_ALWAYS_INLINE void vecColLoopVSX(Index j, LhsMapper& lhs, RhsMapper& rhs, b0[1] = oneConvertBF16Perm(b2.m_val, p16uc_MERGE16_32_V2); } - LhsMapper lhs2 = lhs.getSubMapper(0, j); + using LhsSubMapper = typename LhsMapper::SubMapper; + + LhsSubMapper lhs2 = lhs.getSubMapper(0, j); for(Index k = 0; k < num_acc; k += 2) { - loadVecLoopVSX(k, lhs2, a0); + loadVecLoopVSX(k, lhs2, a0); } multVecVSX(acc, a0, b0); @@ -589,12 +597,14 @@ void colVSXVecColLoopBody(Index& row, Index cend, Index rows, LhsMapper& lhs, Rh zeroAccumulators(acc); - LhsMapper lhs2 = lhs.getSubMapper(row, 0); + using LhsSubMapper = typename LhsMapper::SubMapper; + + LhsSubMapper lhs2 = lhs.getSubMapper(row, 0); for(Index j = 0; j + 2 <= cend; j += 2) { - vecColLoopVSX(j, lhs2, rhs, acc); + vecColLoopVSX(j, lhs2, rhs, acc); } if (cend & 1) { - vecColLoopVSX(cend - 1, lhs2, rhs, acc); + vecColLoopVSX(cend - 1, lhs2, rhs, acc); } addResultsVSX(acc); @@ -716,6 +726,33 @@ EIGEN_ALWAYS_INLINE void convertArrayPointerF32toBF16VSX(float *result, Index ro convertPointerF32toBF16VSX<1,inc>(i, result, rows, dst, resInc); } +template +struct UseStride : std::false_type { + static EIGEN_ALWAYS_INLINE void run(Index j2, Index jend, Index rows, LhsMapper& lhs, RhsMapper& rhs, Packet4f pAlpha, float *result) + { + using RhsSubMapper = typename RhsMapper::SubMapper; + + RhsSubMapper rhs2 = rhs.getSubMapper(j2, 0); + calcVSXVecColLoops(jend - j2, rows, lhs, rhs2, pAlpha, result); + } +}; + +template +struct UseStride::value>> : std::true_type { + static EIGEN_ALWAYS_INLINE void run(Index j2, Index jend, Index rows, LhsMapper& lhs, RhsMapper& rhs, Packet4f pAlpha, float *result) + { + using RhsSubMapper = typename RhsMapper::SubMapper; + + RhsSubMapper rhs2 = rhs.getSubMapper(j2, 0); + if (rhs.stride() == 1) { + calcVSXVecColLoops(jend - j2, rows, lhs, rhs2, pAlpha, result); + } else { + calcVSXVecColLoops(jend - j2, rows, lhs, rhs2, pAlpha, result); + } + } +}; + template void gemv_bfloat16_col( Index rows, Index cols, @@ -724,8 +761,6 @@ void gemv_bfloat16_col( bfloat16* res, Index resIncr, bfloat16 alpha) { - typedef typename RhsMapper::LinearMapper LinearMapper; - EIGEN_UNUSED_VARIABLE(resIncr); eigen_internal_assert(resIncr == 1); @@ -749,14 +784,10 @@ void gemv_bfloat16_col( { Index jend = numext::mini(j2 + block_cols, cols); - LhsMapper lhs2 = lhs.getSubMapper(0, j2); - if (rhs.stride() == 1) { - LinearMapper rhs3 = rhs2.getLinearMapper(j2, 0); - calcVSXVecColLoops(jend - j2, rows, lhs2, rhs3, pAlpha, result); - } else { - RhsMapper rhs3 = rhs2.getSubMapper(j2, 0); - calcVSXVecColLoops(jend - j2, rows, lhs2, rhs3, pAlpha, result); - } + using LhsSubMapper = typename LhsMapper::SubMapper; + + LhsSubMapper lhs2 = lhs.getSubMapper(0, j2); + UseStride::run(j2, jend, rows, lhs2, rhs2, pAlpha, result); } convertArrayPointerF32toBF16VSX(result, rows, res); diff --git a/Eigen/src/Core/util/BlasUtil.h b/Eigen/src/Core/util/BlasUtil.h index 41fbe1538..b0217a905 100644 --- a/Eigen/src/Core/util/BlasUtil.h +++ b/Eigen/src/Core/util/BlasUtil.h @@ -182,6 +182,7 @@ class blas_data_mapper { public: typedef BlasLinearMapper LinearMapper; + typedef blas_data_mapper SubMapper; typedef BlasVectorMapper VectorMapper; EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE blas_data_mapper(Scalar* data, Index stride, Index incr=1) @@ -191,9 +192,9 @@ public: eigen_assert(incr==1); } - EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE blas_data_mapper + EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE SubMapper getSubMapper(Index i, Index j) const { - return blas_data_mapper(&operator()(i, j), m_stride); + return SubMapper(&operator()(i, j), m_stride); } EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE LinearMapper getLinearMapper(Index i, Index j) const { @@ -316,12 +317,13 @@ class blas_data_mapper { public: typedef BlasLinearMapper LinearMapper; + typedef blas_data_mapper SubMapper; EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE blas_data_mapper(Scalar* data, Index stride, Index incr) : m_data(data), m_stride(stride), m_incr(incr) {} - EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE blas_data_mapper + EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE SubMapper getSubMapper(Index i, Index j) const { - return blas_data_mapper(&operator()(i, j), m_stride, m_incr.value()); + return SubMapper(&operator()(i, j), m_stride, m_incr.value()); } EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE LinearMapper getLinearMapper(Index i, Index j) const { @@ -448,10 +450,12 @@ protected: template class const_blas_data_mapper : public blas_data_mapper { public: + typedef const_blas_data_mapper SubMapper; + EIGEN_ALWAYS_INLINE const_blas_data_mapper(const Scalar *data, Index stride) : blas_data_mapper(data, stride) {} - EIGEN_ALWAYS_INLINE const_blas_data_mapper getSubMapper(Index i, Index j) const { - return const_blas_data_mapper(&(this->operator()(i, j)), this->m_stride); + EIGEN_ALWAYS_INLINE SubMapper getSubMapper(Index i, Index j) const { + return SubMapper(&(this->operator()(i, j)), this->m_stride); } }; diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorContractionMapper.h b/unsupported/Eigen/CXX11/src/Tensor/TensorContractionMapper.h index 8478cc3c1..7fb54b961 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorContractionMapper.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorContractionMapper.h @@ -443,6 +443,14 @@ class TensorContractionSubMapper { return m_base_mapper.template loadPacket(i + m_vert_offset, j + m_horiz_offset); } + template + EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE PacketT loadPacketPartial(Index i, Index j, Index, Index = 0) const { + if (UseDirectOffsets) { + return m_base_mapper.template loadPacket(i, j); + } + return m_base_mapper.template loadPacket(i + m_vert_offset, j + m_horiz_offset); + } + template EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE PacketT loadPacket(Index i, Index j) const { if (UseDirectOffsets) { @@ -473,6 +481,8 @@ class TensorContractionSubMapper { return SubMapper(m_base_mapper, i + m_vert_offset, j + m_horiz_offset); } + EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE const Index stride() const { return m_base_mapper.stride(); } + template EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE PacketT load(Index i) const { EIGEN_STATIC_ASSERT((internal::is_same::value), YOU_MADE_A_PROGRAMMING_MISTAKE);