diff --git a/Eigen/src/Core/arch/AltiVec/MatrixProductCommon.h b/Eigen/src/Core/arch/AltiVec/MatrixProductCommon.h index 28868ca5a..35e7f673e 100644 --- a/Eigen/src/Core/arch/AltiVec/MatrixProductCommon.h +++ b/Eigen/src/Core/arch/AltiVec/MatrixProductCommon.h @@ -30,23 +30,6 @@ EIGEN_ALWAYS_INLINE void gemm_extra_row( const Packet& pAlpha, const Packet& pMask); -template -EIGEN_STRONG_INLINE void gemm_extra_cols( - const DataMapper& res, - const Scalar* blockA, - const Scalar* blockB, - Index depth, - Index strideA, - Index offsetA, - Index strideB, - Index offsetB, - Index col, - Index rows, - Index cols, - Index remaining_rows, - const Packet& pAlpha, - const Packet& pMask); - template EIGEN_ALWAYS_INLINE Packet bmask(const Index remaining_rows); diff --git a/Eigen/src/Core/arch/AltiVec/MatrixProductMMA.h b/Eigen/src/Core/arch/AltiVec/MatrixProductMMA.h index e4013a747..b01dc92a4 100644 --- a/Eigen/src/Core/arch/AltiVec/MatrixProductMMA.h +++ b/Eigen/src/Core/arch/AltiVec/MatrixProductMMA.h @@ -429,6 +429,49 @@ EIGEN_ALWAYS_INLINE void gemmMMA_cols( } } +#define MICRO_EXTRA(MICRO_EXTRA_UNROLL, value, is_col) \ + switch(value) { \ + default: \ + MICRO_EXTRA_UNROLL(1) \ + break; \ + case 2: \ + if (is_col || (sizeof(Scalar) == sizeof(float))) { \ + MICRO_EXTRA_UNROLL(2) \ + } \ + break; \ + case 3: \ + if (is_col || (sizeof(Scalar) == sizeof(float))) { \ + MICRO_EXTRA_UNROLL(3) \ + } \ + break; \ + } + +#define MICRO_EXTRA_COLS(N) \ + gemmMMA_cols(res, blockA, blockB, depth, strideA, offsetA, strideB, offsetB, col, rows, remaining_rows, pAlpha, pMask); + +template +EIGEN_STRONG_INLINE void gemmMMA_extra_cols( + const DataMapper& res, + const Scalar* blockA, + const Scalar* blockB, + Index depth, + Index strideA, + Index offsetA, + Index strideB, + Index offsetB, + Index col, + Index rows, + Index cols, + Index remaining_rows, + const Packet& pAlpha, + const Packet& pMask) +{ + MICRO_EXTRA(MICRO_EXTRA_COLS, cols-col, true) +} + +#undef MICRO_EXTRA +#undef MICRO_EXTRA_COLS + template void gemmMMA(const DataMapper& res, const Scalar* blockA, const Scalar* blockB, Index rows, Index depth, Index cols, Scalar alpha, Index strideA, Index strideB, Index offsetA, Index offsetB) { @@ -450,7 +493,7 @@ void gemmMMA(const DataMapper& res, const Scalar* blockA, const Scalar* blockB, if (col != cols) { - gemm_extra_cols(res, blockA, blockB, depth, strideA, offsetA, strideB, offsetB, col, rows, cols, remaining_rows, pAlpha, pMask); + gemmMMA_extra_cols(res, blockA, blockB, depth, strideA, offsetA, strideB, offsetB, col, rows, cols, remaining_rows, pAlpha, pMask); } }