Fix ODR violation with gemm_extra_cols on PPC

This commit is contained in:
Alexander Grund 2023-02-09 22:16:06 +00:00 committed by Rasmus Munk Larsen
parent 325e3063d9
commit f9659d91f1
2 changed files with 44 additions and 18 deletions

View File

@ -30,23 +30,6 @@ EIGEN_ALWAYS_INLINE void gemm_extra_row(
const Packet& pAlpha,
const Packet& pMask);
template<typename Scalar, typename Packet, typename DataMapper, const Index accCols, bool ConjugateLhs, bool ConjugateRhs, bool LhsIsReal, bool RhsIsReal>
EIGEN_STRONG_INLINE void gemm_extra_cols(
const DataMapper& res,
const Scalar* blockA,
const Scalar* blockB,
Index depth,
Index strideA,
Index offsetA,
Index strideB,
Index offsetB,
Index col,
Index rows,
Index cols,
Index remaining_rows,
const Packet& pAlpha,
const Packet& pMask);
template<typename Packet>
EIGEN_ALWAYS_INLINE Packet bmask(const Index remaining_rows);

View File

@ -429,6 +429,49 @@ EIGEN_ALWAYS_INLINE void gemmMMA_cols(
}
}
#define MICRO_EXTRA(MICRO_EXTRA_UNROLL, value, is_col) \
switch(value) { \
default: \
MICRO_EXTRA_UNROLL(1) \
break; \
case 2: \
if (is_col || (sizeof(Scalar) == sizeof(float))) { \
MICRO_EXTRA_UNROLL(2) \
} \
break; \
case 3: \
if (is_col || (sizeof(Scalar) == sizeof(float))) { \
MICRO_EXTRA_UNROLL(3) \
} \
break; \
}
#define MICRO_EXTRA_COLS(N) \
gemmMMA_cols<Scalar, Packet, RhsPacket2, DataMapper, N, accCols>(res, blockA, blockB, depth, strideA, offsetA, strideB, offsetB, col, rows, remaining_rows, pAlpha, pMask);
template<typename Scalar, typename Packet, typename RhsPacket2, typename DataMapper, const Index accCols>
EIGEN_STRONG_INLINE void gemmMMA_extra_cols(
const DataMapper& res,
const Scalar* blockA,
const Scalar* blockB,
Index depth,
Index strideA,
Index offsetA,
Index strideB,
Index offsetB,
Index col,
Index rows,
Index cols,
Index remaining_rows,
const Packet& pAlpha,
const Packet& pMask)
{
MICRO_EXTRA(MICRO_EXTRA_COLS, cols-col, true)
}
#undef MICRO_EXTRA
#undef MICRO_EXTRA_COLS
template<typename Scalar, typename Packet, typename RhsPacket, typename DataMapper, const Index accRows, const Index accCols>
void gemmMMA(const DataMapper& res, const Scalar* blockA, const Scalar* blockB, Index rows, Index depth, Index cols, Scalar alpha, Index strideA, Index strideB, Index offsetA, Index offsetB)
{
@ -450,7 +493,7 @@ void gemmMMA(const DataMapper& res, const Scalar* blockA, const Scalar* blockB,
if (col != cols)
{
gemm_extra_cols<Scalar, Packet, DataMapper, accCols>(res, blockA, blockB, depth, strideA, offsetA, strideB, offsetB, col, rows, cols, remaining_rows, pAlpha, pMask);
gemmMMA_extra_cols<Scalar, Packet, RhsPacket2, DataMapper, accCols>(res, blockA, blockB, depth, strideA, offsetA, strideB, offsetB, col, rows, cols, remaining_rows, pAlpha, pMask);
}
}