Added partial linear access for LHS & Output - 30% faster for bfloat16 GEMM MMA (Power)

This commit is contained in:
Chip Kerchner 2023-03-02 19:22:43 +00:00 committed by Rasmus Munk Larsen
parent 0b396c3167
commit 2b513ca2a0
3 changed files with 249 additions and 170 deletions

View File

@ -841,7 +841,7 @@ struct dhs_pack<double, DataMapper, Packet2d, StorageOrder, PanelMode, false>
} }
}; };
#ifdef __MMA__ #if !EIGEN_ALTIVEC_DISABLE_MMA && defined(__MMA__)
// General template for lhs packing, bfloat16 specialization. // General template for lhs packing, bfloat16 specialization.
template<typename DataMapper, int StorageOrder, bool PanelMode> template<typename DataMapper, int StorageOrder, bool PanelMode>
struct dhs_pack<bfloat16, DataMapper, Packet8bf, StorageOrder, PanelMode, true> struct dhs_pack<bfloat16, DataMapper, Packet8bf, StorageOrder, PanelMode, true>
@ -1162,6 +1162,7 @@ struct dhs_pack<bfloat16, DataMapper, Packet8bf, StorageOrder, PanelMode, false>
t1 = reinterpret_cast<Packet2ul>(vec_mergeh(reinterpret_cast<Packet4ui>(block.packet[2].m_val), reinterpret_cast<Packet4ui>(block.packet[3].m_val))); t1 = reinterpret_cast<Packet2ul>(vec_mergeh(reinterpret_cast<Packet4ui>(block.packet[2].m_val), reinterpret_cast<Packet4ui>(block.packet[3].m_val)));
t2 = reinterpret_cast<Packet2ul>(vec_mergel(reinterpret_cast<Packet4ui>(block.packet[0].m_val), reinterpret_cast<Packet4ui>(block.packet[1].m_val))); t2 = reinterpret_cast<Packet2ul>(vec_mergel(reinterpret_cast<Packet4ui>(block.packet[0].m_val), reinterpret_cast<Packet4ui>(block.packet[1].m_val)));
t3 = reinterpret_cast<Packet2ul>(vec_mergel(reinterpret_cast<Packet4ui>(block.packet[2].m_val), reinterpret_cast<Packet4ui>(block.packet[3].m_val))); t3 = reinterpret_cast<Packet2ul>(vec_mergel(reinterpret_cast<Packet4ui>(block.packet[2].m_val), reinterpret_cast<Packet4ui>(block.packet[3].m_val)));
block.packet[0] = reinterpret_cast<Packet8us>(vec_mergeh(t0, t1)); block.packet[0] = reinterpret_cast<Packet8us>(vec_mergeh(t0, t1));
block.packet[1] = reinterpret_cast<Packet8us>(vec_mergel(t0, t1)); block.packet[1] = reinterpret_cast<Packet8us>(vec_mergel(t0, t1));
block.packet[2] = reinterpret_cast<Packet8us>(vec_mergeh(t2, t3)); block.packet[2] = reinterpret_cast<Packet8us>(vec_mergeh(t2, t3));
@ -2736,7 +2737,7 @@ void gemm_pack_rhs<double, Index, DataMapper, nr, RowMajor, Conjugate, PanelMode
} }
#endif #endif
#ifdef __MMA__ #if !EIGEN_ALTIVEC_DISABLE_MMA && defined(__MMA__)
#if EIGEN_ALTIVEC_USE_CUSTOM_PACK #if EIGEN_ALTIVEC_USE_CUSTOM_PACK
template<typename Index, typename DataMapper, int nr, bool Conjugate, bool PanelMode> template<typename Index, typename DataMapper, int nr, bool Conjugate, bool PanelMode>
struct gemm_pack_rhs<bfloat16, Index, DataMapper, nr, ColMajor, Conjugate, PanelMode> struct gemm_pack_rhs<bfloat16, Index, DataMapper, nr, ColMajor, Conjugate, PanelMode>
@ -3270,7 +3271,7 @@ void gebp_kernel<double, std::complex<double>, Index, DataMapper, mr, nr, Conjug
gemm_function(res, blockA, blockB, rows, depth, cols, alpha, strideA, strideB, offsetA, offsetB); gemm_function(res, blockA, blockB, rows, depth, cols, alpha, strideA, strideB, offsetA, offsetB);
} }
#if defined(__MMA__) #if !EIGEN_ALTIVEC_DISABLE_MMA && defined(__MMA__)
template<typename Index, typename DataMapper, int mr, int nr, bool ConjugateLhs, bool ConjugateRhs> template<typename Index, typename DataMapper, int mr, int nr, bool ConjugateLhs, bool ConjugateRhs>
struct gebp_kernel<bfloat16, bfloat16, Index, DataMapper, mr, nr, ConjugateLhs, ConjugateRhs> struct gebp_kernel<bfloat16, bfloat16, Index, DataMapper, mr, nr, ConjugateLhs, ConjugateRhs>
{ {
@ -3288,10 +3289,7 @@ void gebp_kernel<bfloat16, bfloat16, Index, DataMapper, mr, nr, ConjugateLhs, Co
Index rows, Index depth, Index cols, bfloat16 alpha, Index rows, Index depth, Index cols, bfloat16 alpha,
Index strideA, Index strideB, Index offsetA, Index offsetB) Index strideA, Index strideB, Index offsetA, Index offsetB)
{ {
const Index accRows = quad_traits<bfloat16>::rows; Eigen::internal::gemmMMAbfloat16<DataMapper>(res, blockA, blockB, rows, depth, cols, alpha, strideA, strideB, offsetA, offsetB);
const Index accCols = quad_traits<bfloat16>::size;
Eigen::internal::gemmMMAbfloat16<Index, Packet, RhsPacket, DataMapper, accRows, accCols>(res, blockA, blockB, rows, depth, cols, alpha, strideA, strideB, offsetA, offsetB);
} }
#endif #endif
} // end namespace internal } // end namespace internal

