mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-07-03 03:35:11 +08:00
Fix problems with recent changes and Tensorflow in Power
This commit is contained in:
parent
ba1cb6e45e
commit
7769eb1b2e
@ -15,8 +15,6 @@
|
||||
#define EIGEN_ALTIVEC_USE_CUSTOM_PACK 1
|
||||
#endif
|
||||
|
||||
#include "MatrixProductCommon.h"
|
||||
|
||||
#if !defined(EIGEN_ALTIVEC_DISABLE_MMA)
|
||||
#define EIGEN_ALTIVEC_DISABLE_MMA 0
|
||||
#endif
|
||||
@ -45,6 +43,8 @@
|
||||
|
||||
#endif // EIGEN_ALTIVEC_MMA_SUPPORT
|
||||
|
||||
#include "MatrixProductCommon.h"
|
||||
|
||||
#if defined(EIGEN_ALTIVEC_MMA_ONLY) || defined(EIGEN_ALTIVEC_MMA_DYNAMIC_DISPATCH)
|
||||
#include "MatrixProductMMA.h"
|
||||
#endif
|
||||
@ -2713,12 +2713,10 @@ EIGEN_ALWAYS_INLINE bool supportsMMA()
|
||||
{
|
||||
#if defined(EIGEN_ALTIVEC_MMA_ONLY)
|
||||
return true;
|
||||
#else
|
||||
#if defined(EIGEN_ALTIVEC_MMA_DYNAMIC_DISPATCH) && __has_builtin(__builtin_cpu_supports)
|
||||
#elif defined(EIGEN_ALTIVEC_MMA_DYNAMIC_DISPATCH) && defined(__BUILTIN_CPU_SUPPORTS__)
|
||||
return __builtin_cpu_supports ("arch_3_1") && __builtin_cpu_supports ("mma");
|
||||
#else
|
||||
return false; // No dynamic dispatch for LLVM
|
||||
#endif
|
||||
return false; // No dynamic dispatch for LLVM or older GCC
|
||||
#endif
|
||||
}
|
||||
|
||||
|
@ -5,7 +5,7 @@
|
||||
#define EIGEN_POWER_PREFETCH(p)
|
||||
#endif
|
||||
|
||||
#ifdef _ARCH_PWR9
|
||||
#if defined(_ARCH_PWR9) || defined(EIGEN_ALTIVEC_MMA_DYNAMIC_DISPATCH)
|
||||
#define USE_PARTIAL_PACKETS
|
||||
#endif
|
||||
|
||||
|
@ -394,10 +394,12 @@ EIGEN_ALWAYS_INLINE void vecColLoop(Index j, LhsMapper& lhs, RhsMapper& rhs, __v
|
||||
b0 = vec_mergeh(b0.m_val, b1.m_val);
|
||||
}
|
||||
|
||||
LhsMapper lhs2 = lhs.getSubMapper(0, j);
|
||||
using LhsSubMapper = typename LhsMapper::SubMapper;
|
||||
|
||||
LhsSubMapper lhs2 = lhs.getSubMapper(0, j);
|
||||
BFLOAT16_UNROLL
|
||||
for(Index k = 0; k < num_acc; k += 2) {
|
||||
loadVecLoop<num_acc, LhsMapper, zero>(k, lhs2, a0, b1);
|
||||
loadVecLoop<num_acc, LhsSubMapper, zero>(k, lhs2, a0, b1);
|
||||
}
|
||||
|
||||
multVec<num_acc>(quad_acc, a0, b0);
|
||||
@ -418,12 +420,14 @@ void colVecColLoopBody(Index& row, Index cend, Index rows, LhsMapper& lhs, RhsMa
|
||||
|
||||
zeroAccumulators<num_acc>(quad_acc);
|
||||
|
||||
LhsMapper lhs2 = lhs.getSubMapper(row, 0);
|
||||
using LhsSubMapper = typename LhsMapper::SubMapper;
|
||||
|
||||
LhsSubMapper lhs2 = lhs.getSubMapper(row, 0);
|
||||
for(Index j = 0; j + 2 <= cend; j += 2) {
|
||||
vecColLoop<num_acc, LhsMapper, RhsMapper, false, linear>(j, lhs2, rhs, quad_acc);
|
||||
vecColLoop<num_acc, LhsSubMapper, RhsMapper, false, linear>(j, lhs2, rhs, quad_acc);
|
||||
}
|
||||
if (cend & 1) {
|
||||
vecColLoop<num_acc, LhsMapper, RhsMapper, true, linear>(cend - 1, lhs2, rhs, quad_acc);
|
||||
vecColLoop<num_acc, LhsSubMapper, RhsMapper, true, linear>(cend - 1, lhs2, rhs, quad_acc);
|
||||
}
|
||||
|
||||
disassembleAccumulators<num_acc>(quad_acc, acc);
|
||||
@ -490,6 +494,33 @@ EIGEN_ALWAYS_INLINE void calcVecColLoops(Index cend, Index rows, LhsMapper& lhs,
|
||||
}
|
||||
}
|
||||
|
||||
template<typename RhsMapper, typename LhsMapper, typename = void>
|
||||
struct UseMMAStride : std::false_type {
|
||||
static EIGEN_ALWAYS_INLINE void run(Index j2, Index jend, Index rows, LhsMapper& lhs, RhsMapper& rhs, Packet4f pAlpha, float *result)
|
||||
{
|
||||
using RhsSubMapper = typename RhsMapper::SubMapper;
|
||||
|
||||
RhsSubMapper rhs2 = rhs.getSubMapper(j2, 0);
|
||||
calcVecColLoops<LhsMapper, RhsSubMapper, false>(jend - j2, rows, lhs, rhs2, pAlpha, result);
|
||||
}
|
||||
};
|
||||
|
||||
template<typename RhsMapper, typename LhsMapper>
|
||||
struct UseMMAStride<RhsMapper, LhsMapper, std::enable_if_t<std::is_member_function_pointer<
|
||||
decltype(&RhsMapper::stride)>::value>> : std::true_type {
|
||||
static EIGEN_ALWAYS_INLINE void run(Index j2, Index jend, Index rows, LhsMapper& lhs, RhsMapper& rhs, Packet4f pAlpha, float *result)
|
||||
{
|
||||
using RhsSubMapper = typename RhsMapper::SubMapper;
|
||||
|
||||
RhsSubMapper rhs2 = rhs.getSubMapper(j2, 0);
|
||||
if (rhs.stride() == 1) {
|
||||
calcVecColLoops<LhsMapper, RhsSubMapper, true>(jend - j2, rows, lhs, rhs2, pAlpha, result);
|
||||
} else {
|
||||
calcVecColLoops<LhsMapper, RhsSubMapper, false>(jend - j2, rows, lhs, rhs2, pAlpha, result);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template<typename LhsMapper, typename RhsMapper>
|
||||
void gemvMMA_bfloat16_col(
|
||||
Index rows, Index cols,
|
||||
@ -498,8 +529,6 @@ void gemvMMA_bfloat16_col(
|
||||
bfloat16* res, Index resIncr,
|
||||
bfloat16 alpha)
|
||||
{
|
||||
typedef typename RhsMapper::LinearMapper LinearMapper;
|
||||
|
||||
EIGEN_UNUSED_VARIABLE(resIncr);
|
||||
eigen_internal_assert(resIncr == 1);
|
||||
|
||||
@ -523,14 +552,10 @@ void gemvMMA_bfloat16_col(
|
||||
{
|
||||
Index jend = numext::mini(j2 + block_cols, cols);
|
||||
|
||||
LhsMapper lhs2 = lhs.getSubMapper(0, j2);
|
||||
if (rhs.stride() == 1) {
|
||||
LinearMapper rhs3 = rhs2.getLinearMapper(j2, 0);
|
||||
calcVecColLoops<LhsMapper, LinearMapper, true>(jend - j2, rows, lhs2, rhs3, pAlpha, result);
|
||||
} else {
|
||||
RhsMapper rhs3 = rhs2.getSubMapper(j2, 0);
|
||||
calcVecColLoops<LhsMapper, RhsMapper, false>(jend - j2, rows, lhs2, rhs3, pAlpha, result);
|
||||
}
|
||||
using LhsSubMapper = typename LhsMapper::SubMapper;
|
||||
|
||||
LhsSubMapper lhs2 = lhs.getSubMapper(0, j2);
|
||||
UseMMAStride<RhsMapper, LhsSubMapper>::run(j2, jend, rows, lhs2, rhs2, pAlpha, result);
|
||||
}
|
||||
|
||||
convertArrayPointerF32toBF16(result, rows, res);
|
||||
|
@ -527,7 +527,13 @@ struct loadColData_impl
|
||||
// linear == false
|
||||
static EIGEN_ALWAYS_INLINE Packet8bf run(RhsMapper& rhs, Index j)
|
||||
{
|
||||
return pgather<bfloat16, Packet8bf>(&rhs(j + 0, 0), rhs.stride());
|
||||
const Index n = unpacket_traits<Packet8bf>::size;
|
||||
EIGEN_ALIGN16 bfloat16 to[n];
|
||||
LOAD_STORE_UNROLL_16
|
||||
for (Index i = 0; i < n; i++) {
|
||||
to[i] = rhs(j + i, 0);
|
||||
}
|
||||
return pload<Packet8bf>(to);
|
||||
}
|
||||
};
|
||||
|
||||
@ -537,7 +543,7 @@ struct loadColData_impl<RhsMapper, true>
|
||||
// linear == true
|
||||
static EIGEN_ALWAYS_INLINE Packet8bf run(RhsMapper& rhs, Index j)
|
||||
{
|
||||
return rhs.template loadPacket<Packet8bf>(j + 0);
|
||||
return rhs.template loadPacket<Packet8bf>(j + 0, 0);
|
||||
}
|
||||
};
|
||||
|
||||
@ -558,9 +564,11 @@ EIGEN_ALWAYS_INLINE void vecColLoopVSX(Index j, LhsMapper& lhs, RhsMapper& rhs,
|
||||
b0[1] = oneConvertBF16Perm(b2.m_val, p16uc_MERGE16_32_V2);
|
||||
}
|
||||
|
||||
LhsMapper lhs2 = lhs.getSubMapper(0, j);
|
||||
using LhsSubMapper = typename LhsMapper::SubMapper;
|
||||
|
||||
LhsSubMapper lhs2 = lhs.getSubMapper(0, j);
|
||||
for(Index k = 0; k < num_acc; k += 2) {
|
||||
loadVecLoopVSX<num_acc, LhsMapper, zero>(k, lhs2, a0);
|
||||
loadVecLoopVSX<num_acc, LhsSubMapper, zero>(k, lhs2, a0);
|
||||
}
|
||||
|
||||
multVecVSX<num_acc, zero>(acc, a0, b0);
|
||||
@ -589,12 +597,14 @@ void colVSXVecColLoopBody(Index& row, Index cend, Index rows, LhsMapper& lhs, Rh
|
||||
|
||||
zeroAccumulators<num_acc, 2>(acc);
|
||||
|
||||
LhsMapper lhs2 = lhs.getSubMapper(row, 0);
|
||||
using LhsSubMapper = typename LhsMapper::SubMapper;
|
||||
|
||||
LhsSubMapper lhs2 = lhs.getSubMapper(row, 0);
|
||||
for(Index j = 0; j + 2 <= cend; j += 2) {
|
||||
vecColLoopVSX<num_acc, LhsMapper, RhsMapper, false, linear>(j, lhs2, rhs, acc);
|
||||
vecColLoopVSX<num_acc, LhsSubMapper, RhsMapper, false, linear>(j, lhs2, rhs, acc);
|
||||
}
|
||||
if (cend & 1) {
|
||||
vecColLoopVSX<num_acc, LhsMapper, RhsMapper, true, linear>(cend - 1, lhs2, rhs, acc);
|
||||
vecColLoopVSX<num_acc, LhsSubMapper, RhsMapper, true, linear>(cend - 1, lhs2, rhs, acc);
|
||||
}
|
||||
|
||||
addResultsVSX<num_acc>(acc);
|
||||
@ -716,6 +726,33 @@ EIGEN_ALWAYS_INLINE void convertArrayPointerF32toBF16VSX(float *result, Index ro
|
||||
convertPointerF32toBF16VSX<1,inc>(i, result, rows, dst, resInc);
|
||||
}
|
||||
|
||||
template<typename RhsMapper, typename LhsMapper, typename = void>
|
||||
struct UseStride : std::false_type {
|
||||
static EIGEN_ALWAYS_INLINE void run(Index j2, Index jend, Index rows, LhsMapper& lhs, RhsMapper& rhs, Packet4f pAlpha, float *result)
|
||||
{
|
||||
using RhsSubMapper = typename RhsMapper::SubMapper;
|
||||
|
||||
RhsSubMapper rhs2 = rhs.getSubMapper(j2, 0);
|
||||
calcVSXVecColLoops<LhsMapper, RhsSubMapper, false>(jend - j2, rows, lhs, rhs2, pAlpha, result);
|
||||
}
|
||||
};
|
||||
|
||||
template<typename RhsMapper, typename LhsMapper>
|
||||
struct UseStride<RhsMapper, LhsMapper, std::enable_if_t<std::is_member_function_pointer<
|
||||
decltype(&RhsMapper::stride)>::value>> : std::true_type {
|
||||
static EIGEN_ALWAYS_INLINE void run(Index j2, Index jend, Index rows, LhsMapper& lhs, RhsMapper& rhs, Packet4f pAlpha, float *result)
|
||||
{
|
||||
using RhsSubMapper = typename RhsMapper::SubMapper;
|
||||
|
||||
RhsSubMapper rhs2 = rhs.getSubMapper(j2, 0);
|
||||
if (rhs.stride() == 1) {
|
||||
calcVSXVecColLoops<LhsMapper, RhsSubMapper, true>(jend - j2, rows, lhs, rhs2, pAlpha, result);
|
||||
} else {
|
||||
calcVSXVecColLoops<LhsMapper, RhsSubMapper, false>(jend - j2, rows, lhs, rhs2, pAlpha, result);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template<typename LhsMapper, typename RhsMapper>
|
||||
void gemv_bfloat16_col(
|
||||
Index rows, Index cols,
|
||||
@ -724,8 +761,6 @@ void gemv_bfloat16_col(
|
||||
bfloat16* res, Index resIncr,
|
||||
bfloat16 alpha)
|
||||
{
|
||||
typedef typename RhsMapper::LinearMapper LinearMapper;
|
||||
|
||||
EIGEN_UNUSED_VARIABLE(resIncr);
|
||||
eigen_internal_assert(resIncr == 1);
|
||||
|
||||
@ -749,14 +784,10 @@ void gemv_bfloat16_col(
|
||||
{
|
||||
Index jend = numext::mini(j2 + block_cols, cols);
|
||||
|
||||
LhsMapper lhs2 = lhs.getSubMapper(0, j2);
|
||||
if (rhs.stride() == 1) {
|
||||
LinearMapper rhs3 = rhs2.getLinearMapper(j2, 0);
|
||||
calcVSXVecColLoops<LhsMapper, LinearMapper, true>(jend - j2, rows, lhs2, rhs3, pAlpha, result);
|
||||
} else {
|
||||
RhsMapper rhs3 = rhs2.getSubMapper(j2, 0);
|
||||
calcVSXVecColLoops<LhsMapper, RhsMapper, false>(jend - j2, rows, lhs2, rhs3, pAlpha, result);
|
||||
}
|
||||
using LhsSubMapper = typename LhsMapper::SubMapper;
|
||||
|
||||
LhsSubMapper lhs2 = lhs.getSubMapper(0, j2);
|
||||
UseStride<RhsMapper, LhsSubMapper>::run(j2, jend, rows, lhs2, rhs2, pAlpha, result);
|
||||
}
|
||||
|
||||
convertArrayPointerF32toBF16VSX(result, rows, res);
|
||||
|
@ -182,6 +182,7 @@ class blas_data_mapper<Scalar,Index,StorageOrder,AlignmentType,1>
|
||||
{
|
||||
public:
|
||||
typedef BlasLinearMapper<Scalar, Index, AlignmentType> LinearMapper;
|
||||
typedef blas_data_mapper<Scalar, Index, StorageOrder, AlignmentType> SubMapper;
|
||||
typedef BlasVectorMapper<Scalar, Index> VectorMapper;
|
||||
|
||||
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE blas_data_mapper(Scalar* data, Index stride, Index incr=1)
|
||||
@ -191,9 +192,9 @@ public:
|
||||
eigen_assert(incr==1);
|
||||
}
|
||||
|
||||
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE blas_data_mapper<Scalar, Index, StorageOrder, AlignmentType>
|
||||
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE SubMapper
|
||||
getSubMapper(Index i, Index j) const {
|
||||
return blas_data_mapper<Scalar, Index, StorageOrder, AlignmentType>(&operator()(i, j), m_stride);
|
||||
return SubMapper(&operator()(i, j), m_stride);
|
||||
}
|
||||
|
||||
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE LinearMapper getLinearMapper(Index i, Index j) const {
|
||||
@ -316,12 +317,13 @@ class blas_data_mapper
|
||||
{
|
||||
public:
|
||||
typedef BlasLinearMapper<Scalar, Index, AlignmentType,Incr> LinearMapper;
|
||||
typedef blas_data_mapper SubMapper;
|
||||
|
||||
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE blas_data_mapper(Scalar* data, Index stride, Index incr) : m_data(data), m_stride(stride), m_incr(incr) {}
|
||||
|
||||
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE blas_data_mapper
|
||||
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE SubMapper
|
||||
getSubMapper(Index i, Index j) const {
|
||||
return blas_data_mapper(&operator()(i, j), m_stride, m_incr.value());
|
||||
return SubMapper(&operator()(i, j), m_stride, m_incr.value());
|
||||
}
|
||||
|
||||
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE LinearMapper getLinearMapper(Index i, Index j) const {
|
||||
@ -448,10 +450,12 @@ protected:
|
||||
template<typename Scalar, typename Index, int StorageOrder>
|
||||
class const_blas_data_mapper : public blas_data_mapper<const Scalar, Index, StorageOrder> {
|
||||
public:
|
||||
typedef const_blas_data_mapper<Scalar, Index, StorageOrder> SubMapper;
|
||||
|
||||
EIGEN_ALWAYS_INLINE const_blas_data_mapper(const Scalar *data, Index stride) : blas_data_mapper<const Scalar, Index, StorageOrder>(data, stride) {}
|
||||
|
||||
EIGEN_ALWAYS_INLINE const_blas_data_mapper<Scalar, Index, StorageOrder> getSubMapper(Index i, Index j) const {
|
||||
return const_blas_data_mapper<Scalar, Index, StorageOrder>(&(this->operator()(i, j)), this->m_stride);
|
||||
EIGEN_ALWAYS_INLINE SubMapper getSubMapper(Index i, Index j) const {
|
||||
return SubMapper(&(this->operator()(i, j)), this->m_stride);
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -443,6 +443,14 @@ class TensorContractionSubMapper {
|
||||
return m_base_mapper.template loadPacket<PacketT,Alignment>(i + m_vert_offset, j + m_horiz_offset);
|
||||
}
|
||||
|
||||
template <typename PacketT>
|
||||
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE PacketT loadPacketPartial(Index i, Index j, Index, Index = 0) const {
|
||||
if (UseDirectOffsets) {
|
||||
return m_base_mapper.template loadPacket<PacketT,Alignment>(i, j);
|
||||
}
|
||||
return m_base_mapper.template loadPacket<PacketT,Alignment>(i + m_vert_offset, j + m_horiz_offset);
|
||||
}
|
||||
|
||||
template <typename PacketT, int AlignmentType>
|
||||
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE PacketT loadPacket(Index i, Index j) const {
|
||||
if (UseDirectOffsets) {
|
||||
@ -473,6 +481,8 @@ class TensorContractionSubMapper {
|
||||
return SubMapper(m_base_mapper, i + m_vert_offset, j + m_horiz_offset);
|
||||
}
|
||||
|
||||
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE const Index stride() const { return m_base_mapper.stride(); }
|
||||
|
||||
template <typename PacketT, int AlignmentType>
|
||||
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE PacketT load(Index i) const {
|
||||
EIGEN_STATIC_ASSERT((internal::is_same<PacketT, PacketT>::value), YOU_MADE_A_PROGRAMMING_MISTAKE);
|
||||
|
Loading…
x
Reference in New Issue
Block a user