Fix problems with recent changes and Tensorflow in Power

This commit is contained in:
Chip Kerchner 2023-07-26 16:24:58 +00:00 committed by Antonio Sánchez
parent ba1cb6e45e
commit 7769eb1b2e
6 changed files with 113 additions and 45 deletions

View File

@ -15,8 +15,6 @@
#define EIGEN_ALTIVEC_USE_CUSTOM_PACK 1
#endif
#include "MatrixProductCommon.h"
#if !defined(EIGEN_ALTIVEC_DISABLE_MMA)
#define EIGEN_ALTIVEC_DISABLE_MMA 0
#endif
@ -45,6 +43,8 @@
#endif // EIGEN_ALTIVEC_MMA_SUPPORT
#include "MatrixProductCommon.h"
#if defined(EIGEN_ALTIVEC_MMA_ONLY) || defined(EIGEN_ALTIVEC_MMA_DYNAMIC_DISPATCH)
#include "MatrixProductMMA.h"
#endif
@ -2713,12 +2713,10 @@ EIGEN_ALWAYS_INLINE bool supportsMMA()
{
#if defined(EIGEN_ALTIVEC_MMA_ONLY)
return true;
#else
#if defined(EIGEN_ALTIVEC_MMA_DYNAMIC_DISPATCH) && __has_builtin(__builtin_cpu_supports)
#elif defined(EIGEN_ALTIVEC_MMA_DYNAMIC_DISPATCH) && defined(__BUILTIN_CPU_SUPPORTS__)
return __builtin_cpu_supports ("arch_3_1") && __builtin_cpu_supports ("mma");
#else
return false; // No dynamic dispatch for LLVM
#endif
return false; // No dynamic dispatch for LLVM or older GCC
#endif
}

View File

@ -5,7 +5,7 @@
#define EIGEN_POWER_PREFETCH(p)
#endif
#ifdef _ARCH_PWR9
#if defined(_ARCH_PWR9) || defined(EIGEN_ALTIVEC_MMA_DYNAMIC_DISPATCH)
#define USE_PARTIAL_PACKETS
#endif

View File

