Revert changes that made BF16 GEMM to cause bad register spillage for LLVM (Power)

This commit is contained in:
Chip Kerchner 2023-03-13 23:36:06 +00:00 committed by Rasmus Munk Larsen
parent 8fe6190001
commit 6c58f0fe1f

View File

@ -146,8 +146,8 @@ EIGEN_ALWAYS_INLINE void colLoopBodyIter(Index depth, Index rows, const Packet4f
zeroAccumulators<num_acc>(quad_acc);
Index k = 0;
for(Index j = depth >> 1; j--; k += 2){
Index k;
for(k = 0; k + 2 <= depth; k += 2){
KLoop<num_acc, num_packets, false, rhsExtraCols, lhsExtraRows, num_rhs, num_lhs>(indexA, indexB, quad_acc, strideB, k, offsetB, extra_cols, extra_rows);
}
if(depth&1){
@ -185,9 +185,7 @@ void colLoopBody(Index& col, Index depth, Index cols, Index rows, const Packet4f
template<const Index num_acc, const Index num_packets, bool rhsExtraCols, bool lhsExtraRows>
EIGEN_ALWAYS_INLINE void colLoopBodyExtraN(Index col, Index depth, Index cols, Index rows, const Packet4f pAlpha, const bfloat16* indexA, const bfloat16* blockB, Index strideB, Index offsetB, float* result)
{
if (MAX_BFLOAT16_ACC > num_acc) {
colLoopBody<num_acc + (rhsExtraCols ? 1 : 0), num_packets, rhsExtraCols, lhsExtraRows>(col, depth, cols, rows, pAlpha, indexA, blockB, strideB, offsetB, result);
}
colLoopBody<num_acc + (rhsExtraCols ? 1 : 0), num_packets, rhsExtraCols, lhsExtraRows>(col, depth, cols, rows, pAlpha, indexA, blockB, strideB, offsetB, result);
}
template<const Index num_packets, bool rhsExtraCols, bool lhsExtraRows>
@ -415,7 +413,7 @@ EIGEN_ALWAYS_INLINE void storeBF16fromResult(bfloat16* dst, Packet8bf data, Inde
template<const Index size, bool inc>
EIGEN_ALWAYS_INLINE void convertPointerF32toBF16(Index& i, float* result, Index rows, bfloat16*& dst, Index resInc)
{
for(Index j = (rows - i) / size; j--; i += size, dst += size*resInc){
for(; i + size <= rows; i += size, dst += size*resInc){
PacketBlock<Packet8bf,(size+4)/8> r32;
r32.packet[0] = convertF32toBF16<size != 4>(result + i + 0);
if (size >= 16) {
@ -569,12 +567,11 @@ void colVecColLoopBody(Index& row, Index cend, Index rows, LhsMapper& lhs, RhsMa
zeroAccumulators<num_acc>(quad_acc);
LhsMapper lhs2 = lhs.getSubMapper(row, 0);
Index j = 0;
for(Index k = cend >> 1; k--; j += 2) {
for(Index j = 0; j + 2 <= cend; j += 2) {
vecColLoop<num_acc, LhsMapper, RhsMapper, false>(j, lhs2, rhs, quad_acc);
}
if (cend & 1) {
vecColLoop<num_acc, LhsMapper, RhsMapper, true>(j, lhs2, rhs, quad_acc);
vecColLoop<num_acc, LhsMapper, RhsMapper, true>(cend - 1, lhs2, rhs, quad_acc);
}
disassembleAccumulators<num_acc>(quad_acc, acc);
@ -588,9 +585,7 @@ void colVecColLoopBody(Index& row, Index cend, Index rows, LhsMapper& lhs, RhsMa
template<const Index num_acc, typename LhsMapper, typename RhsMapper, bool extraRows>
EIGEN_ALWAYS_INLINE void colVecColLoopBodyExtraN(Index& row, Index cend, Index rows, LhsMapper& lhs, RhsMapper& rhs, const Packet4f pAlpha, float *result)
{
if (MAX_BFLOAT16_VEC_ACC > num_acc) {
colVecColLoopBody<num_acc + (extraRows ? 1 : 0), LhsMapper, RhsMapper, extraRows>(row, cend, rows, lhs, rhs, pAlpha, result);
}
colVecColLoopBody<num_acc + (extraRows ? 1 : 0), LhsMapper, RhsMapper, extraRows>(row, cend, rows, lhs, rhs, pAlpha, result);
}
template<typename LhsMapper, typename RhsMapper, bool extraRows>
@ -769,7 +764,7 @@ template<Index num_acc, typename LhsMapper, typename RhsMapper>
EIGEN_ALWAYS_INLINE void vecLoop(Index cols, const LhsMapper& lhs, RhsMapper& rhs, __vector_quad (&quad_acc)[num_acc], Index extra_cols)
{
Index j = 0;
for(Index k = cols >> 3; k--; j += 8) {
for(; j + 8 <= cols; j += 8){
multVecLoop<num_acc, LhsMapper, RhsMapper, false>(quad_acc, lhs, rhs, j, extra_cols);
}