mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-04-20 08:39:37 +08:00
Fix slowdown in bfloat16 MMA when rows is not a multiple of 8 or columns is not a multiple of 4.
This commit is contained in:
parent
6d4221af76
commit
6fc9de7d93
@ -1007,6 +1007,7 @@ struct dhs_pack<bfloat16, DataMapper, Packet8bf, StorageOrder, PanelMode, true>
|
||||
|
||||
Packet2ul v1[8];
|
||||
|
||||
// This is transposing and interleaving data
|
||||
v1[0] = reinterpret_cast<Packet2ul>(vec_mergeh(reinterpret_cast<Packet4ui>(block1.packet[0].m_val), reinterpret_cast<Packet4ui>(block1.packet[1].m_val)));
|
||||
v1[1] = reinterpret_cast<Packet2ul>(vec_mergel(reinterpret_cast<Packet4ui>(block1.packet[0].m_val), reinterpret_cast<Packet4ui>(block1.packet[1].m_val)));
|
||||
v1[2] = reinterpret_cast<Packet2ul>(vec_mergeh(reinterpret_cast<Packet4ui>(block1.packet[2].m_val), reinterpret_cast<Packet4ui>(block1.packet[3].m_val)));
|
||||
@ -1052,19 +1053,82 @@ struct dhs_pack<bfloat16, DataMapper, Packet8bf, StorageOrder, PanelMode, true>
|
||||
|
||||
if(PanelMode) ri += vectorSize*(stride - offset - depth);
|
||||
}
|
||||
|
||||
if(PanelMode) ri += offset;
|
||||
|
||||
for(; j < rows; j++)
|
||||
if(j + 4 <= rows)
|
||||
{
|
||||
const DataMapper lhs2 = lhs.getSubMapper(j, 0);
|
||||
for(Index i = 0; i < depth; i++)
|
||||
Index i = 0;
|
||||
|
||||
if(PanelMode) ri += 4*offset;
|
||||
|
||||
for(; i + 2 <= depth; i+=2)
|
||||
{
|
||||
blockA[ri] = lhs2(0, i);
|
||||
ri += 1;
|
||||
if(StorageOrder == ColMajor)
|
||||
{
|
||||
PacketBlock<Packet8bf,2> block;
|
||||
|
||||
block.packet[0] = lhs2.template loadPacketPartial<Packet8bf>(0, i + 0, 4);
|
||||
block.packet[1] = lhs2.template loadPacketPartial<Packet8bf>(0, i + 1, 4);
|
||||
|
||||
block.packet[0] = vec_mergeh(block.packet[0].m_val, block.packet[1].m_val);
|
||||
|
||||
pstore<bfloat16>(blockA + ri, block.packet[0]);
|
||||
} else {
|
||||
blockA[ri+0] = lhs2(0, i + 0);
|
||||
blockA[ri+1] = lhs2(0, i + 1);
|
||||
blockA[ri+2] = lhs2(1, i + 0);
|
||||
blockA[ri+3] = lhs2(1, i + 1);
|
||||
blockA[ri+4] = lhs2(2, i + 0);
|
||||
blockA[ri+5] = lhs2(2, i + 1);
|
||||
blockA[ri+6] = lhs2(3, i + 0);
|
||||
blockA[ri+7] = lhs2(3, i + 1);
|
||||
}
|
||||
|
||||
if(PanelMode) ri += stride - depth;
|
||||
ri += 2*4;
|
||||
}
|
||||
if (depth & 1)
|
||||
{
|
||||
if(StorageOrder == ColMajor)
|
||||
{
|
||||
Packet8bf lhsV = lhs2.template loadPacketPartial<Packet8bf>(0, i + 0, 4);
|
||||
|
||||
pstore_partial<bfloat16>(blockA + ri, lhsV, 4);
|
||||
} else {
|
||||
blockA[ri+0] = lhs2(0, i);
|
||||
blockA[ri+1] = lhs2(1, i);
|
||||
blockA[ri+2] = lhs2(2, i);
|
||||
blockA[ri+3] = lhs2(3, i);
|
||||
}
|
||||
|
||||
ri += 4;
|
||||
}
|
||||
|
||||
if(PanelMode) ri += 4*(stride - offset - depth);
|
||||
j += 4;
|
||||
}
|
||||
|
||||
if (j < rows)
|
||||
{
|
||||
if(PanelMode) ri += offset*(rows - j);
|
||||
|
||||
Index i = 0;
|
||||
for(; i + 2 <= depth; i+=2)
|
||||
{
|
||||
Index k = j;
|
||||
for(; k < rows; k++)
|
||||
{
|
||||
blockA[ri+0] = lhs(k, i + 0);
|
||||
blockA[ri+1] = lhs(k, i + 1);
|
||||
ri += 2;
|
||||
}
|
||||
}
|
||||
if (depth & 1)
|
||||
{
|
||||
for(; j < rows; j++)
|
||||
{
|
||||
blockA[ri] = lhs(j, i);
|
||||
ri += 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
@ -1163,18 +1227,29 @@ struct dhs_pack<bfloat16, DataMapper, Packet8bf, StorageOrder, PanelMode, false>
|
||||
if(PanelMode) ri += 4*(stride - offset - depth);
|
||||
}
|
||||
|
||||
if(PanelMode) ri += offset;
|
||||
if (j < cols)
|
||||
{
|
||||
if(PanelMode) ri += offset*(cols - j);
|
||||
|
||||
Index i = 0;
|
||||
for(; i + 2 <= depth; i+=2)
|
||||
{
|
||||
Index k = j;
|
||||
for(; k < cols; k++)
|
||||
{
|
||||
blockB[ri+0] = rhs(i + 0, k);
|
||||
blockB[ri+1] = rhs(i + 1, k);
|
||||
ri += 2;
|
||||
}
|
||||
}
|
||||
if (depth & 1)
|
||||
{
|
||||
for(; j < cols; j++)
|
||||
{
|
||||
const DataMapper rhs2 = rhs.getSubMapper(0, j);
|
||||
for(Index i = 0; i < depth; i++)
|
||||
{
|
||||
blockB[ri] = rhs2(i, 0);
|
||||
blockB[ri] = rhs(i, j);
|
||||
ri += 1;
|
||||
}
|
||||
|
||||
if(PanelMode) ri += stride - depth;
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
@ -2662,6 +2737,7 @@ void gemm_pack_rhs<double, Index, DataMapper, nr, RowMajor, Conjugate, PanelMode
|
||||
#endif
|
||||
|
||||
#ifdef __MMA__
|
||||
#if EIGEN_ALTIVEC_USE_CUSTOM_PACK
|
||||
template<typename Index, typename DataMapper, int nr, bool Conjugate, bool PanelMode>
|
||||
struct gemm_pack_rhs<bfloat16, Index, DataMapper, nr, ColMajor, Conjugate, PanelMode>
|
||||
{
|
||||
@ -2689,6 +2765,7 @@ void gemm_pack_rhs<bfloat16, Index, DataMapper, nr, RowMajor, Conjugate, PanelMo
|
||||
dhs_pack<bfloat16, DataMapper, Packet8bf, RowMajor, PanelMode, false> pack;
|
||||
pack(blockB, rhs, depth, cols, stride, offset);
|
||||
}
|
||||
#endif
|
||||
|
||||
template<typename Index, typename DataMapper, int Pack1, int Pack2, typename Packet, bool Conjugate, bool PanelMode>
|
||||
struct gemm_pack_lhs<bfloat16, Index, DataMapper, Pack1, Pack2, Packet, ColMajor, Conjugate, PanelMode>
|
||||
|
@ -18,8 +18,8 @@ EIGEN_ALWAYS_INLINE void scaleAndStore(float* result, Packet4f& acc, const Packe
|
||||
pstoreu(result, result_block);
|
||||
}
|
||||
|
||||
template<Index num_packets, bool zero>
|
||||
EIGEN_ALWAYS_INLINE Packet8bf loadLhsBfloat16(const bfloat16* indexA)
|
||||
template<bool zero>
|
||||
EIGEN_ALWAYS_INLINE Packet8bf loadBfloat16(const bfloat16* indexA)
|
||||
{
|
||||
Packet8bf lhs1 = ploadu<Packet8bf>(indexA);
|
||||
if(zero){
|
||||
@ -31,48 +31,33 @@ EIGEN_ALWAYS_INLINE Packet8bf loadLhsBfloat16(const bfloat16* indexA)
|
||||
}
|
||||
|
||||
template<bool zero>
|
||||
EIGEN_ALWAYS_INLINE Packet8bf loadBfloat16Extra(const bfloat16* indexA, Index strideA, Index extra_rows)
|
||||
EIGEN_ALWAYS_INLINE Packet8bf loadBfloat16Extra(const bfloat16* indexA, Index extra_rows)
|
||||
{
|
||||
Index row_count = 0;
|
||||
if (zero) {
|
||||
EIGEN_ALIGN16 bfloat16 lhs_array[8] = { Eigen::bfloat16(0) };
|
||||
do{
|
||||
lhs_array[row_count] = *indexA;
|
||||
indexA += strideA;
|
||||
} while ((row_count += 2) < extra_rows*2);
|
||||
return pload_partial<Packet8bf>(lhs_array, extra_rows*2);
|
||||
Packet8bf lhs1 = ploadu_partial<Packet8bf>(indexA, extra_rows);
|
||||
Packet8bf lhs2 = pset1<Packet8bf>(Eigen::bfloat16(0));
|
||||
return vec_mergeh(lhs1.m_val, lhs2.m_val);
|
||||
} else {
|
||||
EIGEN_ALIGN16 int lhs_array[4];
|
||||
do{
|
||||
lhs_array[row_count] = *reinterpret_cast<const int *>(indexA);
|
||||
indexA += strideA;
|
||||
} while ((row_count += 1) < extra_rows);
|
||||
return reinterpret_cast<Packet8us>(pload_partial<Packet4i>(lhs_array, extra_rows));
|
||||
return reinterpret_cast<Packet8us>(ploadu_partial<Packet4i>(reinterpret_cast<const int *>(indexA), extra_rows));
|
||||
}
|
||||
}
|
||||
|
||||
template<bool zero>
|
||||
EIGEN_ALWAYS_INLINE Packet8bf loadLhsBfloat16ExtraRows(const bfloat16* indexA, Index strideA, Index row, Index extra_rows)
|
||||
{
|
||||
return loadBfloat16Extra<zero>(indexA + row*strideA, strideA, extra_rows);
|
||||
return loadBfloat16Extra<zero>(indexA + row*strideA, extra_rows);
|
||||
}
|
||||
|
||||
template<bool zero>
|
||||
EIGEN_ALWAYS_INLINE Packet8bf loadRhsBfloat16(const bfloat16* baseB, Index strideB, Index i, Index k)
|
||||
{
|
||||
const bfloat16* indexB = baseB + strideB*4*i + (k*4);
|
||||
Packet8bf rhs1 = ploadu<Packet8bf>(indexB);
|
||||
if(zero){
|
||||
Packet8bf rhs2 = pset1<Packet8bf>(Eigen::bfloat16(0));
|
||||
return vec_mergeh(rhs1.m_val, rhs2.m_val);
|
||||
}
|
||||
return rhs1;
|
||||
return loadBfloat16<zero>(baseB + strideB*4*i + (k*4));
|
||||
}
|
||||
|
||||
template<bool zero>
|
||||
EIGEN_ALWAYS_INLINE Packet8bf loadRhsBfloat16ExtraCols(const bfloat16* blockB, Index strideB, Index offsetB, Index col, Index i, Index k, Index extra_cols)
|
||||
{
|
||||
return loadBfloat16Extra<zero>(blockB + ((col+4*i)*strideB)+k+offsetB, strideB, extra_cols);
|
||||
return loadBfloat16Extra<zero>(blockB + ((col+4*i)*strideB)+k*extra_cols+offsetB, extra_cols);
|
||||
}
|
||||
|
||||
template<Index num_acc, Index num_packets, bool zero, bool rhs_extra_cols, bool lhs_extra_rows>
|
||||
@ -93,8 +78,8 @@ EIGEN_STRONG_INLINE void KLoop
|
||||
{
|
||||
Packet8bf lhs;
|
||||
Packet8bf rhs[num_acc];
|
||||
if(lhs_extra_rows) lhs = loadLhsBfloat16ExtraRows<zero>(indexA+k, strideA, row, extra_rows);
|
||||
else lhs = loadLhsBfloat16<num_packets, zero>(indexA + k*num_packets); //a packet of bfloat16 has 8 elements
|
||||
if(lhs_extra_rows) lhs = loadLhsBfloat16ExtraRows<zero>(indexA+k*extra_rows, strideA, row, extra_rows);
|
||||
else lhs = loadBfloat16<zero>(indexA + k*num_packets); //a packet of bfloat16 has 8 elements
|
||||
BFLOAT16_UNROLL
|
||||
for(Index i = 0; i < num_acc; i++){
|
||||
if(!rhs_extra_cols)
|
||||
@ -125,7 +110,7 @@ void colLoopBody(Index& col, Index row, Index depth, Index cols, Index rows, Ind
|
||||
KLoop<num_acc, num_packets, false, rhsExtraCols, lhsExtraRows>(indexA, indexB, quad_acc, strideA, strideB, offsetB, k, row, col, extra_rows, extra_cols);
|
||||
}
|
||||
if(depth&1){
|
||||
KLoop<num_acc, num_packets, true, rhsExtraCols, lhsExtraRows>(indexA-offset_row, indexB, quad_acc, strideA, strideB, offsetB, k, row, col, extra_rows, extra_cols);
|
||||
KLoop<num_acc, num_packets, true, rhsExtraCols, lhsExtraRows>(indexA-(offset_row&(num_packets-1)), indexB, quad_acc, strideA, strideB, offsetB, k, row, col, extra_rows, extra_cols);
|
||||
}
|
||||
|
||||
BFLOAT16_UNROLL
|
||||
@ -174,8 +159,10 @@ void gemmMMAbfloat16(const DataMapper& res, const bfloat16* blockA, const bfloat
|
||||
typedef typename DataMapper::LinearMapper LinearMapper;
|
||||
for(Index j = 0; j < cols; j++){
|
||||
const LinearMapper res2 = res.getLinearMapper(0, j);
|
||||
float *result2 = result + j*rows;
|
||||
BFLOAT16_UNROLL
|
||||
for(Index i = 0; i < rows; i++){
|
||||
result[j*rows + i] = res2(i);
|
||||
result2[i] = res2(i);
|
||||
}
|
||||
}
|
||||
|
||||
@ -185,15 +172,16 @@ void gemmMMAbfloat16(const DataMapper& res, const bfloat16* blockA, const bfloat
|
||||
if( strideA == -1 ) strideA = depth;
|
||||
if( strideB == -1 ) strideB = depth;
|
||||
//Packing is done in blocks.
|
||||
//There's 3 possible sizes of blocks
|
||||
//Blocks of 8 columns with 16 elements (8x16) as col major
|
||||
//Blocks of 8 columns with 8 elements (8x8) as col major. This happens when there's 16 > rows > 8
|
||||
//Blocks of 8 columns with <8 elements as row major. This happens when there's less than 8 remaining rows
|
||||
//There's 4 possible sizes of blocks
|
||||
//Blocks of 8 columns with 16 elements (8x16)
|
||||
//Blocks of 8 columns with 8 elements (8x8). This happens when there's 16 > rows >= 8
|
||||
//Blocks of 8 columns with 4 elements (8x4). This happens when there's 8 > rows >= 4
|
||||
//Blocks of 8 columns with < 4 elements. This happens when there's less than 4 remaining rows
|
||||
|
||||
//Loop for LHS standard block (8x16)
|
||||
const Index standard_block_size = 16;
|
||||
const Index standard_blocks_quantity = rows/standard_block_size; //Number of standard blocks
|
||||
Index bigSuffix = (2*8) * (strideA-offsetA-depth);
|
||||
Index bigSuffix = (2*8) * (strideA-offsetA);
|
||||
const bfloat16* indexA = blockA;
|
||||
const Index offset_factor = 2;
|
||||
Index block_index;
|
||||
@ -215,11 +203,11 @@ void gemmMMAbfloat16(const DataMapper& res, const bfloat16* blockA, const bfloat
|
||||
}
|
||||
}
|
||||
row += 16;
|
||||
indexA += bigSuffix + 2*8*depth;
|
||||
indexA += bigSuffix;
|
||||
}
|
||||
//LHS (8x8) block
|
||||
if(rows - standard_blocks_quantity*16 >= 8){
|
||||
indexA += 1*8*offsetA + 2*8*offsetA;
|
||||
if(rows & 8){
|
||||
indexA += 1*8*offsetA;
|
||||
for(Index offset_row = 0; offset_row < 8; offset_row += 4){
|
||||
col = 0;
|
||||
colLoopBody<7, 8>(col, row, depth, cols, rows, offset_row, block_index, pAlpha, indexA+offset_row*offset_factor, strideA, blockB, strideB, offsetB, result);
|
||||
@ -238,14 +226,33 @@ void gemmMMAbfloat16(const DataMapper& res, const bfloat16* blockA, const bfloat
|
||||
}
|
||||
} //end extra cols
|
||||
row += 8;
|
||||
indexA += (bigSuffix >> 1);
|
||||
}
|
||||
//LHS (8x4) block
|
||||
if(rows & 4){
|
||||
Index offset_row = (rows & 8);
|
||||
indexA += 1*4*offsetA;
|
||||
col = 0;
|
||||
colLoopBody<7, 4>(col, row, depth, cols, rows, offset_row, block_index, pAlpha, indexA, strideA, blockB, strideB, offsetB, result);
|
||||
colLoopBody<6, 4>(col, row, depth, cols, rows, offset_row, block_index, pAlpha, indexA, strideA, blockB, strideB, offsetB, result);
|
||||
colLoopBody<5, 4>(col, row, depth, cols, rows, offset_row, block_index, pAlpha, indexA, strideA, blockB, strideB, offsetB, result);
|
||||
colLoopBody<4, 4>(col, row, depth, cols, rows, offset_row, block_index, pAlpha, indexA, strideA, blockB, strideB, offsetB, result);
|
||||
colLoopBody<3, 4>(col, row, depth, cols, rows, offset_row, block_index, pAlpha, indexA, strideA, blockB, strideB, offsetB, result);
|
||||
colLoopBody<2, 4>(col, row, depth, cols, rows, offset_row, block_index, pAlpha, indexA, strideA, blockB, strideB, offsetB, result);
|
||||
colLoopBody<1, 4>(col, row, depth, cols, rows, offset_row, block_index, pAlpha, indexA, strideA, blockB, strideB, offsetB, result);
|
||||
if(cols > col){
|
||||
Index extra_cols= cols-col;
|
||||
|
||||
colLoopBody<1, 4, true>(col, row, depth, cols, rows, 0, block_index, pAlpha, indexA, strideA, blockB, strideB, offsetB, result, extra_cols, 4);
|
||||
}
|
||||
row += 4;
|
||||
indexA += (bigSuffix >> 2);
|
||||
}
|
||||
//extra rows
|
||||
while(row < rows){
|
||||
Index extra_rows = rows-row;
|
||||
Index extra_rows_or_four = (extra_rows <= 4) ? extra_rows : 4;
|
||||
if(row < rows){
|
||||
Index extra_rows_or_four = rows-row;
|
||||
|
||||
//This index is the beginning of remaining block.
|
||||
//This last block for LHS is organized as RowMajor
|
||||
col = 0;
|
||||
colLoopBody<7, 8, false, true>(col, row, depth, cols, rows, 0, block_index, pAlpha, blockA, strideA, blockB, strideB, offsetB, result, 4, extra_rows_or_four);
|
||||
colLoopBody<6, 8, false, true>(col, row, depth, cols, rows, 0, block_index, pAlpha, blockA, strideA, blockB, strideB, offsetB, result, 4, extra_rows_or_four);
|
||||
@ -269,8 +276,8 @@ void gemmMMAbfloat16(const DataMapper& res, const bfloat16* blockA, const bfloat
|
||||
//get and save block
|
||||
PacketBlock<Packet8bf,4> block;
|
||||
for(Index j = 0; j < 4; j++){
|
||||
Packet16uc fp16_0 = __builtin_vsx_xvcvspbf16(reinterpret_cast<Packet16uc>(pload<Packet4f>(result + (col + j)*rows + row)));
|
||||
Packet16uc fp16_1 = __builtin_vsx_xvcvspbf16(reinterpret_cast<Packet16uc>(pload<Packet4f>(result + (col + j)*rows + row + 4)));
|
||||
Packet16uc fp16_0 = __builtin_vsx_xvcvspbf16(reinterpret_cast<Packet16uc>(ploadu<Packet4f>(result + (col + j)*rows + row)));
|
||||
Packet16uc fp16_1 = __builtin_vsx_xvcvspbf16(reinterpret_cast<Packet16uc>(ploadu<Packet4f>(result + (col + j)*rows + row + 4)));
|
||||
block.packet[j].m_val = vec_pack(reinterpret_cast<Packet4ui>(fp16_0), reinterpret_cast<Packet4ui>(fp16_1));
|
||||
}
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user