Revert changes that made BF16 GEMM to cause bad register spillage for LLVM (Power)

This commit is contained in:
Chip Kerchner 2023-03-13 23:36:06 +00:00 committed by Rasmus Munk Larsen
parent 8fe6190001
commit 6c58f0fe1f

View File

@ -146,8 +146,8 @@ EIGEN_ALWAYS_INLINE void colLoopBodyIter(Index depth, Index rows, const Packet4f
zeroAccumulators<num_acc>(quad_acc); zeroAccumulators<num_acc>(quad_acc);
Index k = 0; Index k;
for(Index j = depth >> 1; j--; k += 2){ for(k = 0; k + 2 <= depth; k += 2){
KLoop<num_acc, num_packets, false, rhsExtraCols, lhsExtraRows, num_rhs, num_lhs>(indexA, indexB, quad_acc, strideB, k, offsetB, extra_cols, extra_rows); KLoop<num_acc, num_packets, false, rhsExtraCols, lhsExtraRows, num_rhs, num_lhs>(indexA, indexB, quad_acc, strideB, k, offsetB, extra_cols, extra_rows);
} }
if(depth&1){ if(depth&1){
@ -185,9 +185,7 @@ void colLoopBody(Index& col, Index depth, Index cols, Index rows, const Packet4f
template<const Index num_acc, const Index num_packets, bool rhsExtraCols, bool lhsExtraRows> template<const Index num_acc, const Index num_packets, bool rhsExtraCols, bool lhsExtraRows>
EIGEN_ALWAYS_INLINE void colLoopBodyExtraN(Index col, Index depth, Index cols, Index rows, const Packet4f pAlpha, const bfloat16* indexA, const bfloat16* blockB, Index strideB, Index offsetB, float* result) EIGEN_ALWAYS_INLINE void colLoopBodyExtraN(Index col, Index depth, Index cols, Index rows, const Packet4f pAlpha, const bfloat16* indexA, const bfloat16* blockB, Index strideB, Index offsetB, float* result)
{ {
if (MAX_BFLOAT16_ACC > num_acc) { colLoopBody<num_acc + (rhsExtraCols ? 1 : 0), num_packets, rhsExtraCols, lhsExtraRows>(col, depth, cols, rows, pAlpha, indexA, blockB, strideB, offsetB, result);
colLoopBody<num_acc + (rhsExtraCols ? 1 : 0), num_packets, rhsExtraCols, lhsExtraRows>(col, depth, cols, rows, pAlpha, indexA, blockB, strideB, offsetB, result);
}
} }
template<const Index num_packets, bool rhsExtraCols, bool lhsExtraRows> template<const Index num_packets, bool rhsExtraCols, bool lhsExtraRows>
@ -415,7 +413,7 @@ EIGEN_ALWAYS_INLINE void storeBF16fromResult(bfloat16* dst, Packet8bf data, Inde
template<const Index size, bool inc> template<const Index size, bool inc>
EIGEN_ALWAYS_INLINE void convertPointerF32toBF16(Index& i, float* result, Index rows, bfloat16*& dst, Index resInc) EIGEN_ALWAYS_INLINE void convertPointerF32toBF16(Index& i, float* result, Index rows, bfloat16*& dst, Index resInc)
{ {
for(Index j = (rows - i) / size; j--; i += size, dst += size*resInc){ for(; i + size <= rows; i += size, dst += size*resInc){
PacketBlock<Packet8bf,(size+4)/8> r32; PacketBlock<Packet8bf,(size+4)/8> r32;
r32.packet[0] = convertF32toBF16<size != 4>(result + i + 0); r32.packet[0] = convertF32toBF16<size != 4>(result + i + 0);
if (size >= 16) { if (size >= 16) {
@ -569,12 +567,11 @@ void colVecColLoopBody(Index& row, Index cend, Index rows, LhsMapper& lhs, RhsMa
zeroAccumulators<num_acc>(quad_acc); zeroAccumulators<num_acc>(quad_acc);
LhsMapper lhs2 = lhs.getSubMapper(row, 0); LhsMapper lhs2 = lhs.getSubMapper(row, 0);
Index j = 0; for(Index j = 0; j + 2 <= cend; j += 2) {
for(Index k = cend >> 1; k--; j += 2) {
vecColLoop<num_acc, LhsMapper, RhsMapper, false>(j, lhs2, rhs, quad_acc); vecColLoop<num_acc, LhsMapper, RhsMapper, false>(j, lhs2, rhs, quad_acc);
} }
if (cend & 1) { if (cend & 1) {
vecColLoop<num_acc, LhsMapper, RhsMapper, true>(j, lhs2, rhs, quad_acc); vecColLoop<num_acc, LhsMapper, RhsMapper, true>(cend - 1, lhs2, rhs, quad_acc);
} }
disassembleAccumulators<num_acc>(quad_acc, acc); disassembleAccumulators<num_acc>(quad_acc, acc);
@ -588,9 +585,7 @@ void colVecColLoopBody(Index& row, Index cend, Index rows, LhsMapper& lhs, RhsMa
template<const Index num_acc, typename LhsMapper, typename RhsMapper, bool extraRows> template<const Index num_acc, typename LhsMapper, typename RhsMapper, bool extraRows>
EIGEN_ALWAYS_INLINE void colVecColLoopBodyExtraN(Index& row, Index cend, Index rows, LhsMapper& lhs, RhsMapper& rhs, const Packet4f pAlpha, float *result) EIGEN_ALWAYS_INLINE void colVecColLoopBodyExtraN(Index& row, Index cend, Index rows, LhsMapper& lhs, RhsMapper& rhs, const Packet4f pAlpha, float *result)
{ {
if (MAX_BFLOAT16_VEC_ACC > num_acc) { colVecColLoopBody<num_acc + (extraRows ? 1 : 0), LhsMapper, RhsMapper, extraRows>(row, cend, rows, lhs, rhs, pAlpha, result);
colVecColLoopBody<num_acc + (extraRows ? 1 : 0), LhsMapper, RhsMapper, extraRows>(row, cend, rows, lhs, rhs, pAlpha, result);
}
} }
template<typename LhsMapper, typename RhsMapper, bool extraRows> template<typename LhsMapper, typename RhsMapper, bool extraRows>
@ -769,7 +764,7 @@ template<Index num_acc, typename LhsMapper, typename RhsMapper>
EIGEN_ALWAYS_INLINE void vecLoop(Index cols, const LhsMapper& lhs, RhsMapper& rhs, __vector_quad (&quad_acc)[num_acc], Index extra_cols) EIGEN_ALWAYS_INLINE void vecLoop(Index cols, const LhsMapper& lhs, RhsMapper& rhs, __vector_quad (&quad_acc)[num_acc], Index extra_cols)
{ {
Index j = 0; Index j = 0;
for(Index k = cols >> 3; k--; j += 8) { for(; j + 8 <= cols; j += 8){
multVecLoop<num_acc, LhsMapper, RhsMapper, false>(quad_acc, lhs, rhs, j, extra_cols); multVecLoop<num_acc, LhsMapper, RhsMapper, false>(quad_acc, lhs, rhs, j, extra_cols);
} }