@ -394,10 +394,12 @@ EIGEN_ALWAYS_INLINE void vecColLoop(Index j, LhsMapper& lhs, RhsMapper& rhs, __v
b0 = vec_mergeh(b0.m_val, b1.m_val);
}
LhsMapper lhs2 = lhs.getSubMapper(0, j);
using LhsSubMapper = typename LhsMapper::SubMapper;
LhsSubMapper lhs2 = lhs.getSubMapper(0, j);
BFLOAT16_UNROLL
for(Index k = 0; k < num_acc; k += 2) {
loadVecLoop<num_acc, LhsMapper, zero>(k, lhs2, a0, b1);
loadVecLoop<num_acc, LhsSubMapper, zero>(k, lhs2, a0, b1);
}
multVec<num_acc>(quad_acc, a0, b0);
@ -418,12 +420,14 @@ void colVecColLoopBody(Index& row, Index cend, Index rows, LhsMapper& lhs, RhsMa
zeroAccumulators<num_acc>(quad_acc);
LhsMapper lhs2 = lhs.getSubMapper(row, 0);
using LhsSubMapper = typename LhsMapper::SubMapper;
LhsSubMapper lhs2 = lhs.getSubMapper(row, 0);
for(Index j = 0; j + 2 <= cend; j += 2) {
vecColLoop<num_acc, LhsMapper, RhsMapper, false, linear>(j, lhs2, rhs, quad_acc);
vecColLoop<num_acc, LhsSubMapper, RhsMapper, false, linear>(j, lhs2, rhs, quad_acc);
}
if (cend & 1) {
vecColLoop<num_acc, LhsMapper, RhsMapper, true, linear>(cend - 1, lhs2, rhs, quad_acc);
vecColLoop<num_acc, LhsSubMapper, RhsMapper, true, linear>(cend - 1, lhs2, rhs, quad_acc);
}
disassembleAccumulators<num_acc>(quad_acc, acc);
@ -490,6 +494,33 @@ EIGEN_ALWAYS_INLINE void calcVecColLoops(Index cend, Index rows, LhsMapper& lhs,
}
}
template<typename RhsMapper, typename LhsMapper, typename = void>
struct UseMMAStride : std::false_type {
static EIGEN_ALWAYS_INLINE void run(Index j2, Index jend, Index rows, LhsMapper& lhs, RhsMapper& rhs, Packet4f pAlpha, float *result)
{
using RhsSubMapper = typename RhsMapper::SubMapper;
RhsSubMapper rhs2 = rhs.getSubMapper(j2, 0);
calcVecColLoops<LhsMapper, RhsSubMapper, false>(jend - j2, rows, lhs, rhs2, pAlpha, result);
}
};
template<typename RhsMapper, typename LhsMapper>
struct UseMMAStride<RhsMapper, LhsMapper, std::enable_if_t<std::is_member_function_pointer<
decltype(&RhsMapper::stride)>::value>> : std::true_type {
static EIGEN_ALWAYS_INLINE void run(Index j2, Index jend, Index rows, LhsMapper& lhs, RhsMapper& rhs, Packet4f pAlpha, float *result)
{
using RhsSubMapper = typename RhsMapper::SubMapper;
RhsSubMapper rhs2 = rhs.getSubMapper(j2, 0);
if (rhs.stride() == 1) {
calcVecColLoops<LhsMapper, RhsSubMapper, true>(jend - j2, rows, lhs, rhs2, pAlpha, result);
} else {
calcVecColLoops<LhsMapper, RhsSubMapper, false>(jend - j2, rows, lhs, rhs2, pAlpha, result);
}
}
};
template<typename LhsMapper, typename RhsMapper>
void gemvMMA_bfloat16_col(
Index rows, Index cols,
@ -498,8 +529,6 @@ void gemvMMA_bfloat16_col(
bfloat16* res, Index resIncr,
bfloat16 alpha)
{
typedef typename RhsMapper::LinearMapper LinearMapper;
EIGEN_UNUSED_VARIABLE(resIncr);
eigen_internal_assert(resIncr == 1);
@ -523,14 +552,10 @@ void gemvMMA_bfloat16_col(
{
Index jend = numext::mini(j2 + block_cols, cols);
LhsMapper lhs2 = lhs.getSubMapper(0, j2);
if (rhs.stride() == 1) {
LinearMapper rhs3 = rhs2.getLinearMapper(j2, 0);
calcVecColLoops<LhsMapper, LinearMapper, true>(jend - j2, rows, lhs2, rhs3, pAlpha, result);
} else {
RhsMapper rhs3 = rhs2.getSubMapper(j2, 0);
calcVecColLoops<LhsMapper, RhsMapper, false>(jend - j2, rows, lhs2, rhs3, pAlpha, result);
}
using LhsSubMapper = typename LhsMapper::SubMapper;
LhsSubMapper lhs2 = lhs.getSubMapper(0, j2);
UseMMAStride<RhsMapper, LhsSubMapper>::run(j2, jend, rows, lhs2, rhs2, pAlpha, result);
}
convertArrayPointerF32toBF16(result, rows, res);

View File

@ -527,7 +527,13 @@ struct loadColData_impl
// linear == false
static EIGEN_ALWAYS_INLINE Packet8bf run(RhsMapper& rhs, Index j)
{
return pgather<bfloat16, Packet8bf>(&rhs(j + 0, 0), rhs.stride());
const Index n = unpacket_traits<Packet8bf>::size;
EIGEN_ALIGN16 bfloat16 to[n];
LOAD_STORE_UNROLL_16
for (Index i = 0; i < n; i++) {
to[i] = rhs(j + i, 0);
}
return pload<Packet8bf>(to);
}
};
@ -537,7 +543,7 @@ struct loadColData_impl<RhsMapper, true>
// linear == true
static EIGEN_ALWAYS_INLINE Packet8bf run(RhsMapper& rhs, Index j)
{
return rhs.template loadPacket<Packet8bf>(j + 0);
return rhs.template loadPacket<Packet8bf>(j + 0, 0);
}
};
@ -558,9 +564,11 @@ EIGEN_ALWAYS_INLINE void vecColLoopVSX(Index j, LhsMapper& lhs, RhsMapper& rhs,
b0[1] = oneConvertBF16Perm(b2.m_val, p16uc_MERGE16_32_V2);
}
LhsMapper lhs2 = lhs.getSubMapper(0, j);
using LhsSubMapper = typename LhsMapper::SubMapper;
LhsSubMapper lhs2 = lhs.getSubMapper(0, j);
for(Index k = 0; k < num_acc; k += 2) {
loadVecLoopVSX<num_acc, LhsMapper, zero>(k, lhs2, a0);
loadVecLoopVSX<num_acc, LhsSubMapper, zero>(k, lhs2, a0);
}
multVecVSX<num_acc, zero>(acc, a0, b0);
@ -589,12 +597,14 @@ void colVSXVecColLoopBody(Index& row, Index cend, Index rows, LhsMapper& lhs, Rh
zeroAccumulators<num_acc, 2>(acc);
LhsMapper lhs2 = lhs.getSubMapper(row, 0);
using LhsSubMapper = typename LhsMapper::SubMapper;
LhsSubMapper lhs2 = lhs.getSubMapper(row, 0);
for(Index j = 0; j + 2 <= cend; j += 2) {
vecColLoopVSX<num_acc, LhsMapper, RhsMapper, false, linear>(j, lhs2, rhs, acc);
vecColLoopVSX<num_acc, LhsSubMapper, RhsMapper, false, linear>(j, lhs2, rhs, acc);
}
if (cend & 1) {
vecColLoopVSX<num_acc, LhsMapper, RhsMapper, true, linear>(cend - 1, lhs2, rhs, acc);
vecColLoopVSX<num_acc, LhsSubMapper, RhsMapper, true, linear>(cend - 1, lhs2, rhs, acc);
}
addResultsVSX<num_acc>(acc);
@ -716,6 +726,33 @@ EIGEN_ALWAYS_INLINE void convertArrayPointerF32toBF16VSX(float *result, Index ro
convertPointerF32toBF16VSX<1,inc>(i, result, rows, dst, resInc);
}
template<typename RhsMapper, typename LhsMapper, typename = void>
struct UseStride : std::false_type {
static EIGEN_ALWAYS_INLINE void run(Index j2, Index jend, Index rows, LhsMapper& lhs, RhsMapper& rhs, Packet4f pAlpha, float *result)
{
using RhsSubMapper = typename RhsMapper::SubMapper;
RhsSubMapper rhs2 = rhs.getSubMapper(j2, 0);
calcVSXVecColLoops<LhsMapper, RhsSubMapper, false>(jend - j2, rows, lhs, rhs2, pAlpha, result);
}
};
template<typename RhsMapper, typename LhsMapper>
struct UseStride<RhsMapper, LhsMapper, std::enable_if_t<std::is_member_function_pointer<
decltype(&RhsMapper::stride)>::value>> : std::true_type {
static EIGEN_ALWAYS_INLINE void run(Index j2, Index jend, Index rows, LhsMapper& lhs, RhsMapper& rhs, Packet4f pAlpha, float *result)
{
using RhsSubMapper = typename RhsMapper::SubMapper;
RhsSubMapper rhs2 = rhs.getSubMapper(j2, 0);
if (rhs.stride() == 1) {
calcVSXVecColLoops<LhsMapper, RhsSubMapper, true>(jend - j2, rows, lhs, rhs2, pAlpha, result);
} else {
calcVSXVecColLoops<LhsMapper, RhsSubMapper, false>(jend - j2, rows, lhs, rhs2, pAlpha, result);
}
}
};
template<typename LhsMapper, typename RhsMapper>
void gemv_bfloat16_col(
Index rows, Index cols,
@ -724,8 +761,6 @@ void gemv_bfloat16_col(
bfloat16* res, Index resIncr,
bfloat16 alpha)
{
typedef typename RhsMapper::LinearMapper LinearMapper;
EIGEN_UNUSED_VARIABLE(resIncr);
eigen_internal_assert(resIncr == 1);
@ -749,14 +784,10 @@ void gemv_bfloat16_col(
{
Index jend = numext::mini(j2 + block_cols, cols);
LhsMapper lhs2 = lhs.getSubMapper(0, j2);
if (rhs.stride() == 1) {
LinearMapper rhs3 = rhs2.getLinearMapper(j2, 0);
calcVSXVecColLoops<LhsMapper, LinearMapper, true>(jend - j2, rows, lhs2, rhs3, pAlpha, result);
} else {
RhsMapper rhs3 = rhs2.getSubMapper(j2, 0);
calcVSXVecColLoops<LhsMapper, RhsMapper, false>(jend - j2, rows, lhs2, rhs3, pAlpha, result);
}
using LhsSubMapper = typename LhsMapper::SubMapper;
LhsSubMapper lhs2 = lhs.getSubMapper(0, j2);
UseStride<RhsMapper, LhsSubMapper>::run(j2, jend, rows, lhs2, rhs2, pAlpha, result);
}
convertArrayPointerF32toBF16VSX(result, rows, res);

View File

@ -182,6 +182,7 @@ class blas_data_mapper<Scalar,Index,StorageOrder,AlignmentType,1>
{
public:
typedef BlasLinearMapper<Scalar, Index, AlignmentType> LinearMapper;
typedef blas_data_mapper<Scalar, Index, StorageOrder, AlignmentType> SubMapper;
typedef BlasVectorMapper<Scalar, Index> VectorMapper;
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE blas_data_mapper(Scalar* data, Index stride, Index incr=1)
@ -191,9 +192,9 @@ public:
eigen_assert(incr==1);
}
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE blas_data_mapper<Scalar, Index, StorageOrder, AlignmentType>
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE SubMapper
getSubMapper(Index i, Index j) const {
return blas_data_mapper<Scalar, Index, StorageOrder, AlignmentType>(&operator()(i, j), m_stride);
return SubMapper(&operator()(i, j), m_stride);
}
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE LinearMapper getLinearMapper(Index i, Index j) const {
@ -316,12 +317,13 @@ class blas_data_mapper
{
public:
typedef BlasLinearMapper<Scalar, Index, AlignmentType,Incr> LinearMapper;
typedef blas_data_mapper SubMapper;
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE blas_data_mapper(Scalar* data, Index stride, Index incr) : m_data(data), m_stride(stride), m_incr(incr) {}
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE blas_data_mapper
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE SubMapper
getSubMapper(Index i, Index j) const {
return blas_data_mapper(&operator()(i, j), m_stride, m_incr.value());
return SubMapper(&operator()(i, j), m_stride, m_incr.value());
}
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE LinearMapper getLinearMapper(Index i, Index j) const {
@ -448,10 +450,12 @@ protected:
template<typename Scalar, typename Index, int StorageOrder>
class const_blas_data_mapper : public blas_data_mapper<const Scalar, Index, StorageOrder> {
public:
typedef const_blas_data_mapper<Scalar, Index, StorageOrder> SubMapper;
EIGEN_ALWAYS_INLINE const_blas_data_mapper(const Scalar *data, Index stride) : blas_data_mapper<const Scalar, Index, StorageOrder>(data, stride) {}
EIGEN_ALWAYS_INLINE const_blas_data_mapper<Scalar, Index, StorageOrder> getSubMapper(Index i, Index j) const {
return const_blas_data_mapper<Scalar, Index, StorageOrder>(&(this->operator()(i, j)), this->m_stride);
EIGEN_ALWAYS_INLINE SubMapper getSubMapper(Index i, Index j) const {
return SubMapper(&(this->operator()(i, j)), this->m_stride);
}
};

View File

@ -443,6 +443,14 @@ class TensorContractionSubMapper {
return m_base_mapper.template loadPacket<PacketT,Alignment>(i + m_vert_offset, j + m_horiz_offset);
}
template <typename PacketT>
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE PacketT loadPacketPartial(Index i, Index j, Index, Index = 0) const {
if (UseDirectOffsets) {
return m_base_mapper.template loadPacket<PacketT,Alignment>(i, j);
}
return m_base_mapper.template loadPacket<PacketT,Alignment>(i + m_vert_offset, j + m_horiz_offset);
}
template <typename PacketT, int AlignmentType>
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE PacketT loadPacket(Index i, Index j) const {
if (UseDirectOffsets) {
@ -473,6 +481,8 @@ class TensorContractionSubMapper {
return SubMapper(m_base_mapper, i + m_vert_offset, j + m_horiz_offset);
}
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE const Index stride() const { return m_base_mapper.stride(); }
template <typename PacketT, int AlignmentType>
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE PacketT load(Index i) const {
EIGEN_STATIC_ASSERT((internal::is_same<PacketT, PacketT>::value), YOU_MADE_A_PROGRAMMING_MISTAKE);