View File

@ -28,7 +28,9 @@
#include "../../InternalHeaderCheck.h" #include "../../InternalHeaderCheck.h"
#if !EIGEN_ALTIVEC_DISABLE_MMA
#include "MatrixProductMMAbfloat16.h" #include "MatrixProductMMAbfloat16.h"
#endif
namespace Eigen { namespace Eigen {

View File

@ -29,7 +29,7 @@ EIGEN_ALWAYS_INLINE Packet8bf loadRhsBfloat16(const bfloat16* blockB, Index stri
return loadBfloat16<zero>(blockB + strideB*i); return loadBfloat16<zero>(blockB + strideB*i);
} }
template<Index num_acc, Index num_packets, bool zero, bool rhsExtraCols, bool lhsExtraRows> template<Index num_acc, Index num_packets, bool zero, bool rhsExtraCols, bool lhsExtraRows, Index num_rhs, Index num_lhs>
EIGEN_ALWAYS_INLINE void KLoop EIGEN_ALWAYS_INLINE void KLoop
( (
const bfloat16* indexA, const bfloat16* indexA,
@ -42,250 +42,277 @@ EIGEN_ALWAYS_INLINE void KLoop
Index extra_rows Index extra_rows
) )
{ {
Packet8bf lhs = loadBfloat16<zero>(indexA + k*(lhsExtraRows ? extra_rows : num_packets)); //a packet of bfloat16 has 8 elements Packet8bf lhs[num_lhs], rhs[num_rhs];
Packet8bf rhs[num_acc];
for(Index i = 0; i < (num_acc - (rhsExtraCols ? 1 : 0)); i++){ for(Index i = 0; i < (num_rhs - (rhsExtraCols ? 1 : 0)); i++){
rhs[i] = loadRhsBfloat16<zero>(indexB + k*4, strideB, i); rhs[i] = loadRhsBfloat16<zero>(indexB + k*4, strideB, i);
} }
if(rhsExtraCols) { if(rhsExtraCols) {
rhs[num_acc - 1] = loadRhsBfloat16<zero>(indexB + k*extra_cols - offsetB, strideB, num_acc - 1); rhs[num_rhs - 1] = loadRhsBfloat16<zero>(indexB + k*extra_cols - offsetB, strideB, num_rhs - 1);
}
indexA += k*(lhsExtraRows ? extra_rows : num_packets);
for(Index j = 0; j < num_lhs; j++) {
lhs[j] = loadBfloat16<zero>(indexA + j*(zero ? 4 : 8)); //a packet of bfloat16 has 8 elements
} }
BFLOAT16_UNROLL BFLOAT16_UNROLL
for (Index i = 0; i < num_acc; i++) { for(Index i = 0, k = 0; i < num_rhs; i++) {
__builtin_mma_xvbf16ger2pp(&(quad_acc[i]), reinterpret_cast<Packet16uc>(rhs[i].m_val), reinterpret_cast<Packet16uc>(lhs.m_val)); BFLOAT16_UNROLL
for(Index j = 0; j < num_lhs; j++, k++) {
__builtin_mma_xvbf16ger2pp(&(quad_acc[k]), reinterpret_cast<Packet16uc>(rhs[i].m_val), reinterpret_cast<Packet16uc>(lhs[j].m_val));
}
} }
} }
template <bool rhsExtraCols, bool lhsExtraRows> EIGEN_ALWAYS_INLINE Packet4f loadAndMultiplyF32(Packet4f acc, const Packet4f pAlpha, float* result)
{
Packet4f result_block = ploadu<Packet4f>(result);
return pmadd(acc, pAlpha, result_block);
}
template<bool lhsExtraRows>
EIGEN_ALWAYS_INLINE void storeF32(float*& result, Packet4f result_block, Index rows, Index extra_rows)
{
if (lhsExtraRows) {
pstoreu_partial(result, result_block, extra_rows);
} else {
pstoreu(result, result_block);
}
result += rows;
}
template<bool rhsExtraCols, bool lhsExtraRows>
EIGEN_ALWAYS_INLINE void storeResults(Packet4f (&acc)[4], Index rows, const Packet4f pAlpha, float* result, Index extra_cols, Index extra_rows) EIGEN_ALWAYS_INLINE void storeResults(Packet4f (&acc)[4], Index rows, const Packet4f pAlpha, float* result, Index extra_cols, Index extra_rows)
{ {
Index x = 0; Index x = 0;
do{ if (rhsExtraCols) {
Packet4f result_block = ploadu<Packet4f>(result); do{
result_block = pmadd(acc[x], pAlpha, result_block); Packet4f result_block = loadAndMultiplyF32(acc[x], pAlpha, result);
if (lhsExtraRows) { storeF32<lhsExtraRows>(result, result_block, rows, extra_rows);
pstoreu_partial(result, result_block, extra_rows); } while (++x < extra_cols);
} else { } else {
pstoreu(result, result_block); Packet4f result_block[4];
float *result2 = result;
do{
result_block[x] = loadAndMultiplyF32(acc[x], pAlpha, result);
result += rows;
} while (++x < 4);
x = 0;
do{
storeF32<lhsExtraRows>(result2, result_block[x], rows, extra_rows);
} while (++x < 4);
}
}
template<Index num_acc>
EIGEN_ALWAYS_INLINE void zeroAccumulators(__vector_quad (&quad_acc)[num_acc])
{
BFLOAT16_UNROLL
for(Index k = 0; k < num_acc; k++)
__builtin_mma_xxsetaccz(&(quad_acc[k]));
}
template<Index num_acc>
EIGEN_ALWAYS_INLINE void disassembleAccumulators(__vector_quad (&quad_acc)[num_acc], Packet4f (&acc)[num_acc][4])
{
BFLOAT16_UNROLL
for(Index k = 0; k < num_acc; k++)
__builtin_mma_disassemble_acc((void*)acc[k], &(quad_acc[k]));
}
template<Index num_acc, bool rhsExtraCols, bool lhsExtraRows, Index num_rhs, Index num_lhs>
EIGEN_ALWAYS_INLINE void outputResults(Packet4f (&acc)[num_acc][4], Index rows, const Packet4f pAlpha, float* result, const Index extra_cols, Index extra_rows)
{
for(Index i = 0, k = 0; i < num_rhs - (rhsExtraCols ? 1 : 0); i++, result += 4*rows){
for(Index j = 0; j < num_lhs; j++, k++) {
storeResults<false, lhsExtraRows>(acc[k], rows, pAlpha, result + j*4, extra_cols, extra_rows);
} }
result += rows; }
} while (++x < (rhsExtraCols ? extra_cols : 4)); if(rhsExtraCols) {
storeResults<rhsExtraCols, lhsExtraRows>(acc[num_acc - 1], rows, pAlpha, result, extra_cols, extra_rows);
}
}
template<const Index num_acc, const Index num_packets, bool rhsExtraCols, bool lhsExtraRows, bool multiIter = false>
EIGEN_ALWAYS_INLINE void colLoopBodyIter(Index depth, Index rows, const Packet4f pAlpha, const bfloat16* indexA, const bfloat16* indexB, Index strideB, Index offsetB, float* result, const Index extra_cols, const Index extra_rows)
{
constexpr Index num_lhs = multiIter ? (num_packets / 4) : 1;
constexpr Index num_rhs = (num_acc + num_lhs - 1) / num_lhs;
for(Index offset_row = 0; offset_row < num_packets; offset_row += 4, indexA += (multiIter ? 0 : 8), indexB += (multiIter ? (num_rhs*strideB) : 0), result += (multiIter ? (4*rows*num_rhs) : 4)) {
Packet4f acc[num_acc][4];
__vector_quad quad_acc[num_acc];
zeroAccumulators<num_acc>(quad_acc);
Index k;
for(k = 0; k + 2 <= depth; k += 2){
KLoop<num_acc, num_packets, false, rhsExtraCols, lhsExtraRows, num_rhs, num_lhs>(indexA, indexB, quad_acc, strideB, k, offsetB, extra_cols, extra_rows);
}
if(depth&1){
KLoop<num_acc, num_packets, true, rhsExtraCols, lhsExtraRows, num_rhs, num_lhs>(indexA - (multiIter ? 0 : offset_row), indexB, quad_acc, strideB, k, offsetB, extra_cols, extra_rows);
}
disassembleAccumulators<num_acc>(quad_acc, acc);
outputResults<num_acc, rhsExtraCols, lhsExtraRows, num_rhs, num_lhs>(acc, rows, pAlpha, result, extra_cols, extra_rows);
}
} }
#define MAX_BFLOAT16_ACC 8 #define MAX_BFLOAT16_ACC 8
template<const Index num_acc, const Index num_packets, bool rhsExtraCols, bool lhsExtraRows> template<const Index num_acc, const Index num_packets, bool rhsExtraCols, bool lhsExtraRows>
void colLoopBody(Index& col, Index depth, Index cols, Index rows, const Packet4f pAlpha, const bfloat16* indexA, const bfloat16* indexB, Index strideB, Index offsetB, float* result, Index extra_rows) void colLoopBody(Index& col, Index depth, Index cols, Index rows, const Packet4f pAlpha, const bfloat16* indexA, const bfloat16* indexB, Index strideB, Index offsetB, float* result)
{ {
const Index step = (num_acc * 4); //each accumulator has 4 elements constexpr Index step = (num_acc * 4); //each accumulator has 4 elements
const Index extra_cols = (rhsExtraCols) ? (cols & 3) : 0; const Index extra_cols = (rhsExtraCols) ? (cols & 3) : 0;
const Index extra_rows = (lhsExtraRows) ? (rows & 3) : 0;
constexpr bool multiIters = !rhsExtraCols && (num_acc == MAX_BFLOAT16_ACC);
do{ do{
for(Index offset_row = 0; offset_row < num_packets; offset_row += 4, indexA += 8, result += 4) { if (multiIters && ((num_acc % (num_packets / 4)) == 0)) {
Index k; colLoopBodyIter<num_acc, num_packets, rhsExtraCols, lhsExtraRows, true>(depth, rows, pAlpha, indexA, indexB, strideB, offsetB, result, extra_cols, extra_rows);
Packet4f acc[num_acc][4]; } else {
__vector_quad quad_acc[num_acc]; colLoopBodyIter<num_acc, num_packets, rhsExtraCols, lhsExtraRows>(depth, rows, pAlpha, indexA, indexB, strideB, offsetB, result, extra_cols, extra_rows);
BFLOAT16_UNROLL
for(k = 0; k < num_acc; k++)
__builtin_mma_xxsetaccz(&(quad_acc[k]));
for(k = 0; k + 2 <= depth; k += 2){
KLoop<num_acc, num_packets, false, rhsExtraCols, lhsExtraRows>(indexA, indexB, quad_acc, strideB, k, offsetB, extra_cols, extra_rows);
}
if(depth&1){
KLoop<num_acc, num_packets, true, rhsExtraCols, lhsExtraRows>(indexA - offset_row, indexB, quad_acc, strideB, k, offsetB, extra_cols, extra_rows);
}
BFLOAT16_UNROLL
for(k = 0; k < num_acc; k++)
__builtin_mma_disassemble_acc((void*)acc[k], &(quad_acc[k]));
for(k = 0; k < (num_acc - 1); k++){
storeResults<false, lhsExtraRows>(acc[k], rows, pAlpha, result + k*4*rows, extra_cols, extra_rows);
}
storeResults<rhsExtraCols, lhsExtraRows>(acc[k], rows, pAlpha, result + k*4*rows, extra_cols, extra_rows);
} }
indexA -= num_packets*2;
indexB += strideB*num_acc; indexB += strideB*num_acc;
result += (rows*step - num_packets); result += rows*step;
} while(!rhsExtraCols && (num_acc == MAX_BFLOAT16_ACC) && (step <= cols - (col += step))); } while(multiIters && (step <= cols - (col += step)));
} }
template<const Index num_acc, const Index num_packets, bool rhsExtraCols, bool lhsExtraRows> template<const Index num_acc, const Index num_packets, bool rhsExtraCols, bool lhsExtraRows>
EIGEN_ALWAYS_INLINE void colLoopBodyExtraN(Index col, Index depth, Index cols, Index rows, const Packet4f pAlpha, const bfloat16* indexA, const bfloat16* blockB, Index strideB, Index offsetB, float* result, Index extra_rows) EIGEN_ALWAYS_INLINE void colLoopBodyExtraN(Index col, Index depth, Index cols, Index rows, const Packet4f pAlpha, const bfloat16* indexA, const bfloat16* blockB, Index strideB, Index offsetB, float* result)
{ {
if (MAX_BFLOAT16_ACC > num_acc) { if (MAX_BFLOAT16_ACC > num_acc) {
colLoopBody<num_acc + (rhsExtraCols ? 1 : 0), num_packets, rhsExtraCols, lhsExtraRows>(col, depth, cols, rows, pAlpha, indexA, blockB, strideB, offsetB, result, extra_rows); colLoopBody<num_acc + (rhsExtraCols ? 1 : 0), num_packets, rhsExtraCols, lhsExtraRows>(col, depth, cols, rows, pAlpha, indexA, blockB, strideB, offsetB, result);
} }
} }
template<const Index num_packets, bool rhsExtraCols, bool lhsExtraRows> template<const Index num_packets, bool rhsExtraCols, bool lhsExtraRows>
void colLoopBodyExtra(Index col, Index depth, Index cols, Index rows, const Packet4f pAlpha, const bfloat16* indexA, const bfloat16* blockB, Index strideB, Index offsetB, float* result, Index extra_rows) void colLoopBodyExtra(Index col, Index depth, Index cols, Index rows, const Packet4f pAlpha, const bfloat16* indexA, const bfloat16* blockB, Index strideB, Index offsetB, float* result)
{ {
switch ((cols - col) >> 2) { switch ((cols - col) >> 2) {
case 7: case 7:
colLoopBodyExtraN<7, num_packets, rhsExtraCols, lhsExtraRows>(col, depth, cols, rows, pAlpha, indexA, blockB, strideB, offsetB, result, extra_rows); colLoopBodyExtraN<7, num_packets, rhsExtraCols, lhsExtraRows>(col, depth, cols, rows, pAlpha, indexA, blockB, strideB, offsetB, result);
break; break;
case 6: case 6:
colLoopBodyExtraN<6, num_packets, rhsExtraCols, lhsExtraRows>(col, depth, cols, rows, pAlpha, indexA, blockB, strideB, offsetB, result, extra_rows); colLoopBodyExtraN<6, num_packets, rhsExtraCols, lhsExtraRows>(col, depth, cols, rows, pAlpha, indexA, blockB, strideB, offsetB, result);
break; break;
case 5: case 5:
colLoopBodyExtraN<5, num_packets, rhsExtraCols, lhsExtraRows>(col, depth, cols, rows, pAlpha, indexA, blockB, strideB, offsetB, result, extra_rows); colLoopBodyExtraN<5, num_packets, rhsExtraCols, lhsExtraRows>(col, depth, cols, rows, pAlpha, indexA, blockB, strideB, offsetB, result);
break; break;
case 4: case 4:
colLoopBodyExtraN<4, num_packets, rhsExtraCols, lhsExtraRows>(col, depth, cols, rows, pAlpha, indexA, blockB, strideB, offsetB, result, extra_rows); colLoopBodyExtraN<4, num_packets, rhsExtraCols, lhsExtraRows>(col, depth, cols, rows, pAlpha, indexA, blockB, strideB, offsetB, result);
break; break;
case 3: case 3:
colLoopBodyExtraN<3, num_packets, rhsExtraCols, lhsExtraRows>(col, depth, cols, rows, pAlpha, indexA, blockB, strideB, offsetB, result, extra_rows); colLoopBodyExtraN<3, num_packets, rhsExtraCols, lhsExtraRows>(col, depth, cols, rows, pAlpha, indexA, blockB, strideB, offsetB, result);
break; break;
case 2: case 2:
colLoopBodyExtraN<2, num_packets, rhsExtraCols, lhsExtraRows>(col, depth, cols, rows, pAlpha, indexA, blockB, strideB, offsetB, result, extra_rows); colLoopBodyExtraN<2, num_packets, rhsExtraCols, lhsExtraRows>(col, depth, cols, rows, pAlpha, indexA, blockB, strideB, offsetB, result);
break; break;
case 1: case 1:
colLoopBodyExtraN<1, num_packets, rhsExtraCols, lhsExtraRows>(col, depth, cols, rows, pAlpha, indexA, blockB, strideB, offsetB, result, extra_rows); colLoopBodyExtraN<1, num_packets, rhsExtraCols, lhsExtraRows>(col, depth, cols, rows, pAlpha, indexA, blockB, strideB, offsetB, result);
break; break;
default: default:
if (rhsExtraCols) { if (rhsExtraCols) {
colLoopBody<1, num_packets, true, lhsExtraRows>(col, depth, cols, rows, pAlpha, indexA, blockB, strideB, offsetB, result, extra_rows); colLoopBody<1, num_packets, true, lhsExtraRows>(col, depth, cols, rows, pAlpha, indexA, blockB, strideB, offsetB, result);
} }
break; break;
} }
} }
template<const Index num_packets, bool lhsExtraRows = false> template<const Index num_packets, bool lhsExtraRows = false>
EIGEN_ALWAYS_INLINE void colLoops(Index depth, Index cols, Index rows, const Packet4f pAlpha, const bfloat16* indexA, const bfloat16* blockB, Index strideB, Index offsetB, float* result, Index extra_rows = 0) EIGEN_ALWAYS_INLINE void colLoops(Index depth, Index cols, Index rows, const Packet4f pAlpha, const bfloat16* indexA, const bfloat16* blockB, Index strideB, Index offsetB, float* result)
{ {
Index col = 0; Index col = 0;
if (cols >= (MAX_BFLOAT16_ACC * 4)) { if (cols >= (MAX_BFLOAT16_ACC * 4)) {
colLoopBody<MAX_BFLOAT16_ACC, num_packets, false, lhsExtraRows>(col, depth, cols, rows, pAlpha, indexA, blockB, strideB, 0, result, extra_rows); colLoopBody<MAX_BFLOAT16_ACC, num_packets, false, lhsExtraRows>(col, depth, cols, rows, pAlpha, indexA, blockB, strideB, 0, result);
blockB += (strideB >> 2)*col; blockB += (strideB >> 2)*col;
result += rows*col; result += rows*col;
} }
if (cols & 3) { if (cols & 3) {
colLoopBodyExtra<num_packets, true, lhsExtraRows>(col, depth, cols, rows, pAlpha, indexA, blockB, strideB, offsetB, result, extra_rows); colLoopBodyExtra<num_packets, true, lhsExtraRows>(col, depth, cols, rows, pAlpha, indexA, blockB, strideB, offsetB, result);
} else { } else {
colLoopBodyExtra<num_packets, false, lhsExtraRows>(col, depth, cols, rows, pAlpha, indexA, blockB, strideB, 0, result, extra_rows); colLoopBodyExtra<num_packets, false, lhsExtraRows>(col, depth, cols, rows, pAlpha, indexA, blockB, strideB, 0, result);
} }
} }
EIGEN_ALWAYS_INLINE Packet8bf convertF16toF32(const float *res) template<bool full = true>
EIGEN_ALWAYS_INLINE Packet8bf convertF32toBF16(const float *res)
{ {
Packet16uc fp16_0 = __builtin_vsx_xvcvspbf16(reinterpret_cast<Packet16uc>(ploadu<Packet4f>(res + 0))); Packet16uc fp16_0 = __builtin_vsx_xvcvspbf16(reinterpret_cast<Packet16uc>(ploadu<Packet4f>(res + 0)));
Packet16uc fp16_1 = __builtin_vsx_xvcvspbf16(reinterpret_cast<Packet16uc>(ploadu<Packet4f>(res + 4))); Packet16uc fp16_1 = (full) ? __builtin_vsx_xvcvspbf16(reinterpret_cast<Packet16uc>(ploadu<Packet4f>(res + 4))) : fp16_0;
return vec_pack(reinterpret_cast<Packet4ui>(fp16_0), reinterpret_cast<Packet4ui>(fp16_1)); return vec_pack(reinterpret_cast<Packet4ui>(fp16_0), reinterpret_cast<Packet4ui>(fp16_1));
} }
template<typename Index, typename Packet, typename RhsPacket, typename DataMapper, const Index accRows, const Index accCols> template<int N>
void gemmMMAbfloat16(const DataMapper& res, const bfloat16* blockA, const bfloat16* blockB, Index rows, Index depth, Index cols, bfloat16 alpha, Index strideA, Index strideB, Index offsetA, Index offsetB) EIGEN_ALWAYS_INLINE void storeConvertBlockBF16(float* to, PacketBlock<Packet8bf,(N+4)/8>& block)
{ {
if(rows == 0 || cols == 0 || depth == 0) return;
float falpha = Eigen::bfloat16_impl::bfloat16_to_float(alpha);
if (falpha == float(0)) return;
const Packet4f pAlpha = pset1<Packet4f>(falpha);
ei_declare_aligned_stack_constructed_variable(float, result, cols*rows, 0);
typedef typename DataMapper::LinearMapper LinearMapper;
Packet8us z = pset1<Packet8us>(0); Packet8us z = pset1<Packet8us>(0);
for(Index j = 0; j < cols; j++){ pstore(to + 0, reinterpret_cast<Packet4f>(vec_mergeh(z, block.packet[0].m_val)));
const LinearMapper res2 = res.getLinearMapper(0, j); if (N >= 8) {
float *result2 = result + j*rows; pstore(to + 4, reinterpret_cast<Packet4f>(vec_mergel(z, block.packet[0].m_val)));
}
if (N >= 16) {
pstore(to + 8, reinterpret_cast<Packet4f>(vec_mergeh(z, block.packet[1].m_val)));
pstore(to + 12, reinterpret_cast<Packet4f>(vec_mergel(z, block.packet[1].m_val)));
}
if (N >= 32) {
pstore(to + 16, reinterpret_cast<Packet4f>(vec_mergeh(z, block.packet[2].m_val)));
pstore(to + 20, reinterpret_cast<Packet4f>(vec_mergel(z, block.packet[2].m_val)));
pstore(to + 24, reinterpret_cast<Packet4f>(vec_mergeh(z, block.packet[3].m_val)));
pstore(to + 28, reinterpret_cast<Packet4f>(vec_mergel(z, block.packet[3].m_val)));
}
}
template<const Index size, typename DataMapper>
EIGEN_ALWAYS_INLINE void convertBF16toF32(Index& i, float *result, Index rows, const DataMapper& src)
{
for(; i + size <= rows; i += size){
PacketBlock<Packet8bf,(size+4)/8> r32;
r32.packet[0] = src.template loadPacket<Packet8bf>(i + 0);
if (size >= 16) {
r32.packet[1] = src.template loadPacket<Packet8bf>(i + 8);
}
if (size >= 32) {
r32.packet[2] = src.template loadPacket<Packet8bf>(i + 16);
r32.packet[3] = src.template loadPacket<Packet8bf>(i + 24);
}
storeConvertBlockBF16<size>(result + i, r32);
}
}
template<typename DataMapper>
EIGEN_ALWAYS_INLINE void convertArrayBF16toF32(float *result, Index cols, Index rows, const DataMapper& src)
{
typedef typename DataMapper::LinearMapper LinearMapper;
for(Index j = 0; j < cols; j++, result += rows){
const LinearMapper src2 = src.getLinearMapper(0, j);
Index i = 0; Index i = 0;
for(; i + 32 <= rows; i+=32){ convertBF16toF32<32, LinearMapper>(i, result, rows, src2);
Packet8us r32_0 = res2.template loadPacket<Packet8bf>(i + 0).m_val; convertBF16toF32<16, LinearMapper>(i, result, rows, src2);
Packet8us r32_1 = res2.template loadPacket<Packet8bf>(i + 8).m_val; convertBF16toF32<8, LinearMapper>(i, result, rows, src2);
Packet8us r32_2 = res2.template loadPacket<Packet8bf>(i + 16).m_val; convertBF16toF32<4, LinearMapper>(i, result, rows, src2);
Packet8us r32_3 = res2.template loadPacket<Packet8bf>(i + 24).m_val;
pstore(result2 + i + 0, reinterpret_cast<Packet4f>(vec_mergeh(z, r32_0)));
pstore(result2 + i + 4, reinterpret_cast<Packet4f>(vec_mergel(z, r32_0)));
pstore(result2 + i + 8, reinterpret_cast<Packet4f>(vec_mergeh(z, r32_1)));
pstore(result2 + i + 12, reinterpret_cast<Packet4f>(vec_mergel(z, r32_1)));
pstore(result2 + i + 16, reinterpret_cast<Packet4f>(vec_mergeh(z, r32_2)));
pstore(result2 + i + 20, reinterpret_cast<Packet4f>(vec_mergel(z, r32_2)));
pstore(result2 + i + 24, reinterpret_cast<Packet4f>(vec_mergeh(z, r32_3)));
pstore(result2 + i + 28, reinterpret_cast<Packet4f>(vec_mergel(z, r32_3)));
}
for(; i + 16 <= rows; i+=16){
Packet8us r32_0 = res2.template loadPacket<Packet8bf>(i + 0).m_val;
Packet8us r32_1 = res2.template loadPacket<Packet8bf>(i + 8).m_val;
pstore(result2 + i + 0, reinterpret_cast<Packet4f>(vec_mergeh(z, r32_0)));
pstore(result2 + i + 4, reinterpret_cast<Packet4f>(vec_mergel(z, r32_0)));
pstore(result2 + i + 8, reinterpret_cast<Packet4f>(vec_mergeh(z, r32_1)));
pstore(result2 + i + 12, reinterpret_cast<Packet4f>(vec_mergel(z, r32_1)));
}
for(; i + 8 <= rows; i+=8){
Packet8us r32_0 = res2.template loadPacket<Packet8bf>(i + 0).m_val;
pstore(result2 + i + 0, reinterpret_cast<Packet4f>(vec_mergeh(z, r32_0)));
pstore(result2 + i + 4, reinterpret_cast<Packet4f>(vec_mergel(z, r32_0)));
}
for(; i + 4 <= rows; i+=4){
Packet8us r32_0 = res2.template loadPacketPartial<Packet8bf>(i + 0, 4).m_val;
pstore(result2 + i + 0, reinterpret_cast<Packet4f>(vec_mergeh(z, r32_0)));
}
for(; i < rows; i++){ for(; i < rows; i++){
result2[i] = res2(i); result[i] = Eigen::bfloat16_impl::bfloat16_to_float(src2(i));
} }
} }
}
Index row = 0; template<typename DataMapper>
Index col; EIGEN_ALWAYS_INLINE void convertArrayF32toBF16(float *result, Index cols, Index rows, const DataMapper& res)
{
if( strideA == -1 ) strideA = depth; typedef typename DataMapper::LinearMapper LinearMapper;
if( strideB == -1 ) strideB = depth; Index col, row;
//Packing is done in blocks.
//There's 4 possible sizes of blocks
//Blocks of 8 columns with 16 elements (8x16)
//Blocks of 8 columns with 8 elements (8x8). This happens when there's 16 > rows >= 8
//Blocks of 8 columns with 4 elements (8x4). This happens when there's 8 > rows >= 4
//Blocks of 8 columns with < 4 elements. This happens when there's less than 4 remaining rows
//Loop for LHS standard block (8x16)
const Index standard_block_size = 16;
const Index standard_blocks_quantity = rows/standard_block_size; //Number of standard blocks
Index bigSuffix = (2*8) * (strideA-offsetA);
const bfloat16* indexA = blockA;
const bfloat16* indexB = blockB + 4*offsetB;
Index block_index;
strideB *= 4;
offsetB *= 3;
for(block_index = 0; block_index < standard_blocks_quantity; block_index++){
indexA += 2*8*offsetA;
colLoops<16>(depth, cols, rows, pAlpha, indexA, indexB, strideB, offsetB, result + row);
row += 16;
indexA += bigSuffix;
}
//LHS (8x8) block
if(rows & 8){
indexA += 1*8*offsetA;
colLoops<8>(depth, cols, rows, pAlpha, indexA, indexB, strideB, offsetB, result + row);
row += 8;
indexA += (bigSuffix >> 1);
}
//LHS (8x4) block
if(rows & 4){
indexA += 1*4*offsetA;
colLoops<4>(depth, cols, rows, pAlpha, indexA, indexB, strideB, offsetB, result + row);
row += 4;
indexA += (bigSuffix >> 2);
}
//extra rows
Index extra_rows = rows & 3;
if(extra_rows){
//This index is the beginning of remaining block.
colLoops<4, true>(depth, cols, rows, pAlpha, indexA, indexB, strideB, offsetB, result + row, extra_rows);
}
//Convert back to bfloat16
for(col = 0; col + 4 <= cols; col += 4){ for(col = 0; col + 4 <= cols; col += 4){
const DataMapper res2 = res.getSubMapper(0, col); const DataMapper res2 = res.getSubMapper(0, col);
for(row = 0; row + 8 <= rows; row += 8){ for(row = 0; row + 8 <= rows; row += 8){
//get and save block //get and save block
PacketBlock<Packet8bf,4> block; PacketBlock<Packet8bf,4> block;
for(Index j = 0; j < 4; j++){ for(Index j = 0; j < 4; j++){
block.packet[j].m_val = convertF16toF32(result + (col + j)*rows + row); block.packet[j].m_val = convertF32toBF16(result + (col + j)*rows + row);
} }
res2.template storePacketBlock<Packet8bf,4>(row, 0, block); res2.template storePacketBlock<Packet8bf,4>(row, 0, block);
@ -303,18 +330,70 @@ void gemmMMAbfloat16(const DataMapper& res, const bfloat16* blockA, const bfloat
while(col < cols){ while(col < cols){
const LinearMapper res2 = res.getLinearMapper(0, col); const LinearMapper res2 = res.getLinearMapper(0, col);
float *result2 = result + col*rows; float *result2 = result + col*rows;
Index r = 0; for(row = 0; row + 8 <= rows; row += 8){
for(; r + 8 <= rows; r += 8){ Packet8bf fp16 = convertF32toBF16(result2 + row);
Packet8bf fp16 = convertF16toF32(result2 + r); res2.template storePacket<Packet8bf>(row, fp16);
res2.template storePacket<Packet8bf>(r, fp16);
} }
for(; r< rows; r++){ for(; row < rows; row++){
res2(r) = Eigen::bfloat16(result2[r]); res2(row) = Eigen::bfloat16(result2[row]);
} }
col++; col++;
} }
} }
template<Index size>
EIGEN_ALWAYS_INLINE void calcColLoops(const bfloat16*& indexA, Index& row, Index depth, Index cols, Index rows, const Packet4f pAlpha, const bfloat16* indexB, Index strideB, Index offsetA, Index offsetB, Index bigSuffix, float *result)
{
if ((size == 16) || (rows & size)) {
indexA += size*offsetA;
colLoops<size>(depth, cols, rows, pAlpha, indexA, indexB, strideB, offsetB, result + row);
row += size;
indexA += bigSuffix*size/16;
}
}
template<typename DataMapper>
void gemmMMAbfloat16(const DataMapper& res, const bfloat16* indexA, const bfloat16* indexB, Index rows, Index depth, Index cols, bfloat16 alpha, Index strideA, Index strideB, Index offsetA, Index offsetB)
{
float falpha = Eigen::bfloat16_impl::bfloat16_to_float(alpha);
if (falpha == float(0)) return;
const Packet4f pAlpha = pset1<Packet4f>(falpha);
ei_declare_aligned_stack_constructed_variable(float, result, cols*rows, 0);
convertArrayBF16toF32<DataMapper>(result, cols, rows, res);
Index row = 0;
if( strideA == -1 ) strideA = depth;
if( strideB == -1 ) strideB = depth;
//Packing is done in blocks.
//There's 4 possible sizes of blocks
//Blocks of 8 columns with 16 elements (8x16)
//Blocks of 8 columns with 8 elements (8x8). This happens when there's 16 > rows >= 8
//Blocks of 8 columns with 4 elements (8x4). This happens when there's 8 > rows >= 4
//Blocks of 8 columns with < 4 elements. This happens when there's less than 4 remaining rows
//Loop for LHS standard block (8x16)
Index bigSuffix = (2*8) * (strideA-offsetA);
indexB += 4*offsetB;
strideB *= 4;
offsetB *= 3;
while(row + 16 <= rows){
calcColLoops<16>(indexA, row, depth, cols, rows, pAlpha, indexB, strideB, offsetA, offsetB, bigSuffix, result);
}
//LHS (8x8) block
calcColLoops<8>(indexA, row, depth, cols, rows, pAlpha, indexB, strideB, offsetA, offsetB, bigSuffix, result);
//LHS (8x4) block
calcColLoops<4>(indexA, row, depth, cols, rows, pAlpha, indexB, strideB, offsetA, offsetB, bigSuffix, result);
//extra rows
if(rows & 3){
//This index is the beginning of remaining block.
colLoops<4, true>(depth, cols, rows, pAlpha, indexA, indexB, strideB, offsetB, result + row);
}
//Convert back to bfloat16
convertArrayF32toBF16<DataMapper>(result, cols, rows, res);
}
} }
} }