diff --git a/Eigen/src/Core/arch/AltiVec/MatrixProductMMAbfloat16.h b/Eigen/src/Core/arch/AltiVec/MatrixProductMMAbfloat16.h index b30c4f8e3..22dbab585 100644 --- a/Eigen/src/Core/arch/AltiVec/MatrixProductMMAbfloat16.h +++ b/Eigen/src/Core/arch/AltiVec/MatrixProductMMAbfloat16.h @@ -185,7 +185,9 @@ void colLoopBody(Index& col, Index depth, Index cols, Index rows, const Packet4f template EIGEN_ALWAYS_INLINE void colLoopBodyExtraN(Index col, Index depth, Index cols, Index rows, const Packet4f pAlpha, const bfloat16* indexA, const bfloat16* blockB, Index strideB, Index offsetB, float* result) { - colLoopBody(col, depth, cols, rows, pAlpha, indexA, blockB, strideB, offsetB, result); + if (MAX_BFLOAT16_ACC > num_acc) { + colLoopBody(col, depth, cols, rows, pAlpha, indexA, blockB, strideB, offsetB, result); + } } template @@ -585,7 +587,9 @@ void colVecColLoopBody(Index& row, Index cend, Index rows, LhsMapper& lhs, RhsMa template EIGEN_ALWAYS_INLINE void colVecColLoopBodyExtraN(Index& row, Index cend, Index rows, LhsMapper& lhs, RhsMapper& rhs, const Packet4f pAlpha, float *result) { - colVecColLoopBody(row, cend, rows, lhs, rhs, pAlpha, result); + if (MAX_BFLOAT16_VEC_ACC > num_acc) { + colVecColLoopBody(row, cend, rows, lhs, rhs, pAlpha, result); + } } template @@ -798,30 +802,38 @@ void colVecLoopBody(Index& row, Index cols, Index rows, LhsMapper& lhs, RhsMappe } while(multiIters && (num_acc <= rows - (row += num_acc))); } +template +EIGEN_ALWAYS_INLINE void colVecLoopBodyExtraN(Index& row, Index cols, Index rows, LhsMapper& lhs, RhsMapper& rhs, const Packet4f pAlpha, float *result) +{ + if (MAX_BFLOAT16_VEC_ACC > num_acc) { + colVecLoopBody(row, cols, rows, lhs, rhs, pAlpha, result); + } +} + template EIGEN_ALWAYS_INLINE void colVecLoopBodyExtra(Index& row, Index cols, Index rows, LhsMapper& lhs, RhsMapper& rhs, const Packet4f pAlpha, float *result) { switch (rows - row) { case 7: - colVecLoopBody<7, LhsMapper, RhsMapper>(row, cols, rows, lhs, rhs, pAlpha, result); + colVecLoopBodyExtraN<7, LhsMapper, RhsMapper>(row, cols, rows, lhs, rhs, pAlpha, result); break; case 6: - colVecLoopBody<6, LhsMapper, RhsMapper>(row, cols, rows, lhs, rhs, pAlpha, result); + colVecLoopBodyExtraN<6, LhsMapper, RhsMapper>(row, cols, rows, lhs, rhs, pAlpha, result); break; case 5: - colVecLoopBody<5, LhsMapper, RhsMapper>(row, cols, rows, lhs, rhs, pAlpha, result); + colVecLoopBodyExtraN<5, LhsMapper, RhsMapper>(row, cols, rows, lhs, rhs, pAlpha, result); break; case 4: - colVecLoopBody<4, LhsMapper, RhsMapper>(row, cols, rows, lhs, rhs, pAlpha, result); + colVecLoopBodyExtraN<4, LhsMapper, RhsMapper>(row, cols, rows, lhs, rhs, pAlpha, result); break; case 3: - colVecLoopBody<3, LhsMapper, RhsMapper>(row, cols, rows, lhs, rhs, pAlpha, result); + colVecLoopBodyExtraN<3, LhsMapper, RhsMapper>(row, cols, rows, lhs, rhs, pAlpha, result); break; case 2: - colVecLoopBody<2, LhsMapper, RhsMapper>(row, cols, rows, lhs, rhs, pAlpha, result); + colVecLoopBodyExtraN<2, LhsMapper, RhsMapper>(row, cols, rows, lhs, rhs, pAlpha, result); break; case 1: - colVecLoopBody<1, LhsMapper, RhsMapper>(row, cols, rows, lhs, rhs, pAlpha, result); + colVecLoopBodyExtraN<1, LhsMapper, RhsMapper>(row, cols, rows, lhs, rhs, pAlpha, result); break; } }