Added partial linear access for LHS & Output - 30% faster for bfloat16 GEMM MMA (Power)

This commit is contained in:
Chip Kerchner 2023-03-02 19:22:43 +00:00 committed by Rasmus Munk Larsen
parent 0b396c3167
commit 2b513ca2a0
3 changed files with 249 additions and 170 deletions

View File

@ -841,7 +841,7 @@ struct dhs_pack<double, DataMapper, Packet2d, StorageOrder, PanelMode, false>
}
};
#ifdef __MMA__
#if !EIGEN_ALTIVEC_DISABLE_MMA && defined(__MMA__)
// General template for lhs packing, bfloat16 specialization.
template<typename DataMapper, int StorageOrder, bool PanelMode>
struct dhs_pack<bfloat16, DataMapper, Packet8bf, StorageOrder, PanelMode, true>
@ -1162,6 +1162,7 @@ struct dhs_pack<bfloat16, DataMapper, Packet8bf, StorageOrder, PanelMode, false>
t1 = reinterpret_cast<Packet2ul>(vec_mergeh(reinterpret_cast<Packet4ui>(block.packet[2].m_val), reinterpret_cast<Packet4ui>(block.packet[3].m_val)));
t2 = reinterpret_cast<Packet2ul>(vec_mergel(reinterpret_cast<Packet4ui>(block.packet[0].m_val), reinterpret_cast<Packet4ui>(block.packet[1].m_val)));
t3 = reinterpret_cast<Packet2ul>(vec_mergel(reinterpret_cast<Packet4ui>(block.packet[2].m_val), reinterpret_cast<Packet4ui>(block.packet[3].m_val)));
block.packet[0] = reinterpret_cast<Packet8us>(vec_mergeh(t0, t1));
block.packet[1] = reinterpret_cast<Packet8us>(vec_mergel(t0, t1));
block.packet[2] = reinterpret_cast<Packet8us>(vec_mergeh(t2, t3));
@ -2736,7 +2737,7 @@ void gemm_pack_rhs<double, Index, DataMapper, nr, RowMajor, Conjugate, PanelMode
}
#endif
#ifdef __MMA__
#if !EIGEN_ALTIVEC_DISABLE_MMA && defined(__MMA__)
#if EIGEN_ALTIVEC_USE_CUSTOM_PACK
template<typename Index, typename DataMapper, int nr, bool Conjugate, bool PanelMode>
struct gemm_pack_rhs<bfloat16, Index, DataMapper, nr, ColMajor, Conjugate, PanelMode>
@ -3270,7 +3271,7 @@ void gebp_kernel<double, std::complex<double>, Index, DataMapper, mr, nr, Conjug
gemm_function(res, blockA, blockB, rows, depth, cols, alpha, strideA, strideB, offsetA, offsetB);
}
#if defined(__MMA__)
#if !EIGEN_ALTIVEC_DISABLE_MMA && defined(__MMA__)
template<typename Index, typename DataMapper, int mr, int nr, bool ConjugateLhs, bool ConjugateRhs>
struct gebp_kernel<bfloat16, bfloat16, Index, DataMapper, mr, nr, ConjugateLhs, ConjugateRhs>
{
@ -3288,10 +3289,7 @@ void gebp_kernel<bfloat16, bfloat16, Index, DataMapper, mr, nr, ConjugateLhs, Co
Index rows, Index depth, Index cols, bfloat16 alpha,
Index strideA, Index strideB, Index offsetA, Index offsetB)
{
const Index accRows = quad_traits<bfloat16>::rows;
const Index accCols = quad_traits<bfloat16>::size;
Eigen::internal::gemmMMAbfloat16<Index, Packet, RhsPacket, DataMapper, accRows, accCols>(res, blockA, blockB, rows, depth, cols, alpha, strideA, strideB, offsetA, offsetB);
Eigen::internal::gemmMMAbfloat16<DataMapper>(res, blockA, blockB, rows, depth, cols, alpha, strideA, strideB, offsetA, offsetB);
}
#endif
} // end namespace internal

View File

