mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-06-21 20:09:06 +08:00
Added partial linear access for LHS & Output - 30% faster for bfloat16 GEMM MMA (Power)
This commit is contained in:
parent
0b396c3167
commit
2b513ca2a0
@ -841,7 +841,7 @@ struct dhs_pack<double, DataMapper, Packet2d, StorageOrder, PanelMode, false>
|
||||
}
|
||||
};
|
||||
|
||||
#ifdef __MMA__
|
||||
#if !EIGEN_ALTIVEC_DISABLE_MMA && defined(__MMA__)
|
||||
// General template for lhs packing, bfloat16 specialization.
|
||||
template<typename DataMapper, int StorageOrder, bool PanelMode>
|
||||
struct dhs_pack<bfloat16, DataMapper, Packet8bf, StorageOrder, PanelMode, true>
|
||||
@ -1162,6 +1162,7 @@ struct dhs_pack<bfloat16, DataMapper, Packet8bf, StorageOrder, PanelMode, false>
|
||||
t1 = reinterpret_cast<Packet2ul>(vec_mergeh(reinterpret_cast<Packet4ui>(block.packet[2].m_val), reinterpret_cast<Packet4ui>(block.packet[3].m_val)));
|
||||
t2 = reinterpret_cast<Packet2ul>(vec_mergel(reinterpret_cast<Packet4ui>(block.packet[0].m_val), reinterpret_cast<Packet4ui>(block.packet[1].m_val)));
|
||||
t3 = reinterpret_cast<Packet2ul>(vec_mergel(reinterpret_cast<Packet4ui>(block.packet[2].m_val), reinterpret_cast<Packet4ui>(block.packet[3].m_val)));
|
||||
|
||||
block.packet[0] = reinterpret_cast<Packet8us>(vec_mergeh(t0, t1));
|
||||
block.packet[1] = reinterpret_cast<Packet8us>(vec_mergel(t0, t1));
|
||||
block.packet[2] = reinterpret_cast<Packet8us>(vec_mergeh(t2, t3));
|
||||
@ -2736,7 +2737,7 @@ void gemm_pack_rhs<double, Index, DataMapper, nr, RowMajor, Conjugate, PanelMode
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifdef __MMA__
|
||||
#if !EIGEN_ALTIVEC_DISABLE_MMA && defined(__MMA__)
|
||||
#if EIGEN_ALTIVEC_USE_CUSTOM_PACK
|
||||
template<typename Index, typename DataMapper, int nr, bool Conjugate, bool PanelMode>
|
||||
struct gemm_pack_rhs<bfloat16, Index, DataMapper, nr, ColMajor, Conjugate, PanelMode>
|
||||
@ -3270,7 +3271,7 @@ void gebp_kernel<double, std::complex<double>, Index, DataMapper, mr, nr, Conjug
|
||||
gemm_function(res, blockA, blockB, rows, depth, cols, alpha, strideA, strideB, offsetA, offsetB);
|
||||
}
|
||||
|
||||
#if defined(__MMA__)
|
||||
#if !EIGEN_ALTIVEC_DISABLE_MMA && defined(__MMA__)
|
||||
template<typename Index, typename DataMapper, int mr, int nr, bool ConjugateLhs, bool ConjugateRhs>
|
||||
struct gebp_kernel<bfloat16, bfloat16, Index, DataMapper, mr, nr, ConjugateLhs, ConjugateRhs>
|
||||
{
|
||||
@ -3288,10 +3289,7 @@ void gebp_kernel<bfloat16, bfloat16, Index, DataMapper, mr, nr, ConjugateLhs, Co
|
||||
Index rows, Index depth, Index cols, bfloat16 alpha,
|
||||
Index strideA, Index strideB, Index offsetA, Index offsetB)
|
||||
{
|
||||
const Index accRows = quad_traits<bfloat16>::rows;
|
||||
const Index accCols = quad_traits<bfloat16>::size;
|
||||
|
||||
Eigen::internal::gemmMMAbfloat16<Index, Packet, RhsPacket, DataMapper, accRows, accCols>(res, blockA, blockB, rows, depth, cols, alpha, strideA, strideB, offsetA, offsetB);
|
||||
Eigen::internal::gemmMMAbfloat16<DataMapper>(res, blockA, blockB, rows, depth, cols, alpha, strideA, strideB, offsetA, offsetB);
|
||||
}
|
||||
#endif
|
||||
} // end namespace internal
|
||||
|
@ -28,7 +28,9 @@
|
||||
|
||||
#include "../../InternalHeaderCheck.h"
|
||||
|
||||
#if !EIGEN_ALTIVEC_DISABLE_MMA
|
||||
#include "MatrixProductMMAbfloat16.h"
|
||||
#endif
|
||||
|
||||
namespace Eigen {
|
||||
|
||||
|
@ -29,7 +29,7 @@ EIGEN_ALWAYS_INLINE Packet8bf loadRhsBfloat16(const bfloat16* blockB, Index stri
|
||||
return loadBfloat16<zero>(blockB + strideB*i);
|
||||
}
|
||||
|
||||
template<Index num_acc, Index num_packets, bool zero, bool rhsExtraCols, bool lhsExtraRows>
|
||||
template<Index num_acc, Index num_packets, bool zero, bool rhsExtraCols, bool lhsExtraRows, Index num_rhs, Index num_lhs>
|
||||
EIGEN_ALWAYS_INLINE void KLoop
|
||||
(
|
||||
const bfloat16* indexA,
|
||||
@ -42,250 +42,277 @@ EIGEN_ALWAYS_INLINE void KLoop
|
||||
Index extra_rows
|
||||
)
|
||||
{
|
||||
Packet8bf lhs = loadBfloat16<zero>(indexA + k*(lhsExtraRows ? extra_rows : num_packets)); //a packet of bfloat16 has 8 elements
|
||||
Packet8bf rhs[num_acc];
|
||||
Packet8bf lhs[num_lhs], rhs[num_rhs];
|
||||
|
||||
for(Index i = 0; i < (num_acc - (rhsExtraCols ? 1 : 0)); i++){
|
||||
for(Index i = 0; i < (num_rhs - (rhsExtraCols ? 1 : 0)); i++){
|
||||
rhs[i] = loadRhsBfloat16<zero>(indexB + k*4, strideB, i);
|
||||
}
|
||||
if(rhsExtraCols) {
|
||||
rhs[num_acc - 1] = loadRhsBfloat16<zero>(indexB + k*extra_cols - offsetB, strideB, num_acc - 1);
|
||||
rhs[num_rhs - 1] = loadRhsBfloat16<zero>(indexB + k*extra_cols - offsetB, strideB, num_rhs - 1);
|
||||
}
|
||||
|
||||
indexA += k*(lhsExtraRows ? extra_rows : num_packets);
|
||||
for(Index j = 0; j < num_lhs; j++) {
|
||||
lhs[j] = loadBfloat16<zero>(indexA + j*(zero ? 4 : 8)); //a packet of bfloat16 has 8 elements
|
||||
}
|
||||
|
||||
BFLOAT16_UNROLL
|
||||
for (Index i = 0; i < num_acc; i++) {
|
||||
__builtin_mma_xvbf16ger2pp(&(quad_acc[i]), reinterpret_cast<Packet16uc>(rhs[i].m_val), reinterpret_cast<Packet16uc>(lhs.m_val));
|
||||
for(Index i = 0, k = 0; i < num_rhs; i++) {
|
||||
BFLOAT16_UNROLL
|
||||
for(Index j = 0; j < num_lhs; j++, k++) {
|
||||
__builtin_mma_xvbf16ger2pp(&(quad_acc[k]), reinterpret_cast<Packet16uc>(rhs[i].m_val), reinterpret_cast<Packet16uc>(lhs[j].m_val));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <bool rhsExtraCols, bool lhsExtraRows>
|
||||
EIGEN_ALWAYS_INLINE Packet4f loadAndMultiplyF32(Packet4f acc, const Packet4f pAlpha, float* result)
|
||||
{
|
||||
Packet4f result_block = ploadu<Packet4f>(result);
|
||||
return pmadd(acc, pAlpha, result_block);
|
||||
}
|
||||
|
||||
template<bool lhsExtraRows>
|
||||
EIGEN_ALWAYS_INLINE void storeF32(float*& result, Packet4f result_block, Index rows, Index extra_rows)
|
||||
{
|
||||
if (lhsExtraRows) {
|
||||
pstoreu_partial(result, result_block, extra_rows);
|
||||
} else {
|
||||
pstoreu(result, result_block);
|
||||
}
|
||||
result += rows;
|
||||
}
|
||||
|
||||
template<bool rhsExtraCols, bool lhsExtraRows>
|
||||
EIGEN_ALWAYS_INLINE void storeResults(Packet4f (&acc)[4], Index rows, const Packet4f pAlpha, float* result, Index extra_cols, Index extra_rows)
|
||||
{
|
||||
Index x = 0;
|
||||
do{
|
||||
Packet4f result_block = ploadu<Packet4f>(result);
|
||||
result_block = pmadd(acc[x], pAlpha, result_block);
|
||||
if (lhsExtraRows) {
|
||||
pstoreu_partial(result, result_block, extra_rows);
|
||||
} else {
|
||||
pstoreu(result, result_block);
|
||||
if (rhsExtraCols) {
|
||||
do{
|
||||
Packet4f result_block = loadAndMultiplyF32(acc[x], pAlpha, result);
|
||||
storeF32<lhsExtraRows>(result, result_block, rows, extra_rows);
|
||||
} while (++x < extra_cols);
|
||||
} else {
|
||||
Packet4f result_block[4];
|
||||
float *result2 = result;
|
||||
do{
|
||||
result_block[x] = loadAndMultiplyF32(acc[x], pAlpha, result);
|
||||
result += rows;
|
||||
} while (++x < 4);
|
||||
x = 0;
|
||||
do{
|
||||
storeF32<lhsExtraRows>(result2, result_block[x], rows, extra_rows);
|
||||
} while (++x < 4);
|
||||
}
|
||||
}
|
||||
|
||||
template<Index num_acc>
|
||||
EIGEN_ALWAYS_INLINE void zeroAccumulators(__vector_quad (&quad_acc)[num_acc])
|
||||
{
|
||||
BFLOAT16_UNROLL
|
||||
for(Index k = 0; k < num_acc; k++)
|
||||
__builtin_mma_xxsetaccz(&(quad_acc[k]));
|
||||
}
|
||||
|
||||
template<Index num_acc>
|
||||
EIGEN_ALWAYS_INLINE void disassembleAccumulators(__vector_quad (&quad_acc)[num_acc], Packet4f (&acc)[num_acc][4])
|
||||
{
|
||||
BFLOAT16_UNROLL
|
||||
for(Index k = 0; k < num_acc; k++)
|
||||
__builtin_mma_disassemble_acc((void*)acc[k], &(quad_acc[k]));
|
||||
}
|
||||
|
||||
template<Index num_acc, bool rhsExtraCols, bool lhsExtraRows, Index num_rhs, Index num_lhs>
|
||||
EIGEN_ALWAYS_INLINE void outputResults(Packet4f (&acc)[num_acc][4], Index rows, const Packet4f pAlpha, float* result, const Index extra_cols, Index extra_rows)
|
||||
{
|
||||
for(Index i = 0, k = 0; i < num_rhs - (rhsExtraCols ? 1 : 0); i++, result += 4*rows){
|
||||
for(Index j = 0; j < num_lhs; j++, k++) {
|
||||
storeResults<false, lhsExtraRows>(acc[k], rows, pAlpha, result + j*4, extra_cols, extra_rows);
|
||||
}
|
||||
result += rows;
|
||||
} while (++x < (rhsExtraCols ? extra_cols : 4));
|
||||
}
|
||||
if(rhsExtraCols) {
|
||||
storeResults<rhsExtraCols, lhsExtraRows>(acc[num_acc - 1], rows, pAlpha, result, extra_cols, extra_rows);
|
||||
}
|
||||
}
|
||||
|
||||
template<const Index num_acc, const Index num_packets, bool rhsExtraCols, bool lhsExtraRows, bool multiIter = false>
|
||||
EIGEN_ALWAYS_INLINE void colLoopBodyIter(Index depth, Index rows, const Packet4f pAlpha, const bfloat16* indexA, const bfloat16* indexB, Index strideB, Index offsetB, float* result, const Index extra_cols, const Index extra_rows)
|
||||
{
|
||||
constexpr Index num_lhs = multiIter ? (num_packets / 4) : 1;
|
||||
constexpr Index num_rhs = (num_acc + num_lhs - 1) / num_lhs;
|
||||
|
||||
for(Index offset_row = 0; offset_row < num_packets; offset_row += 4, indexA += (multiIter ? 0 : 8), indexB += (multiIter ? (num_rhs*strideB) : 0), result += (multiIter ? (4*rows*num_rhs) : 4)) {
|
||||
Packet4f acc[num_acc][4];
|
||||
__vector_quad quad_acc[num_acc];
|
||||
|
||||
zeroAccumulators<num_acc>(quad_acc);
|
||||
|
||||
Index k;
|
||||
for(k = 0; k + 2 <= depth; k += 2){
|
||||
KLoop<num_acc, num_packets, false, rhsExtraCols, lhsExtraRows, num_rhs, num_lhs>(indexA, indexB, quad_acc, strideB, k, offsetB, extra_cols, extra_rows);
|
||||
}
|
||||
if(depth&1){
|
||||
KLoop<num_acc, num_packets, true, rhsExtraCols, lhsExtraRows, num_rhs, num_lhs>(indexA - (multiIter ? 0 : offset_row), indexB, quad_acc, strideB, k, offsetB, extra_cols, extra_rows);
|
||||
}
|
||||
|
||||
disassembleAccumulators<num_acc>(quad_acc, acc);
|
||||
|
||||
outputResults<num_acc, rhsExtraCols, lhsExtraRows, num_rhs, num_lhs>(acc, rows, pAlpha, result, extra_cols, extra_rows);
|
||||
}
|
||||
}
|
||||
|
||||
#define MAX_BFLOAT16_ACC 8
|
||||
|
||||
template<const Index num_acc, const Index num_packets, bool rhsExtraCols, bool lhsExtraRows>
|
||||
void colLoopBody(Index& col, Index depth, Index cols, Index rows, const Packet4f pAlpha, const bfloat16* indexA, const bfloat16* indexB, Index strideB, Index offsetB, float* result, Index extra_rows)
|
||||
void colLoopBody(Index& col, Index depth, Index cols, Index rows, const Packet4f pAlpha, const bfloat16* indexA, const bfloat16* indexB, Index strideB, Index offsetB, float* result)
|
||||
{
|
||||
const Index step = (num_acc * 4); //each accumulator has 4 elements
|
||||
constexpr Index step = (num_acc * 4); //each accumulator has 4 elements
|
||||
const Index extra_cols = (rhsExtraCols) ? (cols & 3) : 0;
|
||||
const Index extra_rows = (lhsExtraRows) ? (rows & 3) : 0;
|
||||
constexpr bool multiIters = !rhsExtraCols && (num_acc == MAX_BFLOAT16_ACC);
|
||||
|
||||
do{
|
||||
for(Index offset_row = 0; offset_row < num_packets; offset_row += 4, indexA += 8, result += 4) {
|
||||
Index k;
|
||||
Packet4f acc[num_acc][4];
|
||||
__vector_quad quad_acc[num_acc];
|
||||
|
||||
BFLOAT16_UNROLL
|
||||
for(k = 0; k < num_acc; k++)
|
||||
__builtin_mma_xxsetaccz(&(quad_acc[k]));
|
||||
|
||||
for(k = 0; k + 2 <= depth; k += 2){
|
||||
KLoop<num_acc, num_packets, false, rhsExtraCols, lhsExtraRows>(indexA, indexB, quad_acc, strideB, k, offsetB, extra_cols, extra_rows);
|
||||
}
|
||||
if(depth&1){
|
||||
KLoop<num_acc, num_packets, true, rhsExtraCols, lhsExtraRows>(indexA - offset_row, indexB, quad_acc, strideB, k, offsetB, extra_cols, extra_rows);
|
||||
}
|
||||
|
||||
BFLOAT16_UNROLL
|
||||
for(k = 0; k < num_acc; k++)
|
||||
__builtin_mma_disassemble_acc((void*)acc[k], &(quad_acc[k]));
|
||||
|
||||
for(k = 0; k < (num_acc - 1); k++){
|
||||
storeResults<false, lhsExtraRows>(acc[k], rows, pAlpha, result + k*4*rows, extra_cols, extra_rows);
|
||||
}
|
||||
storeResults<rhsExtraCols, lhsExtraRows>(acc[k], rows, pAlpha, result + k*4*rows, extra_cols, extra_rows);
|
||||
if (multiIters && ((num_acc % (num_packets / 4)) == 0)) {
|
||||
colLoopBodyIter<num_acc, num_packets, rhsExtraCols, lhsExtraRows, true>(depth, rows, pAlpha, indexA, indexB, strideB, offsetB, result, extra_cols, extra_rows);
|
||||
} else {
|
||||
colLoopBodyIter<num_acc, num_packets, rhsExtraCols, lhsExtraRows>(depth, rows, pAlpha, indexA, indexB, strideB, offsetB, result, extra_cols, extra_rows);
|
||||
}
|
||||
|
||||
indexA -= num_packets*2;
|
||||
indexB += strideB*num_acc;
|
||||
result += (rows*step - num_packets);
|
||||
} while(!rhsExtraCols && (num_acc == MAX_BFLOAT16_ACC) && (step <= cols - (col += step)));
|
||||
result += rows*step;
|
||||
} while(multiIters && (step <= cols - (col += step)));
|
||||
}
|
||||
|
||||
template<const Index num_acc, const Index num_packets, bool rhsExtraCols, bool lhsExtraRows>
|
||||
EIGEN_ALWAYS_INLINE void colLoopBodyExtraN(Index col, Index depth, Index cols, Index rows, const Packet4f pAlpha, const bfloat16* indexA, const bfloat16* blockB, Index strideB, Index offsetB, float* result, Index extra_rows)
|
||||
EIGEN_ALWAYS_INLINE void colLoopBodyExtraN(Index col, Index depth, Index cols, Index rows, const Packet4f pAlpha, const bfloat16* indexA, const bfloat16* blockB, Index strideB, Index offsetB, float* result)
|
||||
{
|
||||
if (MAX_BFLOAT16_ACC > num_acc) {
|
||||
colLoopBody<num_acc + (rhsExtraCols ? 1 : 0), num_packets, rhsExtraCols, lhsExtraRows>(col, depth, cols, rows, pAlpha, indexA, blockB, strideB, offsetB, result, extra_rows);
|
||||
colLoopBody<num_acc + (rhsExtraCols ? 1 : 0), num_packets, rhsExtraCols, lhsExtraRows>(col, depth, cols, rows, pAlpha, indexA, blockB, strideB, offsetB, result);
|
||||
}
|
||||
}
|
||||
|
||||
template<const Index num_packets, bool rhsExtraCols, bool lhsExtraRows>
|
||||
void colLoopBodyExtra(Index col, Index depth, Index cols, Index rows, const Packet4f pAlpha, const bfloat16* indexA, const bfloat16* blockB, Index strideB, Index offsetB, float* result, Index extra_rows)
|
||||
void colLoopBodyExtra(Index col, Index depth, Index cols, Index rows, const Packet4f pAlpha, const bfloat16* indexA, const bfloat16* blockB, Index strideB, Index offsetB, float* result)
|
||||
{
|
||||
switch ((cols - col) >> 2) {
|
||||
case 7:
|
||||
colLoopBodyExtraN<7, num_packets, rhsExtraCols, lhsExtraRows>(col, depth, cols, rows, pAlpha, indexA, blockB, strideB, offsetB, result, extra_rows);
|
||||
colLoopBodyExtraN<7, num_packets, rhsExtraCols, lhsExtraRows>(col, depth, cols, rows, pAlpha, indexA, blockB, strideB, offsetB, result);
|
||||
break;
|
||||
case 6:
|
||||
colLoopBodyExtraN<6, num_packets, rhsExtraCols, lhsExtraRows>(col, depth, cols, rows, pAlpha, indexA, blockB, strideB, offsetB, result, extra_rows);
|
||||
colLoopBodyExtraN<6, num_packets, rhsExtraCols, lhsExtraRows>(col, depth, cols, rows, pAlpha, indexA, blockB, strideB, offsetB, result);
|
||||
break;
|
||||
case 5:
|
||||
colLoopBodyExtraN<5, num_packets, rhsExtraCols, lhsExtraRows>(col, depth, cols, rows, pAlpha, indexA, blockB, strideB, offsetB, result, extra_rows);
|
||||
colLoopBodyExtraN<5, num_packets, rhsExtraCols, lhsExtraRows>(col, depth, cols, rows, pAlpha, indexA, blockB, strideB, offsetB, result);
|
||||
break;
|
||||
case 4:
|
||||
colLoopBodyExtraN<4, num_packets, rhsExtraCols, lhsExtraRows>(col, depth, cols, rows, pAlpha, indexA, blockB, strideB, offsetB, result, extra_rows);
|
||||
colLoopBodyExtraN<4, num_packets, rhsExtraCols, lhsExtraRows>(col, depth, cols, rows, pAlpha, indexA, blockB, strideB, offsetB, result);
|
||||
break;
|
||||
case 3:
|
||||
colLoopBodyExtraN<3, num_packets, rhsExtraCols, lhsExtraRows>(col, depth, cols, rows, pAlpha, indexA, blockB, strideB, offsetB, result, extra_rows);
|
||||
colLoopBodyExtraN<3, num_packets, rhsExtraCols, lhsExtraRows>(col, depth, cols, rows, pAlpha, indexA, blockB, strideB, offsetB, result);
|
||||
break;
|
||||
case 2:
|
||||
colLoopBodyExtraN<2, num_packets, rhsExtraCols, lhsExtraRows>(col, depth, cols, rows, pAlpha, indexA, blockB, strideB, offsetB, result, extra_rows);
|
||||
colLoopBodyExtraN<2, num_packets, rhsExtraCols, lhsExtraRows>(col, depth, cols, rows, pAlpha, indexA, blockB, strideB, offsetB, result);
|
||||
break;
|
||||
case 1:
|
||||
colLoopBodyExtraN<1, num_packets, rhsExtraCols, lhsExtraRows>(col, depth, cols, rows, pAlpha, indexA, blockB, strideB, offsetB, result, extra_rows);
|
||||
colLoopBodyExtraN<1, num_packets, rhsExtraCols, lhsExtraRows>(col, depth, cols, rows, pAlpha, indexA, blockB, strideB, offsetB, result);
|
||||
break;
|
||||
default:
|
||||
if (rhsExtraCols) {
|
||||
colLoopBody<1, num_packets, true, lhsExtraRows>(col, depth, cols, rows, pAlpha, indexA, blockB, strideB, offsetB, result, extra_rows);
|
||||
colLoopBody<1, num_packets, true, lhsExtraRows>(col, depth, cols, rows, pAlpha, indexA, blockB, strideB, offsetB, result);
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
template<const Index num_packets, bool lhsExtraRows = false>
|
||||
EIGEN_ALWAYS_INLINE void colLoops(Index depth, Index cols, Index rows, const Packet4f pAlpha, const bfloat16* indexA, const bfloat16* blockB, Index strideB, Index offsetB, float* result, Index extra_rows = 0)
|
||||
EIGEN_ALWAYS_INLINE void colLoops(Index depth, Index cols, Index rows, const Packet4f pAlpha, const bfloat16* indexA, const bfloat16* blockB, Index strideB, Index offsetB, float* result)
|
||||
{
|
||||
Index col = 0;
|
||||
if (cols >= (MAX_BFLOAT16_ACC * 4)) {
|
||||
colLoopBody<MAX_BFLOAT16_ACC, num_packets, false, lhsExtraRows>(col, depth, cols, rows, pAlpha, indexA, blockB, strideB, 0, result, extra_rows);
|
||||
colLoopBody<MAX_BFLOAT16_ACC, num_packets, false, lhsExtraRows>(col, depth, cols, rows, pAlpha, indexA, blockB, strideB, 0, result);
|
||||
blockB += (strideB >> 2)*col;
|
||||
result += rows*col;
|
||||
}
|
||||
if (cols & 3) {
|
||||
colLoopBodyExtra<num_packets, true, lhsExtraRows>(col, depth, cols, rows, pAlpha, indexA, blockB, strideB, offsetB, result, extra_rows);
|
||||
colLoopBodyExtra<num_packets, true, lhsExtraRows>(col, depth, cols, rows, pAlpha, indexA, blockB, strideB, offsetB, result);
|
||||
} else {
|
||||
colLoopBodyExtra<num_packets, false, lhsExtraRows>(col, depth, cols, rows, pAlpha, indexA, blockB, strideB, 0, result, extra_rows);
|
||||
colLoopBodyExtra<num_packets, false, lhsExtraRows>(col, depth, cols, rows, pAlpha, indexA, blockB, strideB, 0, result);
|
||||
}
|
||||
}
|
||||
|
||||
EIGEN_ALWAYS_INLINE Packet8bf convertF16toF32(const float *res)
|
||||
template<bool full = true>
|
||||
EIGEN_ALWAYS_INLINE Packet8bf convertF32toBF16(const float *res)
|
||||
{
|
||||
Packet16uc fp16_0 = __builtin_vsx_xvcvspbf16(reinterpret_cast<Packet16uc>(ploadu<Packet4f>(res + 0)));
|
||||
Packet16uc fp16_1 = __builtin_vsx_xvcvspbf16(reinterpret_cast<Packet16uc>(ploadu<Packet4f>(res + 4)));
|
||||
Packet16uc fp16_1 = (full) ? __builtin_vsx_xvcvspbf16(reinterpret_cast<Packet16uc>(ploadu<Packet4f>(res + 4))) : fp16_0;
|
||||
return vec_pack(reinterpret_cast<Packet4ui>(fp16_0), reinterpret_cast<Packet4ui>(fp16_1));
|
||||
}
|
||||
|
||||
template<typename Index, typename Packet, typename RhsPacket, typename DataMapper, const Index accRows, const Index accCols>
|
||||
void gemmMMAbfloat16(const DataMapper& res, const bfloat16* blockA, const bfloat16* blockB, Index rows, Index depth, Index cols, bfloat16 alpha, Index strideA, Index strideB, Index offsetA, Index offsetB)
|
||||
template<int N>
|
||||
EIGEN_ALWAYS_INLINE void storeConvertBlockBF16(float* to, PacketBlock<Packet8bf,(N+4)/8>& block)
|
||||
{
|
||||
if(rows == 0 || cols == 0 || depth == 0) return;
|
||||
float falpha = Eigen::bfloat16_impl::bfloat16_to_float(alpha);
|
||||
if (falpha == float(0)) return;
|
||||
const Packet4f pAlpha = pset1<Packet4f>(falpha);
|
||||
ei_declare_aligned_stack_constructed_variable(float, result, cols*rows, 0);
|
||||
|
||||
typedef typename DataMapper::LinearMapper LinearMapper;
|
||||
Packet8us z = pset1<Packet8us>(0);
|
||||
for(Index j = 0; j < cols; j++){
|
||||
const LinearMapper res2 = res.getLinearMapper(0, j);
|
||||
float *result2 = result + j*rows;
|
||||
pstore(to + 0, reinterpret_cast<Packet4f>(vec_mergeh(z, block.packet[0].m_val)));
|
||||
if (N >= 8) {
|
||||
pstore(to + 4, reinterpret_cast<Packet4f>(vec_mergel(z, block.packet[0].m_val)));
|
||||
}
|
||||
if (N >= 16) {
|
||||
pstore(to + 8, reinterpret_cast<Packet4f>(vec_mergeh(z, block.packet[1].m_val)));
|
||||
pstore(to + 12, reinterpret_cast<Packet4f>(vec_mergel(z, block.packet[1].m_val)));
|
||||
}
|
||||
if (N >= 32) {
|
||||
pstore(to + 16, reinterpret_cast<Packet4f>(vec_mergeh(z, block.packet[2].m_val)));
|
||||
pstore(to + 20, reinterpret_cast<Packet4f>(vec_mergel(z, block.packet[2].m_val)));
|
||||
pstore(to + 24, reinterpret_cast<Packet4f>(vec_mergeh(z, block.packet[3].m_val)));
|
||||
pstore(to + 28, reinterpret_cast<Packet4f>(vec_mergel(z, block.packet[3].m_val)));
|
||||
}
|
||||
}
|
||||
|
||||
template<const Index size, typename DataMapper>
|
||||
EIGEN_ALWAYS_INLINE void convertBF16toF32(Index& i, float *result, Index rows, const DataMapper& src)
|
||||
{
|
||||
for(; i + size <= rows; i += size){
|
||||
PacketBlock<Packet8bf,(size+4)/8> r32;
|
||||
r32.packet[0] = src.template loadPacket<Packet8bf>(i + 0);
|
||||
if (size >= 16) {
|
||||
r32.packet[1] = src.template loadPacket<Packet8bf>(i + 8);
|
||||
}
|
||||
if (size >= 32) {
|
||||
r32.packet[2] = src.template loadPacket<Packet8bf>(i + 16);
|
||||
r32.packet[3] = src.template loadPacket<Packet8bf>(i + 24);
|
||||
}
|
||||
storeConvertBlockBF16<size>(result + i, r32);
|
||||
}
|
||||
}
|
||||
|
||||
template<typename DataMapper>
|
||||
EIGEN_ALWAYS_INLINE void convertArrayBF16toF32(float *result, Index cols, Index rows, const DataMapper& src)
|
||||
{
|
||||
typedef typename DataMapper::LinearMapper LinearMapper;
|
||||
for(Index j = 0; j < cols; j++, result += rows){
|
||||
const LinearMapper src2 = src.getLinearMapper(0, j);
|
||||
Index i = 0;
|
||||
for(; i + 32 <= rows; i+=32){
|
||||
Packet8us r32_0 = res2.template loadPacket<Packet8bf>(i + 0).m_val;
|
||||
Packet8us r32_1 = res2.template loadPacket<Packet8bf>(i + 8).m_val;
|
||||
Packet8us r32_2 = res2.template loadPacket<Packet8bf>(i + 16).m_val;
|
||||
Packet8us r32_3 = res2.template loadPacket<Packet8bf>(i + 24).m_val;
|
||||
pstore(result2 + i + 0, reinterpret_cast<Packet4f>(vec_mergeh(z, r32_0)));
|
||||
pstore(result2 + i + 4, reinterpret_cast<Packet4f>(vec_mergel(z, r32_0)));
|
||||
pstore(result2 + i + 8, reinterpret_cast<Packet4f>(vec_mergeh(z, r32_1)));
|
||||
pstore(result2 + i + 12, reinterpret_cast<Packet4f>(vec_mergel(z, r32_1)));
|
||||
pstore(result2 + i + 16, reinterpret_cast<Packet4f>(vec_mergeh(z, r32_2)));
|
||||
pstore(result2 + i + 20, reinterpret_cast<Packet4f>(vec_mergel(z, r32_2)));
|
||||
pstore(result2 + i + 24, reinterpret_cast<Packet4f>(vec_mergeh(z, r32_3)));
|
||||
pstore(result2 + i + 28, reinterpret_cast<Packet4f>(vec_mergel(z, r32_3)));
|
||||
}
|
||||
for(; i + 16 <= rows; i+=16){
|
||||
Packet8us r32_0 = res2.template loadPacket<Packet8bf>(i + 0).m_val;
|
||||
Packet8us r32_1 = res2.template loadPacket<Packet8bf>(i + 8).m_val;
|
||||
pstore(result2 + i + 0, reinterpret_cast<Packet4f>(vec_mergeh(z, r32_0)));
|
||||
pstore(result2 + i + 4, reinterpret_cast<Packet4f>(vec_mergel(z, r32_0)));
|
||||
pstore(result2 + i + 8, reinterpret_cast<Packet4f>(vec_mergeh(z, r32_1)));
|
||||
pstore(result2 + i + 12, reinterpret_cast<Packet4f>(vec_mergel(z, r32_1)));
|
||||
}
|
||||
for(; i + 8 <= rows; i+=8){
|
||||
Packet8us r32_0 = res2.template loadPacket<Packet8bf>(i + 0).m_val;
|
||||
pstore(result2 + i + 0, reinterpret_cast<Packet4f>(vec_mergeh(z, r32_0)));
|
||||
pstore(result2 + i + 4, reinterpret_cast<Packet4f>(vec_mergel(z, r32_0)));
|
||||
}
|
||||
for(; i + 4 <= rows; i+=4){
|
||||
Packet8us r32_0 = res2.template loadPacketPartial<Packet8bf>(i + 0, 4).m_val;
|
||||
pstore(result2 + i + 0, reinterpret_cast<Packet4f>(vec_mergeh(z, r32_0)));
|
||||
}
|
||||
convertBF16toF32<32, LinearMapper>(i, result, rows, src2);
|
||||
convertBF16toF32<16, LinearMapper>(i, result, rows, src2);
|
||||
convertBF16toF32<8, LinearMapper>(i, result, rows, src2);
|
||||
convertBF16toF32<4, LinearMapper>(i, result, rows, src2);
|
||||
for(; i < rows; i++){
|
||||
result2[i] = res2(i);
|
||||
result[i] = Eigen::bfloat16_impl::bfloat16_to_float(src2(i));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Index row = 0;
|
||||
Index col;
|
||||
|
||||
if( strideA == -1 ) strideA = depth;
|
||||
if( strideB == -1 ) strideB = depth;
|
||||
//Packing is done in blocks.
|
||||
//There's 4 possible sizes of blocks
|
||||
//Blocks of 8 columns with 16 elements (8x16)
|
||||
//Blocks of 8 columns with 8 elements (8x8). This happens when there's 16 > rows >= 8
|
||||
//Blocks of 8 columns with 4 elements (8x4). This happens when there's 8 > rows >= 4
|
||||
//Blocks of 8 columns with < 4 elements. This happens when there's less than 4 remaining rows
|
||||
|
||||
//Loop for LHS standard block (8x16)
|
||||
const Index standard_block_size = 16;
|
||||
const Index standard_blocks_quantity = rows/standard_block_size; //Number of standard blocks
|
||||
Index bigSuffix = (2*8) * (strideA-offsetA);
|
||||
const bfloat16* indexA = blockA;
|
||||
const bfloat16* indexB = blockB + 4*offsetB;
|
||||
Index block_index;
|
||||
strideB *= 4;
|
||||
offsetB *= 3;
|
||||
for(block_index = 0; block_index < standard_blocks_quantity; block_index++){
|
||||
indexA += 2*8*offsetA;
|
||||
colLoops<16>(depth, cols, rows, pAlpha, indexA, indexB, strideB, offsetB, result + row);
|
||||
row += 16;
|
||||
indexA += bigSuffix;
|
||||
}
|
||||
//LHS (8x8) block
|
||||
if(rows & 8){
|
||||
indexA += 1*8*offsetA;
|
||||
colLoops<8>(depth, cols, rows, pAlpha, indexA, indexB, strideB, offsetB, result + row);
|
||||
row += 8;
|
||||
indexA += (bigSuffix >> 1);
|
||||
}
|
||||
//LHS (8x4) block
|
||||
if(rows & 4){
|
||||
indexA += 1*4*offsetA;
|
||||
colLoops<4>(depth, cols, rows, pAlpha, indexA, indexB, strideB, offsetB, result + row);
|
||||
row += 4;
|
||||
indexA += (bigSuffix >> 2);
|
||||
}
|
||||
//extra rows
|
||||
Index extra_rows = rows & 3;
|
||||
if(extra_rows){
|
||||
//This index is the beginning of remaining block.
|
||||
colLoops<4, true>(depth, cols, rows, pAlpha, indexA, indexB, strideB, offsetB, result + row, extra_rows);
|
||||
}
|
||||
|
||||
//Convert back to bfloat16
|
||||
template<typename DataMapper>
|
||||
EIGEN_ALWAYS_INLINE void convertArrayF32toBF16(float *result, Index cols, Index rows, const DataMapper& res)
|
||||
{
|
||||
typedef typename DataMapper::LinearMapper LinearMapper;
|
||||
Index col, row;
|
||||
for(col = 0; col + 4 <= cols; col += 4){
|
||||
const DataMapper res2 = res.getSubMapper(0, col);
|
||||
for(row = 0; row + 8 <= rows; row += 8){
|
||||
//get and save block
|
||||
PacketBlock<Packet8bf,4> block;
|
||||
for(Index j = 0; j < 4; j++){
|
||||
block.packet[j].m_val = convertF16toF32(result + (col + j)*rows + row);
|
||||
block.packet[j].m_val = convertF32toBF16(result + (col + j)*rows + row);
|
||||
}
|
||||
|
||||
res2.template storePacketBlock<Packet8bf,4>(row, 0, block);
|
||||
@ -303,18 +330,70 @@ void gemmMMAbfloat16(const DataMapper& res, const bfloat16* blockA, const bfloat
|
||||
while(col < cols){
|
||||
const LinearMapper res2 = res.getLinearMapper(0, col);
|
||||
float *result2 = result + col*rows;
|
||||
Index r = 0;
|
||||
for(; r + 8 <= rows; r += 8){
|
||||
Packet8bf fp16 = convertF16toF32(result2 + r);
|
||||
res2.template storePacket<Packet8bf>(r, fp16);
|
||||
for(row = 0; row + 8 <= rows; row += 8){
|
||||
Packet8bf fp16 = convertF32toBF16(result2 + row);
|
||||
res2.template storePacket<Packet8bf>(row, fp16);
|
||||
}
|
||||
for(; r< rows; r++){
|
||||
res2(r) = Eigen::bfloat16(result2[r]);
|
||||
for(; row < rows; row++){
|
||||
res2(row) = Eigen::bfloat16(result2[row]);
|
||||
}
|
||||
col++;
|
||||
}
|
||||
}
|
||||
|
||||
template<Index size>
|
||||
EIGEN_ALWAYS_INLINE void calcColLoops(const bfloat16*& indexA, Index& row, Index depth, Index cols, Index rows, const Packet4f pAlpha, const bfloat16* indexB, Index strideB, Index offsetA, Index offsetB, Index bigSuffix, float *result)
|
||||
{
|
||||
if ((size == 16) || (rows & size)) {
|
||||
indexA += size*offsetA;
|
||||
colLoops<size>(depth, cols, rows, pAlpha, indexA, indexB, strideB, offsetB, result + row);
|
||||
row += size;
|
||||
indexA += bigSuffix*size/16;
|
||||
}
|
||||
}
|
||||
|
||||
template<typename DataMapper>
|
||||
void gemmMMAbfloat16(const DataMapper& res, const bfloat16* indexA, const bfloat16* indexB, Index rows, Index depth, Index cols, bfloat16 alpha, Index strideA, Index strideB, Index offsetA, Index offsetB)
|
||||
{
|
||||
float falpha = Eigen::bfloat16_impl::bfloat16_to_float(alpha);
|
||||
if (falpha == float(0)) return;
|
||||
const Packet4f pAlpha = pset1<Packet4f>(falpha);
|
||||
ei_declare_aligned_stack_constructed_variable(float, result, cols*rows, 0);
|
||||
|
||||
convertArrayBF16toF32<DataMapper>(result, cols, rows, res);
|
||||
|
||||
Index row = 0;
|
||||
|
||||
if( strideA == -1 ) strideA = depth;
|
||||
if( strideB == -1 ) strideB = depth;
|
||||
//Packing is done in blocks.
|
||||
//There's 4 possible sizes of blocks
|
||||
//Blocks of 8 columns with 16 elements (8x16)
|
||||
//Blocks of 8 columns with 8 elements (8x8). This happens when there's 16 > rows >= 8
|
||||
//Blocks of 8 columns with 4 elements (8x4). This happens when there's 8 > rows >= 4
|
||||
//Blocks of 8 columns with < 4 elements. This happens when there's less than 4 remaining rows
|
||||
|
||||
//Loop for LHS standard block (8x16)
|
||||
Index bigSuffix = (2*8) * (strideA-offsetA);
|
||||
indexB += 4*offsetB;
|
||||
strideB *= 4;
|
||||
offsetB *= 3;
|
||||
while(row + 16 <= rows){
|
||||
calcColLoops<16>(indexA, row, depth, cols, rows, pAlpha, indexB, strideB, offsetA, offsetB, bigSuffix, result);
|
||||
}
|
||||
//LHS (8x8) block
|
||||
calcColLoops<8>(indexA, row, depth, cols, rows, pAlpha, indexB, strideB, offsetA, offsetB, bigSuffix, result);
|
||||
//LHS (8x4) block
|
||||
calcColLoops<4>(indexA, row, depth, cols, rows, pAlpha, indexB, strideB, offsetA, offsetB, bigSuffix, result);
|
||||
//extra rows
|
||||
if(rows & 3){
|
||||
//This index is the beginning of remaining block.
|
||||
colLoops<4, true>(depth, cols, rows, pAlpha, indexA, indexB, strideB, offsetB, result + row);
|
||||
}
|
||||
|
||||
//Convert back to bfloat16
|
||||
convertArrayF32toBF16<DataMapper>(result, cols, rows, res);
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user