mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-04-22 17:49:36 +08:00
Revert ODR changes and make gemm_extra_cols and gemm_complex_extra_cols EIGEN_ALWAYS_INLINE to avoid external functions.
This commit is contained in:
parent
f9659d91f1
commit
79cfc74f4d
@ -2188,7 +2188,7 @@ EIGEN_ALWAYS_INLINE void gemm_cols(
|
|||||||
gemm_cols<Scalar, Packet, DataMapper, N, accCols>(res, blockA, blockB, depth, strideA, offsetA, strideB, offsetB, col, rows, remaining_rows, pAlpha, pMask);
|
gemm_cols<Scalar, Packet, DataMapper, N, accCols>(res, blockA, blockB, depth, strideA, offsetA, strideB, offsetB, col, rows, remaining_rows, pAlpha, pMask);
|
||||||
|
|
||||||
template<typename Scalar, typename Packet, typename DataMapper, const Index accCols>
|
template<typename Scalar, typename Packet, typename DataMapper, const Index accCols>
|
||||||
EIGEN_STRONG_INLINE void gemm_extra_cols(
|
EIGEN_ALWAYS_INLINE void gemm_extra_cols(
|
||||||
const DataMapper& res,
|
const DataMapper& res,
|
||||||
const Scalar* blockA,
|
const Scalar* blockA,
|
||||||
const Scalar* blockB,
|
const Scalar* blockB,
|
||||||
@ -2622,7 +2622,7 @@ EIGEN_ALWAYS_INLINE void gemm_complex_cols(
|
|||||||
gemm_complex_cols<Scalar, Packet, Packetc, DataMapper, N, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res, blockA, blockB, depth, strideA, offsetA, strideB, offsetB, col, rows, remaining_rows, pAlphaReal, pAlphaImag, pMask);
|
gemm_complex_cols<Scalar, Packet, Packetc, DataMapper, N, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res, blockA, blockB, depth, strideA, offsetA, strideB, offsetB, col, rows, remaining_rows, pAlphaReal, pAlphaImag, pMask);
|
||||||
|
|
||||||
template<typename Scalar, typename Packet, typename Packetc, typename DataMapper, const Index accCols, bool ConjugateLhs, bool ConjugateRhs, bool LhsIsReal, bool RhsIsReal>
|
template<typename Scalar, typename Packet, typename Packetc, typename DataMapper, const Index accCols, bool ConjugateLhs, bool ConjugateRhs, bool LhsIsReal, bool RhsIsReal>
|
||||||
EIGEN_STRONG_INLINE void gemm_complex_extra_cols(
|
EIGEN_ALWAYS_INLINE void gemm_complex_extra_cols(
|
||||||
const DataMapper& res,
|
const DataMapper& res,
|
||||||
const Scalar* blockA,
|
const Scalar* blockA,
|
||||||
const Scalar* blockB,
|
const Scalar* blockB,
|
||||||
|
@ -30,6 +30,23 @@ EIGEN_ALWAYS_INLINE void gemm_extra_row(
|
|||||||
const Packet& pAlpha,
|
const Packet& pAlpha,
|
||||||
const Packet& pMask);
|
const Packet& pMask);
|
||||||
|
|
||||||
|
template<typename Scalar, typename Packet, typename DataMapper, const Index accCols>
|
||||||
|
EIGEN_ALWAYS_INLINE void gemm_extra_cols(
|
||||||
|
const DataMapper& res,
|
||||||
|
const Scalar* blockA,
|
||||||
|
const Scalar* blockB,
|
||||||
|
Index depth,
|
||||||
|
Index strideA,
|
||||||
|
Index offsetA,
|
||||||
|
Index strideB,
|
||||||
|
Index offsetB,
|
||||||
|
Index col,
|
||||||
|
Index rows,
|
||||||
|
Index cols,
|
||||||
|
Index remaining_rows,
|
||||||
|
const Packet& pAlpha,
|
||||||
|
const Packet& pMask);
|
||||||
|
|
||||||
template<typename Packet>
|
template<typename Packet>
|
||||||
EIGEN_ALWAYS_INLINE Packet bmask(const Index remaining_rows);
|
EIGEN_ALWAYS_INLINE Packet bmask(const Index remaining_rows);
|
||||||
|
|
||||||
@ -50,7 +67,7 @@ EIGEN_ALWAYS_INLINE void gemm_complex_extra_row(
|
|||||||
const Packet& pMask);
|
const Packet& pMask);
|
||||||
|
|
||||||
template<typename Scalar, typename Packet, typename Packetc, typename DataMapper, const Index accCols, bool ConjugateLhs, bool ConjugateRhs, bool LhsIsReal, bool RhsIsReal>
|
template<typename Scalar, typename Packet, typename Packetc, typename DataMapper, const Index accCols, bool ConjugateLhs, bool ConjugateRhs, bool LhsIsReal, bool RhsIsReal>
|
||||||
EIGEN_STRONG_INLINE void gemm_complex_extra_cols(
|
EIGEN_ALWAYS_INLINE void gemm_complex_extra_cols(
|
||||||
const DataMapper& res,
|
const DataMapper& res,
|
||||||
const Scalar* blockA,
|
const Scalar* blockA,
|
||||||
const Scalar* blockB,
|
const Scalar* blockB,
|
||||||
|
@ -429,49 +429,6 @@ EIGEN_ALWAYS_INLINE void gemmMMA_cols(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#define MICRO_EXTRA(MICRO_EXTRA_UNROLL, value, is_col) \
|
|
||||||
switch(value) { \
|
|
||||||
default: \
|
|
||||||
MICRO_EXTRA_UNROLL(1) \
|
|
||||||
break; \
|
|
||||||
case 2: \
|
|
||||||
if (is_col || (sizeof(Scalar) == sizeof(float))) { \
|
|
||||||
MICRO_EXTRA_UNROLL(2) \
|
|
||||||
} \
|
|
||||||
break; \
|
|
||||||
case 3: \
|
|
||||||
if (is_col || (sizeof(Scalar) == sizeof(float))) { \
|
|
||||||
MICRO_EXTRA_UNROLL(3) \
|
|
||||||
} \
|
|
||||||
break; \
|
|
||||||
}
|
|
||||||
|
|
||||||
#define MICRO_EXTRA_COLS(N) \
|
|
||||||
gemmMMA_cols<Scalar, Packet, RhsPacket2, DataMapper, N, accCols>(res, blockA, blockB, depth, strideA, offsetA, strideB, offsetB, col, rows, remaining_rows, pAlpha, pMask);
|
|
||||||
|
|
||||||
template<typename Scalar, typename Packet, typename RhsPacket2, typename DataMapper, const Index accCols>
|
|
||||||
EIGEN_STRONG_INLINE void gemmMMA_extra_cols(
|
|
||||||
const DataMapper& res,
|
|
||||||
const Scalar* blockA,
|
|
||||||
const Scalar* blockB,
|
|
||||||
Index depth,
|
|
||||||
Index strideA,
|
|
||||||
Index offsetA,
|
|
||||||
Index strideB,
|
|
||||||
Index offsetB,
|
|
||||||
Index col,
|
|
||||||
Index rows,
|
|
||||||
Index cols,
|
|
||||||
Index remaining_rows,
|
|
||||||
const Packet& pAlpha,
|
|
||||||
const Packet& pMask)
|
|
||||||
{
|
|
||||||
MICRO_EXTRA(MICRO_EXTRA_COLS, cols-col, true)
|
|
||||||
}
|
|
||||||
|
|
||||||
#undef MICRO_EXTRA
|
|
||||||
#undef MICRO_EXTRA_COLS
|
|
||||||
|
|
||||||
template<typename Scalar, typename Packet, typename RhsPacket, typename DataMapper, const Index accRows, const Index accCols>
|
template<typename Scalar, typename Packet, typename RhsPacket, typename DataMapper, const Index accRows, const Index accCols>
|
||||||
void gemmMMA(const DataMapper& res, const Scalar* blockA, const Scalar* blockB, Index rows, Index depth, Index cols, Scalar alpha, Index strideA, Index strideB, Index offsetA, Index offsetB)
|
void gemmMMA(const DataMapper& res, const Scalar* blockA, const Scalar* blockB, Index rows, Index depth, Index cols, Scalar alpha, Index strideA, Index strideB, Index offsetA, Index offsetB)
|
||||||
{
|
{
|
||||||
@ -493,7 +450,7 @@ void gemmMMA(const DataMapper& res, const Scalar* blockA, const Scalar* blockB,
|
|||||||
|
|
||||||
if (col != cols)
|
if (col != cols)
|
||||||
{
|
{
|
||||||
gemmMMA_extra_cols<Scalar, Packet, RhsPacket2, DataMapper, accCols>(res, blockA, blockB, depth, strideA, offsetA, strideB, offsetB, col, rows, cols, remaining_rows, pAlpha, pMask);
|
gemm_extra_cols<Scalar, Packet, DataMapper, accCols>(res, blockA, blockB, depth, strideA, offsetA, strideB, offsetB, col, rows, cols, remaining_rows, pAlpha, pMask);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user