@ -28,7 +28,9 @@
#include "../../InternalHeaderCheck.h"
#if !EIGEN_ALTIVEC_DISABLE_MMA
#include "MatrixProductMMAbfloat16.h"
#endif
namespace Eigen {

View File

@ -29,7 +29,7 @@ EIGEN_ALWAYS_INLINE Packet8bf loadRhsBfloat16(const bfloat16* blockB, Index stri
return loadBfloat16<zero>(blockB + strideB*i);
}
template<Index num_acc, Index num_packets, bool zero, bool rhsExtraCols, bool lhsExtraRows>
template<Index num_acc, Index num_packets, bool zero, bool rhsExtraCols, bool lhsExtraRows, Index num_rhs, Index num_lhs>
EIGEN_ALWAYS_INLINE void KLoop
(
const bfloat16* indexA,
@ -42,250 +42,277 @@ EIGEN_ALWAYS_INLINE void KLoop
Index extra_rows
)
{
Packet8bf lhs = loadBfloat16<zero>(indexA + k*(lhsExtraRows ? extra_rows : num_packets)); //a packet of bfloat16 has 8 elements
Packet8bf rhs[num_acc];
Packet8bf lhs[num_lhs], rhs[num_rhs];
for(Index i = 0; i < (num_acc - (rhsExtraCols ? 1 : 0)); i++){
for(Index i = 0; i < (num_rhs - (rhsExtraCols ? 1 : 0)); i++){
rhs[i] = loadRhsBfloat16<zero>(indexB + k*4, strideB, i);
}
if(rhsExtraCols) {
rhs[num_acc - 1] = loadRhsBfloat16<zero>(indexB + k*extra_cols - offsetB, strideB, num_acc - 1);
rhs[num_rhs - 1] = loadRhsBfloat16<zero>(indexB + k*extra_cols - offsetB, strideB, num_rhs - 1);
}
indexA += k*(lhsExtraRows ? extra_rows : num_packets);
for(Index j = 0; j < num_lhs; j++) {
lhs[j] = loadBfloat16<zero>(indexA + j*(zero ? 4 : 8)); //a packet of bfloat16 has 8 elements
}
BFLOAT16_UNROLL
for (Index i = 0; i < num_acc; i++) {
__builtin_mma_xvbf16ger2pp(&(quad_acc[i]), reinterpret_cast<Packet16uc>(rhs[i].m_val), reinterpret_cast<Packet16uc>(lhs.m_val));
for(Index i = 0, k = 0; i < num_rhs; i++) {
BFLOAT16_UNROLL
for(Index j = 0; j < num_lhs; j++, k++) {
__builtin_mma_xvbf16ger2pp(&(quad_acc[k]), reinterpret_cast<Packet16uc>(rhs[i].m_val), reinterpret_cast<Packet16uc>(lhs[j].m_val));
}
}
}
template <bool rhsExtraCols, bool lhsExtraRows>
EIGEN_ALWAYS_INLINE Packet4f loadAndMultiplyF32(Packet4f acc, const Packet4f pAlpha, float* result)
{
Packet4f result_block = ploadu<Packet4f>(result);
return pmadd(acc, pAlpha, result_block);
}
template<bool lhsExtraRows>
EIGEN_ALWAYS_INLINE void storeF32(float*& result, Packet4f result_block, Index rows, Index extra_rows)
{
if (lhsExtraRows) {
pstoreu_partial(result, result_block, extra_rows);
} else {
pstoreu(result, result_block);
}
result += rows;
}
template<bool rhsExtraCols, bool lhsExtraRows>
EIGEN_ALWAYS_INLINE void storeResults(Packet4f (&acc)[4], Index rows, const Packet4f pAlpha, float* result, Index extra_cols, Index extra_rows)
{
Index x = 0;
do{
Packet4f result_block = ploadu<Packet4f>(result);
result_block = pmadd(acc[x], pAlpha, result_block);
if (lhsExtraRows) {
pstoreu_partial(result, result_block, extra_rows);
} else {
pstoreu(result, result_block);
if (rhsExtraCols) {
do{
Packet4f result_block = loadAndMultiplyF32(acc[x], pAlpha, result);
storeF32<lhsExtraRows>(result, result_block, rows, extra_rows);
} while (++x < extra_cols);
} else {
Packet4f result_block[4];
float *result2 = result;
do{
result_block[x] = loadAndMultiplyF32(acc[x], pAlpha, result);
result += rows;
} while (++x < 4);
x = 0;
do{
storeF32<lhsExtraRows>(result2, result_block[x], rows, extra_rows);
} while (++x < 4);
}
}
template<Index num_acc>
EIGEN_ALWAYS_INLINE void zeroAccumulators(__vector_quad (&quad_acc)[num_acc])
{
BFLOAT16_UNROLL
for(Index k = 0; k < num_acc; k++)
__builtin_mma_xxsetaccz(&(quad_acc[k]));
}
template<Index num_acc>
EIGEN_ALWAYS_INLINE void disassembleAccumulators(__vector_quad (&quad_acc)[num_acc], Packet4f (&acc)[num_acc][4])
{
BFLOAT16_UNROLL
for(Index k = 0; k < num_acc; k++)
__builtin_mma_disassemble_acc((void*)acc[k], &(quad_acc[k]));
}
template<Index num_acc, bool rhsExtraCols, bool lhsExtraRows, Index num_rhs, Index num_lhs>
EIGEN_ALWAYS_INLINE void outputResults(Packet4f (&acc)[num_acc][4], Index rows, const Packet4f pAlpha, float* result, const Index extra_cols, Index extra_rows)
{
for(Index i = 0, k = 0; i < num_rhs - (rhsExtraCols ? 1 : 0); i++, result += 4*rows){
for(Index j = 0; j < num_lhs; j++, k++) {
storeResults<false, lhsExtraRows>(acc[k], rows, pAlpha, result + j*4, extra_cols, extra_rows);
}
result += rows;
} while (++x < (rhsExtraCols ? extra_cols : 4));
}
if(rhsExtraCols) {
storeResults<rhsExtraCols, lhsExtraRows>(acc[num_acc - 1], rows, pAlpha, result, extra_cols, extra_rows);
}
}
template<const Index num_acc, const Index num_packets, bool rhsExtraCols, bool lhsExtraRows, bool multiIter = false>
EIGEN_ALWAYS_INLINE void colLoopBodyIter(Index depth, Index rows, const Packet4f pAlpha, const bfloat16* indexA, const bfloat16* indexB, Index strideB, Index offsetB, float* result, const Index extra_cols, const Index extra_rows)
{
constexpr Index num_lhs = multiIter ? (num_packets / 4) : 1;
constexpr Index num_rhs = (num_acc + num_lhs - 1) / num_lhs;
for(Index offset_row = 0; offset_row < num_packets; offset_row += 4, indexA += (multiIter ? 0 : 8), indexB += (multiIter ? (num_rhs*strideB) : 0), result += (multiIter ? (4*rows*num_rhs) : 4)) {
Packet4f acc[num_acc][4];
__vector_quad quad_acc[num_acc];
zeroAccumulators<num_acc>(quad_acc);
Index k;
for(k = 0; k + 2 <= depth; k += 2){
KLoop<num_acc, num_packets, false, rhsExtraCols, lhsExtraRows, num_rhs, num_lhs>(indexA, indexB, quad_acc, strideB, k, offsetB, extra_cols, extra_rows);
}
if(depth&1){
KLoop<num_acc, num_packets, true, rhsExtraCols, lhsExtraRows, num_rhs, num_lhs>(indexA - (multiIter ? 0 : offset_row), indexB, quad_acc, strideB, k, offsetB, extra_cols, extra_rows);
}
disassembleAccumulators<num_acc>(quad_acc, acc);
outputResults<num_acc, rhsExtraCols, lhsExtraRows, num_rhs, num_lhs>(acc, rows, pAlpha, result, extra_cols, extra_rows);
}
}
#define MAX_BFLOAT16_ACC 8
template<const Index num_acc, const Index num_packets, bool rhsExtraCols, bool lhsExtraRows>
void colLoopBody(Index& col, Index depth, Index cols, Index rows, const Packet4f pAlpha, const bfloat16* indexA, const bfloat16* indexB, Index strideB, Index offsetB, float* result, Index extra_rows)
void colLoopBody(Index& col, Index depth, Index cols, Index rows, const Packet4f pAlpha, const bfloat16* indexA, const bfloat16* indexB, Index strideB, Index offsetB, float* result)
{
const Index step = (num_acc * 4); //each accumulator has 4 elements
constexpr Index step = (num_acc * 4); //each accumulator has 4 elements
const Index extra_cols = (rhsExtraCols) ? (cols & 3) : 0;
const Index extra_rows = (lhsExtraRows) ? (rows & 3) : 0;
constexpr bool multiIters = !rhsExtraCols && (num_acc == MAX_BFLOAT16_ACC);
do{
for(Index offset_row = 0; offset_row < num_packets; offset_row += 4, indexA += 8, result += 4) {
Index k;
Packet4f acc[num_acc][4];
__vector_quad quad_acc[num_acc];
BFLOAT16_UNROLL
for(k = 0; k < num_acc; k++)
__builtin_mma_xxsetaccz(&(quad_acc[k]));
for(k = 0; k + 2 <= depth; k += 2){
KLoop<num_acc, num_packets, false, rhsExtraCols, lhsExtraRows>(indexA, indexB, quad_acc, strideB, k, offsetB, extra_cols, extra_rows);
}
if(depth&1){
KLoop<num_acc, num_packets, true, rhsExtraCols, lhsExtraRows>(indexA - offset_row, indexB, quad_acc, strideB, k, offsetB, extra_cols, extra_rows);
}
BFLOAT16_UNROLL
for(k = 0; k < num_acc; k++)
__builtin_mma_disassemble_acc((void*)acc[k], &(quad_acc[k]));
for(k = 0; k < (num_acc - 1); k++){
storeResults<false, lhsExtraRows>(acc[k], rows, pAlpha, result + k*4*rows, extra_cols, extra_rows);
}
storeResults<rhsExtraCols, lhsExtraRows>(acc[k], rows, pAlpha, result + k*4*rows, extra_cols, extra_rows);
if (multiIters && ((num_acc % (num_packets / 4)) == 0)) {
colLoopBodyIter<num_acc, num_packets, rhsExtraCols, lhsExtraRows, true>(depth, rows, pAlpha, indexA, indexB, strideB, offsetB, result, extra_cols, extra_rows);
} else {
colLoopBodyIter<num_acc, num_packets, rhsExtraCols, lhsExtraRows>(depth, rows, pAlpha, indexA, indexB, strideB, offsetB, result, extra_cols, extra_rows);
}
indexA -= num_packets*2;
indexB += strideB*num_acc;
result += (rows*step - num_packets);
} while(!rhsExtraCols && (num_acc == MAX_BFLOAT16_ACC) && (step <= cols - (col += step)));
result += rows*step;
} while(multiIters && (step <= cols - (col += step)));
}
template<const Index num_acc, const Index num_packets, bool rhsExtraCols, bool lhsExtraRows>
EIGEN_ALWAYS_INLINE void colLoopBodyExtraN(Index col, Index depth, Index cols, Index rows, const Packet4f pAlpha, const bfloat16* indexA, const bfloat16* blockB, Index strideB, Index offsetB, float* result, Index extra_rows)
EIGEN_ALWAYS_INLINE void colLoopBodyExtraN(Index col, Index depth, Index cols, Index rows, const Packet4f pAlpha, const bfloat16* indexA, const bfloat16* blockB, Index strideB, Index offsetB, float* result)
{
if (MAX_BFLOAT16_ACC > num_acc) {
colLoopBody<num_acc + (rhsExtraCols ? 1 : 0), num_packets, rhsExtraCols, lhsExtraRows>(col, depth, cols, rows, pAlpha, indexA, blockB, strideB, offsetB, result, extra_rows);
colLoopBody<num_acc + (rhsExtraCols ? 1 : 0), num_packets, rhsExtraCols, lhsExtraRows>(col, depth, cols, rows, pAlpha, indexA, blockB, strideB, offsetB, result);
}
}
template<const Index num_packets, bool rhsExtraCols, bool lhsExtraRows>
void colLoopBodyExtra(Index col, Index depth, Index cols, Index rows, const Packet4f pAlpha, const bfloat16* indexA, const bfloat16* blockB, Index strideB, Index offsetB, float* result, Index extra_rows)
void colLoopBodyExtra(Index col, Index depth, Index cols, Index rows, const Packet4f pAlpha, const bfloat16* indexA, const bfloat16* blockB, Index strideB, Index offsetB, float* result)
{
switch ((cols - col) >> 2) {
case 7:
colLoopBodyExtraN<7, num_packets, rhsExtraCols, lhsExtraRows>(col, depth, cols, rows, pAlpha, indexA, blockB, strideB, offsetB, result, extra_rows);
colLoopBodyExtraN<7, num_packets, rhsExtraCols, lhsExtraRows>(col, depth, cols, rows, pAlpha, indexA, blockB, strideB, offsetB, result);
break;
case 6:
colLoopBodyExtraN<6, num_packets, rhsExtraCols, lhsExtraRows>(col, depth, cols, rows, pAlpha, indexA, blockB, strideB, offsetB, result, extra_rows);
colLoopBodyExtraN<6, num_packets, rhsExtraCols, lhsExtraRows>(col, depth, cols, rows, pAlpha, indexA, blockB, strideB, offsetB, result);
break;
case 5:
colLoopBodyExtraN<5, num_packets, rhsExtraCols, lhsExtraRows>(col, depth, cols, rows, pAlpha, indexA, blockB, strideB, offsetB, result, extra_rows);
colLoopBodyExtraN<5, num_packets, rhsExtraCols, lhsExtraRows>(col, depth, cols, rows, pAlpha, indexA, blockB, strideB, offsetB, result);
break;
case 4:
colLoopBodyExtraN<4, num_packets, rhsExtraCols, lhsExtraRows>(col, depth, cols, rows, pAlpha, indexA, blockB, strideB, offsetB, result, extra_rows);
colLoopBodyExtraN<4, num_packets, rhsExtraCols, lhsExtraRows>(col, depth, cols, rows, pAlpha, indexA, blockB, strideB, offsetB, result);
break;
case 3:
colLoopBodyExtraN<3, num_packets, rhsExtraCols, lhsExtraRows>(col, depth, cols, rows, pAlpha, indexA, blockB, strideB, offsetB, result, extra_rows);
colLoopBodyExtraN<3, num_packets, rhsExtraCols, lhsExtraRows>(col, depth, cols, rows, pAlpha, indexA, blockB, strideB, offsetB, result);
break;
case 2:
colLoopBodyExtraN<2, num_packets, rhsExtraCols, lhsExtraRows>(col, depth, cols, rows, pAlpha, indexA, blockB, strideB, offsetB, result, extra_rows);
colLoopBodyExtraN<2, num_packets, rhsExtraCols, lhsExtraRows>(col, depth, cols, rows, pAlpha, indexA, blockB, strideB, offsetB, result);
break;
case 1:
colLoopBodyExtraN<1, num_packets, rhsExtraCols, lhsExtraRows>(col, depth, cols, rows, pAlpha, indexA, blockB, strideB, offsetB, result, extra_rows);
colLoopBodyExtraN<1, num_packets, rhsExtraCols, lhsExtraRows>(col, depth, cols, rows, pAlpha, indexA, blockB, strideB, offsetB, result);
break;
default:
if (rhsExtraCols) {
colLoopBody<1, num_packets, true, lhsExtraRows>(col, depth, cols, rows, pAlpha, indexA, blockB, strideB, offsetB, result, extra_rows);
colLoopBody<1, num_packets, true, lhsExtraRows>(col, depth, cols, rows, pAlpha, indexA, blockB, strideB, offsetB, result);
}
break;
}
}
template<const Index num_packets, bool lhsExtraRows = false>
EIGEN_ALWAYS_INLINE void colLoops(Index depth, Index cols, Index rows, const Packet4f pAlpha, const bfloat16* indexA, const bfloat16* blockB, Index strideB, Index offsetB, float* result, Index extra_rows = 0)
EIGEN_ALWAYS_INLINE void colLoops(Index depth, Index cols, Index rows, const Packet4f pAlpha, const bfloat16* indexA, const bfloat16* blockB, Index strideB, Index offsetB, float* result)
{
Index col = 0;
if (cols >= (MAX_BFLOAT16_ACC * 4)) {
colLoopBody<MAX_BFLOAT16_ACC, num_packets, false, lhsExtraRows>(col, depth, cols, rows, pAlpha, indexA, blockB, strideB, 0, result, extra_rows);
colLoopBody<MAX_BFLOAT16_ACC, num_packets, false, lhsExtraRows>(col, depth, cols, rows, pAlpha, indexA, blockB, strideB, 0, result);
blockB += (strideB >> 2)*col;
result += rows*col;
}
if (cols & 3) {
colLoopBodyExtra<num_packets, true, lhsExtraRows>(col, depth, cols, rows, pAlpha, indexA, blockB, strideB, offsetB, result, extra_rows);
colLoopBodyExtra<num_packets, true, lhsExtraRows>(col, depth, cols, rows, pAlpha, indexA, blockB, strideB, offsetB, result);
} else {
colLoopBodyExtra<num_packets, false, lhsExtraRows>(col, depth, cols, rows, pAlpha, indexA, blockB, strideB, 0, result, extra_rows);
colLoopBodyExtra<num_packets, false, lhsExtraRows>(col, depth, cols, rows, pAlpha, indexA, blockB, strideB, 0, result);
}
}
EIGEN_ALWAYS_INLINE Packet8bf convertF16toF32(const float *res)
template<bool full = true>
EIGEN_ALWAYS_INLINE Packet8bf convertF32toBF16(const float *res)
{
Packet16uc fp16_0 = __builtin_vsx_xvcvspbf16(reinterpret_cast<Packet16uc>(ploadu<Packet4f>(res + 0)));
Packet16uc fp16_1 = __builtin_vsx_xvcvspbf16(reinterpret_cast<Packet16uc>(ploadu<Packet4f>(res + 4)));
Packet16uc fp16_1 = (full) ? __builtin_vsx_xvcvspbf16(reinterpret_cast<Packet16uc>(ploadu<Packet4f>(res + 4))) : fp16_0;
return vec_pack(reinterpret_cast<Packet4ui>(fp16_0), reinterpret_cast<Packet4ui>(fp16_1));
}
template<typename Index, typename Packet, typename RhsPacket, typename DataMapper, const Index accRows, const Index accCols>
void gemmMMAbfloat16(const DataMapper& res, const bfloat16* blockA, const bfloat16* blockB, Index rows, Index depth, Index cols, bfloat16 alpha, Index strideA, Index strideB, Index offsetA, Index offsetB)
template<int N>
EIGEN_ALWAYS_INLINE void storeConvertBlockBF16(float* to, PacketBlock<Packet8bf,(N+4)/8>& block)
{
if(rows == 0 || cols == 0 || depth == 0) return;
float falpha = Eigen::bfloat16_impl::bfloat16_to_float(alpha);
if (falpha == float(0)) return;
const Packet4f pAlpha = pset1<Packet4f>(falpha);
ei_declare_aligned_stack_constructed_variable(float, result, cols*rows, 0);
typedef typename DataMapper::LinearMapper LinearMapper;
Packet8us z = pset1<Packet8us>(0);
for(Index j = 0; j < cols; j++){
const LinearMapper res2 = res.getLinearMapper(0, j);
float *result2 = result + j*rows;
pstore(to + 0, reinterpret_cast<Packet4f>(vec_mergeh(z, block.packet[0].m_val)));
if (N >= 8) {
pstore(to + 4, reinterpret_cast<Packet4f>(vec_mergel(z, block.packet[0].m_val)));
}
if (N >= 16) {
pstore(to + 8, reinterpret_cast<Packet4f>(vec_mergeh(z, block.packet[1].m_val)));
pstore(to + 12, reinterpret_cast<Packet4f>(vec_mergel(z, block.packet[1].m_val)));
}
if (N >= 32) {
pstore(to + 16, reinterpret_cast<Packet4f>(vec_mergeh(z, block.packet[2].m_val)));
pstore(to + 20, reinterpret_cast<Packet4f>(vec_mergel(z, block.packet[2].m_val)));
pstore(to + 24, reinterpret_cast<Packet4f>(vec_mergeh(z, block.packet[3].m_val)));
pstore(to + 28, reinterpret_cast<Packet4f>(vec_mergel(z, block.packet[3].m_val)));
}
}
template<const Index size, typename DataMapper>
EIGEN_ALWAYS_INLINE void convertBF16toF32(Index& i, float *result, Index rows, const DataMapper& src)
{
for(; i + size <= rows; i += size){
PacketBlock<Packet8bf,(size+4)/8> r32;
r32.packet[0] = src.template loadPacket<Packet8bf>(i + 0);
if (size >= 16) {
r32.packet[1] = src.template loadPacket<Packet8bf>(i + 8);
}
if (size >= 32) {
r32.packet[2] = src.template loadPacket<Packet8bf>(i + 16);
r32.packet[3] = src.template loadPacket<Packet8bf>(i + 24);
}
storeConvertBlockBF16<size>(result + i, r32);
}
}
template<typename DataMapper>
EIGEN_ALWAYS_INLINE void convertArrayBF16toF32(float *result, Index cols, Index rows, const DataMapper& src)
{
typedef typename DataMapper::LinearMapper LinearMapper;
for(Index j = 0; j < cols; j++, result += rows){
const LinearMapper src2 = src.getLinearMapper(0, j);
Index i = 0;
for(; i + 32 <= rows; i+=32){
Packet8us r32_0 = res2.template loadPacket<Packet8bf>(i + 0).m_val;
Packet8us r32_1 = res2.template loadPacket<Packet8bf>(i + 8).m_val;
Packet8us r32_2 = res2.template loadPacket<Packet8bf>(i + 16).m_val;
Packet8us r32_3 = res2.template loadPacket<Packet8bf>(i + 24).m_val;
pstore(result2 + i + 0, reinterpret_cast<Packet4f>(vec_mergeh(z, r32_0)));
pstore(result2 + i + 4, reinterpret_cast<Packet4f>(vec_mergel(z, r32_0)));
pstore(result2 + i + 8, reinterpret_cast<Packet4f>(vec_mergeh(z, r32_1)));
pstore(result2 + i + 12, reinterpret_cast<Packet4f>(vec_mergel(z, r32_1)));
pstore(result2 + i + 16, reinterpret_cast<Packet4f>(vec_mergeh(z, r32_2)));
pstore(result2 + i + 20, reinterpret_cast<Packet4f>(vec_mergel(z, r32_2)));
pstore(result2 + i + 24, reinterpret_cast<Packet4f>(vec_mergeh(z, r32_3)));
pstore(result2 + i + 28, reinterpret_cast<Packet4f>(vec_mergel(z, r32_3)));
}
for(; i + 16 <= rows; i+=16){
Packet8us r32_0 = res2.template loadPacket<Packet8bf>(i + 0).m_val;
Packet8us r32_1 = res2.template loadPacket<Packet8bf>(i + 8).m_val;
pstore(result2 + i + 0, reinterpret_cast<Packet4f>(vec_mergeh(z, r32_0)));
pstore(result2 + i + 4, reinterpret_cast<Packet4f>(vec_mergel(z, r32_0)));
pstore(result2 + i + 8, reinterpret_cast<Packet4f>(vec_mergeh(z, r32_1)));
pstore(result2 + i + 12, reinterpret_cast<Packet4f>(vec_mergel(z, r32_1)));
}
for(; i + 8 <= rows; i+=8){
Packet8us r32_0 = res2.template loadPacket<Packet8bf>(i + 0).m_val;
pstore(result2 + i + 0, reinterpret_cast<Packet4f>(vec_mergeh(z, r32_0)));
pstore(result2 + i + 4, reinterpret_cast<Packet4f>(vec_mergel(z, r32_0)));
}
for(; i + 4 <= rows; i+=4){
Packet8us r32_0 = res2.template loadPacketPartial<Packet8bf>(i + 0, 4).m_val;
pstore(result2 + i + 0, reinterpret_cast<Packet4f>(vec_mergeh(z, r32_0)));
}
convertBF16toF32<32, LinearMapper>(i, result, rows, src2);
convertBF16toF32<16, LinearMapper>(i, result, rows, src2);
convertBF16toF32<8, LinearMapper>(i, result, rows, src2);
convertBF16toF32<4, LinearMapper>(i, result, rows, src2);
for(; i < rows; i++){
result2[i] = res2(i);
result[i] = Eigen::bfloat16_impl::bfloat16_to_float(src2(i));
}
}
}
Index row = 0;
Index col;
if( strideA == -1 ) strideA = depth;
if( strideB == -1 ) strideB = depth;
//Packing is done in blocks.
//There's 4 possible sizes of blocks
//Blocks of 8 columns with 16 elements (8x16)
//Blocks of 8 columns with 8 elements (8x8). This happens when there's 16 > rows >= 8
//Blocks of 8 columns with 4 elements (8x4). This happens when there's 8 > rows >= 4
//Blocks of 8 columns with < 4 elements. This happens when there's less than 4 remaining rows
//Loop for LHS standard block (8x16)
const Index standard_block_size = 16;
const Index standard_blocks_quantity = rows/standard_block_size; //Number of standard blocks
Index bigSuffix = (2*8) * (strideA-offsetA);
const bfloat16* indexA = blockA;
const bfloat16* indexB = blockB + 4*offsetB;
Index block_index;
strideB *= 4;
offsetB *= 3;
for(block_index = 0; block_index < standard_blocks_quantity; block_index++){
indexA += 2*8*offsetA;
colLoops<16>(depth, cols, rows, pAlpha, indexA, indexB, strideB, offsetB, result + row);
row += 16;
indexA += bigSuffix;
}
//LHS (8x8) block
if(rows & 8){
indexA += 1*8*offsetA;
colLoops<8>(depth, cols, rows, pAlpha, indexA, indexB, strideB, offsetB, result + row);
row += 8;
indexA += (bigSuffix >> 1);
}
//LHS (8x4) block
if(rows & 4){
indexA += 1*4*offsetA;
colLoops<4>(depth, cols, rows, pAlpha, indexA, indexB, strideB, offsetB, result + row);
row += 4;
indexA += (bigSuffix >> 2);
}
//extra rows
Index extra_rows = rows & 3;
if(extra_rows){
//This index is the beginning of remaining block.
colLoops<4, true>(depth, cols, rows, pAlpha, indexA, indexB, strideB, offsetB, result + row, extra_rows);
}
//Convert back to bfloat16
template<typename DataMapper>
EIGEN_ALWAYS_INLINE void convertArrayF32toBF16(float *result, Index cols, Index rows, const DataMapper& res)
{
typedef typename DataMapper::LinearMapper LinearMapper;
Index col, row;
for(col = 0; col + 4 <= cols; col += 4){
const DataMapper res2 = res.getSubMapper(0, col);
for(row = 0; row + 8 <= rows; row += 8){
//get and save block
PacketBlock<Packet8bf,4> block;
for(Index j = 0; j < 4; j++){
block.packet[j].m_val = convertF16toF32(result + (col + j)*rows + row);
block.packet[j].m_val = convertF32toBF16(result + (col + j)*rows + row);
}
res2.template storePacketBlock<Packet8bf,4>(row, 0, block);
@ -303,18 +330,70 @@ void gemmMMAbfloat16(const DataMapper& res, const bfloat16* blockA, const bfloat
while(col < cols){
const LinearMapper res2 = res.getLinearMapper(0, col);
float *result2 = result + col*rows;
Index r = 0;
for(; r + 8 <= rows; r += 8){
Packet8bf fp16 = convertF16toF32(result2 + r);
res2.template storePacket<Packet8bf>(r, fp16);
for(row = 0; row + 8 <= rows; row += 8){
Packet8bf fp16 = convertF32toBF16(result2 + row);
res2.template storePacket<Packet8bf>(row, fp16);
}
for(; r< rows; r++){
res2(r) = Eigen::bfloat16(result2[r]);
for(; row < rows; row++){
res2(row) = Eigen::bfloat16(result2[row]);
}
col++;
}
}
template<Index size>
EIGEN_ALWAYS_INLINE void calcColLoops(const bfloat16*& indexA, Index& row, Index depth, Index cols, Index rows, const Packet4f pAlpha, const bfloat16* indexB, Index strideB, Index offsetA, Index offsetB, Index bigSuffix, float *result)
{
if ((size == 16) || (rows & size)) {
indexA += size*offsetA;
colLoops<size>(depth, cols, rows, pAlpha, indexA, indexB, strideB, offsetB, result + row);
row += size;
indexA += bigSuffix*size/16;
}
}
template<typename DataMapper>
void gemmMMAbfloat16(const DataMapper& res, const bfloat16* indexA, const bfloat16* indexB, Index rows, Index depth, Index cols, bfloat16 alpha, Index strideA, Index strideB, Index offsetA, Index offsetB)
{
float falpha = Eigen::bfloat16_impl::bfloat16_to_float(alpha);
if (falpha == float(0)) return;
const Packet4f pAlpha = pset1<Packet4f>(falpha);
ei_declare_aligned_stack_constructed_variable(float, result, cols*rows, 0);
convertArrayBF16toF32<DataMapper>(result, cols, rows, res);
Index row = 0;
if( strideA == -1 ) strideA = depth;
if( strideB == -1 ) strideB = depth;
//Packing is done in blocks.
//There's 4 possible sizes of blocks
//Blocks of 8 columns with 16 elements (8x16)
//Blocks of 8 columns with 8 elements (8x8). This happens when there's 16 > rows >= 8
//Blocks of 8 columns with 4 elements (8x4). This happens when there's 8 > rows >= 4
//Blocks of 8 columns with < 4 elements. This happens when there's less than 4 remaining rows
//Loop for LHS standard block (8x16)
Index bigSuffix = (2*8) * (strideA-offsetA);
indexB += 4*offsetB;
strideB *= 4;
offsetB *= 3;
while(row + 16 <= rows){
calcColLoops<16>(indexA, row, depth, cols, rows, pAlpha, indexB, strideB, offsetA, offsetB, bigSuffix, result);
}
//LHS (8x8) block
calcColLoops<8>(indexA, row, depth, cols, rows, pAlpha, indexB, strideB, offsetA, offsetB, bigSuffix, result);
//LHS (8x4) block
calcColLoops<4>(indexA, row, depth, cols, rows, pAlpha, indexB, strideB, offsetA, offsetB, bigSuffix, result);
//extra rows
if(rows & 3){
//This index is the beginning of remaining block.
colLoops<4, true>(depth, cols, rows, pAlpha, indexA, indexB, strideB, offsetB, result + row);
}
//Convert back to bfloat16
convertArrayF32toBF16<DataMapper>(result, cols, rows, res);
}
}
}