mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-09-13 09:53:13 +08:00
Fold extra column calculations into an extra MMA accumulator and other bfloat16 MMA GEMM improvements
This commit is contained in:
parent
79cfc74f4d
commit
fba12e02b3
@ -11,13 +11,6 @@ namespace Eigen {
|
|||||||
|
|
||||||
namespace internal {
|
namespace internal {
|
||||||
|
|
||||||
EIGEN_ALWAYS_INLINE void scaleAndStore(float* result, Packet4f& acc, const Packet4f& pAlpha)
|
|
||||||
{
|
|
||||||
Packet4f result_block = ploadu<Packet4f>(result);
|
|
||||||
result_block = pmadd(acc, pAlpha, result_block);
|
|
||||||
pstoreu(result, result_block);
|
|
||||||
}
|
|
||||||
|
|
||||||
template<bool zero>
|
template<bool zero>
|
||||||
EIGEN_ALWAYS_INLINE Packet8bf loadBfloat16(const bfloat16* indexA)
|
EIGEN_ALWAYS_INLINE Packet8bf loadBfloat16(const bfloat16* indexA)
|
||||||
{
|
{
|
||||||
@ -31,120 +24,159 @@ EIGEN_ALWAYS_INLINE Packet8bf loadBfloat16(const bfloat16* indexA)
|
|||||||
}
|
}
|
||||||
|
|
||||||
template<bool zero>
|
template<bool zero>
|
||||||
EIGEN_ALWAYS_INLINE Packet8bf loadBfloat16Extra(const bfloat16* indexA, Index extra_rows)
|
EIGEN_ALWAYS_INLINE Packet8bf loadRhsBfloat16(const bfloat16* blockB, Index strideB, Index i)
|
||||||
{
|
{
|
||||||
if (zero) {
|
return loadBfloat16<zero>(blockB + strideB*i);
|
||||||
Packet8bf lhs1 = ploadu_partial<Packet8bf>(indexA, extra_rows);
|
|
||||||
Packet8bf lhs2 = pset1<Packet8bf>(Eigen::bfloat16(0));
|
|
||||||
return vec_mergeh(lhs1.m_val, lhs2.m_val);
|
|
||||||
} else {
|
|
||||||
return reinterpret_cast<Packet8us>(ploadu_partial<Packet4i>(reinterpret_cast<const int *>(indexA), extra_rows));
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
template<bool zero>
|
template<Index num_acc, Index num_packets, bool zero, bool rhsExtraCols, bool lhsExtraRows>
|
||||||
EIGEN_ALWAYS_INLINE Packet8bf loadLhsBfloat16ExtraRows(const bfloat16* indexA, Index strideA, Index row, Index extra_rows)
|
EIGEN_ALWAYS_INLINE void KLoop
|
||||||
{
|
|
||||||
return loadBfloat16Extra<zero>(indexA + row*strideA, extra_rows);
|
|
||||||
}
|
|
||||||
|
|
||||||
template<bool zero>
|
|
||||||
EIGEN_ALWAYS_INLINE Packet8bf loadRhsBfloat16(const bfloat16* baseB, Index strideB, Index i, Index k)
|
|
||||||
{
|
|
||||||
return loadBfloat16<zero>(baseB + strideB*4*i + (k*4));
|
|
||||||
}
|
|
||||||
|
|
||||||
template<bool zero>
|
|
||||||
EIGEN_ALWAYS_INLINE Packet8bf loadRhsBfloat16ExtraCols(const bfloat16* blockB, Index strideB, Index offsetB, Index col, Index i, Index k, Index extra_cols)
|
|
||||||
{
|
|
||||||
return loadBfloat16Extra<zero>(blockB + ((col+4*i)*strideB)+k*extra_cols+offsetB, extra_cols);
|
|
||||||
}
|
|
||||||
|
|
||||||
template<Index num_acc, Index num_packets, bool zero, bool rhs_extra_cols, bool lhs_extra_rows>
|
|
||||||
EIGEN_STRONG_INLINE void KLoop
|
|
||||||
(
|
(
|
||||||
const bfloat16* indexA,
|
const bfloat16* indexA,
|
||||||
const bfloat16* indexB,
|
const bfloat16* indexB,
|
||||||
__vector_quad (&quad_acc)[num_acc],
|
__vector_quad (&quad_acc)[num_acc],
|
||||||
Index strideA,
|
|
||||||
Index strideB,
|
Index strideB,
|
||||||
Index offsetB,
|
|
||||||
Index k,
|
Index k,
|
||||||
Index row,
|
Index offsetB,
|
||||||
Index col,
|
Index extra_cols,
|
||||||
Index extra_rows,
|
Index extra_rows
|
||||||
Index extra_cols
|
|
||||||
)
|
)
|
||||||
{
|
{
|
||||||
Packet8bf lhs;
|
Packet8bf lhs = loadBfloat16<zero>(indexA + k*(lhsExtraRows ? extra_rows : num_packets)); //a packet of bfloat16 has 8 elements
|
||||||
Packet8bf rhs[num_acc];
|
Packet8bf rhs[num_acc];
|
||||||
if(lhs_extra_rows) lhs = loadLhsBfloat16ExtraRows<zero>(indexA+k*extra_rows, strideA, row, extra_rows);
|
|
||||||
else lhs = loadBfloat16<zero>(indexA + k*num_packets); //a packet of bfloat16 has 8 elements
|
for(Index i = 0; i < (num_acc - (rhsExtraCols ? 1 : 0)); i++){
|
||||||
|
rhs[i] = loadRhsBfloat16<zero>(indexB + k*4, strideB, i);
|
||||||
|
}
|
||||||
|
if(rhsExtraCols) {
|
||||||
|
rhs[num_acc - 1] = loadRhsBfloat16<zero>(indexB + k*extra_cols - offsetB, strideB, num_acc - 1);
|
||||||
|
}
|
||||||
|
|
||||||
BFLOAT16_UNROLL
|
BFLOAT16_UNROLL
|
||||||
for (Index i = 0; i < num_acc; i++) {
|
for (Index i = 0; i < num_acc; i++) {
|
||||||
if(!rhs_extra_cols)
|
|
||||||
rhs[i] = loadRhsBfloat16<zero>(indexB, strideB, i, k);
|
|
||||||
else{
|
|
||||||
rhs[i] = loadRhsBfloat16ExtraCols<zero>(indexB, strideB, offsetB, col, i, k, extra_cols);
|
|
||||||
}
|
|
||||||
__builtin_mma_xvbf16ger2pp(&(quad_acc[i]), reinterpret_cast<Packet16uc>(rhs[i].m_val), reinterpret_cast<Packet16uc>(lhs.m_val));
|
__builtin_mma_xvbf16ger2pp(&(quad_acc[i]), reinterpret_cast<Packet16uc>(rhs[i].m_val), reinterpret_cast<Packet16uc>(lhs.m_val));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template<const Index num_acc, const Index num_packets, bool rhsExtraCols = false, bool lhsExtraRows = false>
|
template <bool rhsExtraCols, bool lhsExtraRows>
|
||||||
void colLoopBody(Index& col, Index row, Index depth, Index cols, Index rows, Index offset_row, Index block_index, const Packet4f& pAlpha, const bfloat16* indexA, Index strideA, const bfloat16* blockB, Index strideB, Index offsetB, float* result, Index extra_cols = 0, Index extra_rows = 0)
|
EIGEN_ALWAYS_INLINE void storeResults(Packet4f (&acc)[4], Index rows, const Packet4f pAlpha, float* result, Index extra_cols, Index extra_rows)
|
||||||
{
|
{
|
||||||
const Index step = rhsExtraCols ? 1 : (num_acc * 4); //each accumulator has 4 elements
|
Index x = 0;
|
||||||
const bfloat16* indexB = rhsExtraCols ? blockB : (blockB + 4*offsetB + strideB*col);
|
do{
|
||||||
|
Packet4f result_block = ploadu<Packet4f>(result);
|
||||||
|
result_block = pmadd(acc[x], pAlpha, result_block);
|
||||||
|
if (lhsExtraRows) {
|
||||||
|
pstoreu_partial(result, result_block, extra_rows);
|
||||||
|
} else {
|
||||||
|
pstoreu(result, result_block);
|
||||||
|
}
|
||||||
|
result += rows;
|
||||||
|
} while (++x < (rhsExtraCols ? extra_cols : 4));
|
||||||
|
}
|
||||||
|
|
||||||
while(col + step <= cols){
|
#define MAX_BFLOAT16_ACC 8
|
||||||
Index k = 0;
|
|
||||||
|
template<const Index num_acc, const Index num_packets, bool rhsExtraCols, bool lhsExtraRows>
|
||||||
|
void colLoopBody(Index& col, Index depth, Index cols, Index rows, const Packet4f pAlpha, const bfloat16* indexA, const bfloat16* indexB, Index strideB, Index offsetB, float* result, Index extra_rows)
|
||||||
|
{
|
||||||
|
const Index step = (num_acc * 4); //each accumulator has 4 elements
|
||||||
|
const Index extra_cols = (rhsExtraCols) ? (cols & 3) : 0;
|
||||||
|
|
||||||
|
do{
|
||||||
|
for(Index offset_row = 0; offset_row < num_packets; offset_row += 4, indexA += 8, result += 4) {
|
||||||
|
Index k;
|
||||||
Packet4f acc[num_acc][4];
|
Packet4f acc[num_acc][4];
|
||||||
__vector_quad quad_acc[num_acc];
|
__vector_quad quad_acc[num_acc];
|
||||||
|
|
||||||
BFLOAT16_UNROLL
|
BFLOAT16_UNROLL
|
||||||
for(Index i = 0; i < num_acc; i++)
|
for(k = 0; k < num_acc; k++)
|
||||||
__builtin_mma_xxsetaccz(&(quad_acc[i]));
|
__builtin_mma_xxsetaccz(&(quad_acc[k]));
|
||||||
|
|
||||||
for(; k + 2 <= depth; k += 2){
|
for(k = 0; k + 2 <= depth; k += 2){
|
||||||
KLoop<num_acc, num_packets, false, rhsExtraCols, lhsExtraRows>(indexA, indexB, quad_acc, strideA, strideB, offsetB, k, row, col, extra_rows, extra_cols);
|
KLoop<num_acc, num_packets, false, rhsExtraCols, lhsExtraRows>(indexA, indexB, quad_acc, strideB, k, offsetB, extra_cols, extra_rows);
|
||||||
}
|
}
|
||||||
if(depth&1){
|
if(depth&1){
|
||||||
KLoop<num_acc, num_packets, true, rhsExtraCols, lhsExtraRows>(indexA-(offset_row&(num_packets-1)), indexB, quad_acc, strideA, strideB, offsetB, k, row, col, extra_rows, extra_cols);
|
KLoop<num_acc, num_packets, true, rhsExtraCols, lhsExtraRows>(indexA - offset_row, indexB, quad_acc, strideB, k, offsetB, extra_cols, extra_rows);
|
||||||
}
|
}
|
||||||
|
|
||||||
BFLOAT16_UNROLL
|
BFLOAT16_UNROLL
|
||||||
for(Index i = 0; i < num_acc; i++)
|
for(k = 0; k < num_acc; k++)
|
||||||
__builtin_mma_disassemble_acc((void*)acc[i], &(quad_acc[i]));
|
__builtin_mma_disassemble_acc((void*)acc[k], &(quad_acc[k]));
|
||||||
|
|
||||||
for(Index i = 0; i < num_acc; i++){
|
for(k = 0; k < (num_acc - 1); k++){
|
||||||
if(lhsExtraRows){
|
storeResults<false, lhsExtraRows>(acc[k], rows, pAlpha, result + k*4*rows, extra_cols, extra_rows);
|
||||||
float *r = result + (col+i*4)*rows + row;
|
}
|
||||||
for(Index x = 0; x < extra_cols; x++, r += rows){
|
storeResults<rhsExtraCols, lhsExtraRows>(acc[k], rows, pAlpha, result + k*4*rows, extra_cols, extra_rows);
|
||||||
Packet4f result_block = ploadu_partial<Packet4f>(r, extra_rows);
|
}
|
||||||
result_block = pmadd(acc[i][x], pAlpha, result_block);
|
|
||||||
pstoreu_partial<float>(r, result_block, extra_rows);
|
indexA -= num_packets*2;
|
||||||
|
indexB += strideB*num_acc;
|
||||||
|
result += (rows*step - num_packets);
|
||||||
|
} while(!rhsExtraCols && (num_acc == MAX_BFLOAT16_ACC) && (step <= cols - (col += step)));
|
||||||
|
}
|
||||||
|
|
||||||
|
template<const Index num_acc, const Index num_packets, bool rhsExtraCols, bool lhsExtraRows>
|
||||||
|
EIGEN_ALWAYS_INLINE void colLoopBodyExtraN(Index col, Index depth, Index cols, Index rows, const Packet4f pAlpha, const bfloat16* indexA, const bfloat16* blockB, Index strideB, Index offsetB, float* result, Index extra_rows)
|
||||||
|
{
|
||||||
|
if (MAX_BFLOAT16_ACC > num_acc) {
|
||||||
|
colLoopBody<num_acc + (rhsExtraCols ? 1 : 0), num_packets, rhsExtraCols, lhsExtraRows>(col, depth, cols, rows, pAlpha, indexA, blockB, strideB, offsetB, result, extra_rows);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
else{
|
|
||||||
|
template<const Index num_packets, bool rhsExtraCols, bool lhsExtraRows>
|
||||||
|
void colLoopBodyExtra(Index col, Index depth, Index cols, Index rows, const Packet4f pAlpha, const bfloat16* indexA, const bfloat16* blockB, Index strideB, Index offsetB, float* result, Index extra_rows)
|
||||||
|
{
|
||||||
|
switch ((cols - col) >> 2) {
|
||||||
|
case 7:
|
||||||
|
colLoopBodyExtraN<7, num_packets, rhsExtraCols, lhsExtraRows>(col, depth, cols, rows, pAlpha, indexA, blockB, strideB, offsetB, result, extra_rows);
|
||||||
|
break;
|
||||||
|
case 6:
|
||||||
|
colLoopBodyExtraN<6, num_packets, rhsExtraCols, lhsExtraRows>(col, depth, cols, rows, pAlpha, indexA, blockB, strideB, offsetB, result, extra_rows);
|
||||||
|
break;
|
||||||
|
case 5:
|
||||||
|
colLoopBodyExtraN<5, num_packets, rhsExtraCols, lhsExtraRows>(col, depth, cols, rows, pAlpha, indexA, blockB, strideB, offsetB, result, extra_rows);
|
||||||
|
break;
|
||||||
|
case 4:
|
||||||
|
colLoopBodyExtraN<4, num_packets, rhsExtraCols, lhsExtraRows>(col, depth, cols, rows, pAlpha, indexA, blockB, strideB, offsetB, result, extra_rows);
|
||||||
|
break;
|
||||||
|
case 3:
|
||||||
|
colLoopBodyExtraN<3, num_packets, rhsExtraCols, lhsExtraRows>(col, depth, cols, rows, pAlpha, indexA, blockB, strideB, offsetB, result, extra_rows);
|
||||||
|
break;
|
||||||
|
case 2:
|
||||||
|
colLoopBodyExtraN<2, num_packets, rhsExtraCols, lhsExtraRows>(col, depth, cols, rows, pAlpha, indexA, blockB, strideB, offsetB, result, extra_rows);
|
||||||
|
break;
|
||||||
|
case 1:
|
||||||
|
colLoopBodyExtraN<1, num_packets, rhsExtraCols, lhsExtraRows>(col, depth, cols, rows, pAlpha, indexA, blockB, strideB, offsetB, result, extra_rows);
|
||||||
|
break;
|
||||||
|
default:
|
||||||
if (rhsExtraCols) {
|
if (rhsExtraCols) {
|
||||||
float *r = result + (col+i*4)*rows + row + offset_row;
|
colLoopBody<1, num_packets, true, lhsExtraRows>(col, depth, cols, rows, pAlpha, indexA, blockB, strideB, offsetB, result, extra_rows);
|
||||||
for(Index x = 0; x < cols-col; x++, r += rows){
|
}
|
||||||
scaleAndStore(r,acc[i][x], pAlpha);
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
else{
|
|
||||||
float *r = result + (col+i*4)*rows + (block_index*16) + offset_row;
|
template<const Index num_packets, bool lhsExtraRows = false>
|
||||||
for(Index x = 0; x < 4; x++, r += rows){
|
EIGEN_ALWAYS_INLINE void colLoops(Index depth, Index cols, Index rows, const Packet4f pAlpha, const bfloat16* indexA, const bfloat16* blockB, Index strideB, Index offsetB, float* result, Index extra_rows = 0)
|
||||||
scaleAndStore(r,acc[i][x], pAlpha);
|
{
|
||||||
|
Index col = 0;
|
||||||
|
if (cols >= (MAX_BFLOAT16_ACC * 4)) {
|
||||||
|
colLoopBody<MAX_BFLOAT16_ACC, num_packets, false, lhsExtraRows>(col, depth, cols, rows, pAlpha, indexA, blockB, strideB, 0, result, extra_rows);
|
||||||
|
blockB += (strideB >> 2)*col;
|
||||||
|
result += rows*col;
|
||||||
|
}
|
||||||
|
if (cols & 3) {
|
||||||
|
colLoopBodyExtra<num_packets, true, lhsExtraRows>(col, depth, cols, rows, pAlpha, indexA, blockB, strideB, offsetB, result, extra_rows);
|
||||||
|
} else {
|
||||||
|
colLoopBodyExtra<num_packets, false, lhsExtraRows>(col, depth, cols, rows, pAlpha, indexA, blockB, strideB, 0, result, extra_rows);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
}
|
EIGEN_ALWAYS_INLINE Packet8bf convertF16toF32(const float *res)
|
||||||
if(rhsExtraCols) return;
|
{
|
||||||
indexB += strideB*step;
|
Packet16uc fp16_0 = __builtin_vsx_xvcvspbf16(reinterpret_cast<Packet16uc>(ploadu<Packet4f>(res + 0)));
|
||||||
col += step;
|
Packet16uc fp16_1 = __builtin_vsx_xvcvspbf16(reinterpret_cast<Packet16uc>(ploadu<Packet4f>(res + 4)));
|
||||||
}
|
return vec_pack(reinterpret_cast<Packet4ui>(fp16_0), reinterpret_cast<Packet4ui>(fp16_1));
|
||||||
}
|
}
|
||||||
|
|
||||||
template<typename Index, typename Packet, typename RhsPacket, typename DataMapper, const Index accRows, const Index accCols>
|
template<typename Index, typename Packet, typename RhsPacket, typename DataMapper, const Index accRows, const Index accCols>
|
||||||
@ -157,17 +189,33 @@ void gemmMMAbfloat16(const DataMapper& res, const bfloat16* blockA, const bfloat
|
|||||||
ei_declare_aligned_stack_constructed_variable(float, result, cols*rows, 0);
|
ei_declare_aligned_stack_constructed_variable(float, result, cols*rows, 0);
|
||||||
|
|
||||||
typedef typename DataMapper::LinearMapper LinearMapper;
|
typedef typename DataMapper::LinearMapper LinearMapper;
|
||||||
|
Packet4f z = pset1<Packet4f>(float(0));
|
||||||
for(Index j = 0; j < cols; j++){
|
for(Index j = 0; j < cols; j++){
|
||||||
const LinearMapper res2 = res.getLinearMapper(0, j);
|
const LinearMapper res2 = res.getLinearMapper(0, j);
|
||||||
float *result2 = result + j*rows;
|
float *result2 = result + j*rows;
|
||||||
|
Index i = 0;
|
||||||
|
for(; i + 32 <= rows; i+=32){
|
||||||
|
Packet4f r32_0 = reinterpret_cast<Packet4f>(res2.template loadPacket<Packet8bf>(i + 0).m_val);
|
||||||
|
Packet4f r32_1 = reinterpret_cast<Packet4f>(res2.template loadPacket<Packet8bf>(i + 8).m_val);
|
||||||
|
Packet4f r32_2 = reinterpret_cast<Packet4f>(res2.template loadPacket<Packet8bf>(i + 16).m_val);
|
||||||
|
Packet4f r32_3 = reinterpret_cast<Packet4f>(res2.template loadPacket<Packet8bf>(i + 24).m_val);
|
||||||
|
pstore(result2 + i + 0, vec_mergeo(r32_0, z));
|
||||||
|
pstore(result2 + i + 4, vec_mergee(r32_0, z));
|
||||||
|
pstore(result2 + i + 8, vec_mergeo(r32_1, z));
|
||||||
|
pstore(result2 + i + 12, vec_mergee(r32_1, z));
|
||||||
|
pstore(result2 + i + 16, vec_mergeo(r32_2, z));
|
||||||
|
pstore(result2 + i + 20, vec_mergee(r32_2, z));
|
||||||
|
pstore(result2 + i + 24, vec_mergeo(r32_3, z));
|
||||||
|
pstore(result2 + i + 28, vec_mergee(r32_3, z));
|
||||||
|
}
|
||||||
BFLOAT16_UNROLL
|
BFLOAT16_UNROLL
|
||||||
for(Index i = 0; i < rows; i++){
|
for(; i < rows; i++){
|
||||||
result2[i] = res2(i);
|
result2[i] = res2(i);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
Index row = 0;
|
Index row = 0;
|
||||||
Index col = 0;
|
Index col;
|
||||||
|
|
||||||
if( strideA == -1 ) strideA = depth;
|
if( strideA == -1 ) strideA = depth;
|
||||||
if( strideB == -1 ) strideB = depth;
|
if( strideB == -1 ) strideB = depth;
|
||||||
@ -183,90 +231,35 @@ void gemmMMAbfloat16(const DataMapper& res, const bfloat16* blockA, const bfloat
|
|||||||
const Index standard_blocks_quantity = rows/standard_block_size; //Number of standard blocks
|
const Index standard_blocks_quantity = rows/standard_block_size; //Number of standard blocks
|
||||||
Index bigSuffix = (2*8) * (strideA-offsetA);
|
Index bigSuffix = (2*8) * (strideA-offsetA);
|
||||||
const bfloat16* indexA = blockA;
|
const bfloat16* indexA = blockA;
|
||||||
const Index offset_factor = 2;
|
const bfloat16* indexB = blockB + 4*offsetB;
|
||||||
Index block_index;
|
Index block_index;
|
||||||
|
strideB *= 4;
|
||||||
|
offsetB *= 3;
|
||||||
for(block_index = 0; block_index < standard_blocks_quantity; block_index++){
|
for(block_index = 0; block_index < standard_blocks_quantity; block_index++){
|
||||||
indexA += 2*8*offsetA;
|
indexA += 2*8*offsetA;
|
||||||
for(Index offset_row = 0; offset_row < standard_block_size; offset_row += 4){ //This block size has 16 rows maximum
|
colLoops<16>(depth, cols, rows, pAlpha, indexA, indexB, strideB, offsetB, result + row);
|
||||||
col = 0;
|
|
||||||
colLoopBody<7, 16>(col, row, depth, cols, rows, offset_row, block_index, pAlpha, indexA+offset_row*offset_factor, strideA, blockB, strideB, offsetB, result);
|
|
||||||
colLoopBody<6, 16>(col, row, depth, cols, rows, offset_row, block_index, pAlpha, indexA+offset_row*offset_factor, strideA, blockB, strideB, offsetB, result);
|
|
||||||
colLoopBody<5, 16>(col, row, depth, cols, rows, offset_row, block_index, pAlpha, indexA+offset_row*offset_factor, strideA, blockB, strideB, offsetB, result);
|
|
||||||
colLoopBody<4, 16>(col, row, depth, cols, rows, offset_row, block_index, pAlpha, indexA+offset_row*offset_factor, strideA, blockB, strideB, offsetB, result);
|
|
||||||
colLoopBody<3, 16>(col, row, depth, cols, rows, offset_row, block_index, pAlpha, indexA+offset_row*offset_factor, strideA, blockB, strideB, offsetB, result);
|
|
||||||
colLoopBody<2, 16>(col, row, depth, cols, rows, offset_row, block_index, pAlpha, indexA+offset_row*offset_factor, strideA, blockB, strideB, offsetB, result);
|
|
||||||
colLoopBody<1, 16>(col, row, depth, cols, rows, offset_row, block_index, pAlpha, indexA+offset_row*offset_factor, strideA, blockB, strideB, offsetB, result);
|
|
||||||
if(cols > col){
|
|
||||||
Index extra_cols= cols-col;
|
|
||||||
//Remember: It doesnt make sense use multiple acc to extra_cols as we are unrolling col loop
|
|
||||||
colLoopBody<1, 16, true>(col, row, depth, cols, rows, offset_row, block_index, pAlpha, indexA+offset_row*offset_factor, strideA, blockB, strideB, offsetB, result, extra_cols, 4);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
row += 16;
|
row += 16;
|
||||||
indexA += bigSuffix;
|
indexA += bigSuffix;
|
||||||
}
|
}
|
||||||
//LHS (8x8) block
|
//LHS (8x8) block
|
||||||
if(rows & 8){
|
if(rows & 8){
|
||||||
indexA += 1*8*offsetA;
|
indexA += 1*8*offsetA;
|
||||||
for(Index offset_row = 0; offset_row < 8; offset_row += 4){
|
colLoops<8>(depth, cols, rows, pAlpha, indexA, indexB, strideB, offsetB, result + row);
|
||||||
col = 0;
|
|
||||||
colLoopBody<7, 8>(col, row, depth, cols, rows, offset_row, block_index, pAlpha, indexA+offset_row*offset_factor, strideA, blockB, strideB, offsetB, result);
|
|
||||||
colLoopBody<6, 8>(col, row, depth, cols, rows, offset_row, block_index, pAlpha, indexA+offset_row*offset_factor, strideA, blockB, strideB, offsetB, result);
|
|
||||||
colLoopBody<5, 8>(col, row, depth, cols, rows, offset_row, block_index, pAlpha, indexA+offset_row*offset_factor, strideA, blockB, strideB, offsetB, result);
|
|
||||||
colLoopBody<4, 8>(col, row, depth, cols, rows, offset_row, block_index, pAlpha, indexA+offset_row*offset_factor, strideA, blockB, strideB, offsetB, result);
|
|
||||||
colLoopBody<3, 8>(col, row, depth, cols, rows, offset_row, block_index, pAlpha, indexA+offset_row*offset_factor, strideA, blockB, strideB, offsetB, result);
|
|
||||||
colLoopBody<2, 8>(col, row, depth, cols, rows, offset_row, block_index, pAlpha, indexA+offset_row*offset_factor, strideA, blockB, strideB, offsetB, result);
|
|
||||||
colLoopBody<1, 8>(col, row, depth, cols, rows, offset_row, block_index, pAlpha, indexA+offset_row*offset_factor, strideA, blockB, strideB, offsetB, result);
|
|
||||||
}
|
|
||||||
if(cols > col){
|
|
||||||
Index extra_cols= cols-col;
|
|
||||||
|
|
||||||
for(Index offset_row = 0; offset_row < 8; offset_row += 4){
|
|
||||||
colLoopBody<1, 8, true>(col, row, depth, cols, rows, offset_row, block_index, pAlpha, indexA+offset_row*offset_factor, strideA, blockB, strideB, offsetB, result, extra_cols, 4);
|
|
||||||
}
|
|
||||||
} //end extra cols
|
|
||||||
row += 8;
|
row += 8;
|
||||||
indexA += (bigSuffix >> 1);
|
indexA += (bigSuffix >> 1);
|
||||||
}
|
}
|
||||||
//LHS (8x4) block
|
//LHS (8x4) block
|
||||||
if(rows & 4){
|
if(rows & 4){
|
||||||
Index offset_row = (rows & 8);
|
|
||||||
indexA += 1*4*offsetA;
|
indexA += 1*4*offsetA;
|
||||||
col = 0;
|
colLoops<4>(depth, cols, rows, pAlpha, indexA, indexB, strideB, offsetB, result + row);
|
||||||
colLoopBody<7, 4>(col, row, depth, cols, rows, offset_row, block_index, pAlpha, indexA, strideA, blockB, strideB, offsetB, result);
|
|
||||||
colLoopBody<6, 4>(col, row, depth, cols, rows, offset_row, block_index, pAlpha, indexA, strideA, blockB, strideB, offsetB, result);
|
|
||||||
colLoopBody<5, 4>(col, row, depth, cols, rows, offset_row, block_index, pAlpha, indexA, strideA, blockB, strideB, offsetB, result);
|
|
||||||
colLoopBody<4, 4>(col, row, depth, cols, rows, offset_row, block_index, pAlpha, indexA, strideA, blockB, strideB, offsetB, result);
|
|
||||||
colLoopBody<3, 4>(col, row, depth, cols, rows, offset_row, block_index, pAlpha, indexA, strideA, blockB, strideB, offsetB, result);
|
|
||||||
colLoopBody<2, 4>(col, row, depth, cols, rows, offset_row, block_index, pAlpha, indexA, strideA, blockB, strideB, offsetB, result);
|
|
||||||
colLoopBody<1, 4>(col, row, depth, cols, rows, offset_row, block_index, pAlpha, indexA, strideA, blockB, strideB, offsetB, result);
|
|
||||||
if(cols > col){
|
|
||||||
Index extra_cols= cols-col;
|
|
||||||
|
|
||||||
colLoopBody<1, 4, true>(col, row, depth, cols, rows, 0, block_index, pAlpha, indexA, strideA, blockB, strideB, offsetB, result, extra_cols, 4);
|
|
||||||
}
|
|
||||||
row += 4;
|
row += 4;
|
||||||
indexA += (bigSuffix >> 2);
|
indexA += (bigSuffix >> 2);
|
||||||
}
|
}
|
||||||
//extra rows
|
//extra rows
|
||||||
if(row < rows){
|
Index extra_rows = rows & 3;
|
||||||
Index extra_rows_or_four = rows-row;
|
if(extra_rows){
|
||||||
|
|
||||||
//This index is the beginning of remaining block.
|
//This index is the beginning of remaining block.
|
||||||
col = 0;
|
colLoops<4, true>(depth, cols, rows, pAlpha, indexA, indexB, strideB, offsetB, result + row, extra_rows);
|
||||||
colLoopBody<7, 8, false, true>(col, row, depth, cols, rows, 0, block_index, pAlpha, blockA, strideA, blockB, strideB, offsetB, result, 4, extra_rows_or_four);
|
|
||||||
colLoopBody<6, 8, false, true>(col, row, depth, cols, rows, 0, block_index, pAlpha, blockA, strideA, blockB, strideB, offsetB, result, 4, extra_rows_or_four);
|
|
||||||
colLoopBody<5, 8, false, true>(col, row, depth, cols, rows, 0, block_index, pAlpha, blockA, strideA, blockB, strideB, offsetB, result, 4, extra_rows_or_four);
|
|
||||||
colLoopBody<4, 8, false, true>(col, row, depth, cols, rows, 0, block_index, pAlpha, blockA, strideA, blockB, strideB, offsetB, result, 4, extra_rows_or_four);
|
|
||||||
colLoopBody<3, 8, false, true>(col, row, depth, cols, rows, 0, block_index, pAlpha, blockA, strideA, blockB, strideB, offsetB, result, 4, extra_rows_or_four);
|
|
||||||
colLoopBody<2, 8, false, true>(col, row, depth, cols, rows, 0, block_index, pAlpha, blockA, strideA, blockB, strideB, offsetB, result, 4, extra_rows_or_four);
|
|
||||||
colLoopBody<1, 8, false, true>(col, row, depth, cols, rows, 0, block_index, pAlpha, blockA, strideA, blockB, strideB, offsetB, result, 4, extra_rows_or_four);
|
|
||||||
if(cols > col){
|
|
||||||
Index extra_cols= cols-col;
|
|
||||||
|
|
||||||
colLoopBody<1, 8, true, true>(col, row, depth, cols, rows, 0, block_index, pAlpha, blockA, strideA, blockB, strideB, offsetB, result, extra_cols, extra_rows_or_four);
|
|
||||||
}
|
|
||||||
row += extra_rows_or_four;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
//Convert back to bfloat16
|
//Convert back to bfloat16
|
||||||
@ -276,9 +269,7 @@ void gemmMMAbfloat16(const DataMapper& res, const bfloat16* blockA, const bfloat
|
|||||||
//get and save block
|
//get and save block
|
||||||
PacketBlock<Packet8bf,4> block;
|
PacketBlock<Packet8bf,4> block;
|
||||||
for(Index j = 0; j < 4; j++){
|
for(Index j = 0; j < 4; j++){
|
||||||
Packet16uc fp16_0 = __builtin_vsx_xvcvspbf16(reinterpret_cast<Packet16uc>(ploadu<Packet4f>(result + (col + j)*rows + row)));
|
block.packet[j].m_val = convertF16toF32(result + (col + j)*rows + row);
|
||||||
Packet16uc fp16_1 = __builtin_vsx_xvcvspbf16(reinterpret_cast<Packet16uc>(ploadu<Packet4f>(result + (col + j)*rows + row + 4)));
|
|
||||||
block.packet[j].m_val = vec_pack(reinterpret_cast<Packet4ui>(fp16_0), reinterpret_cast<Packet4ui>(fp16_1));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
res2.template storePacketBlock<Packet8bf,4>(row, 0, block);
|
res2.template storePacketBlock<Packet8bf,4>(row, 0, block);
|
||||||
@ -295,8 +286,14 @@ void gemmMMAbfloat16(const DataMapper& res, const bfloat16* blockA, const bfloat
|
|||||||
//extra cols
|
//extra cols
|
||||||
while(col < cols){
|
while(col < cols){
|
||||||
const LinearMapper res2 = res.getLinearMapper(0, col);
|
const LinearMapper res2 = res.getLinearMapper(0, col);
|
||||||
for(Index r= 0; r< rows; r++){
|
float *result2 = result + col*rows;
|
||||||
res2(r) = Eigen::bfloat16(result[col*rows + r]);
|
Index r = 0;
|
||||||
|
for(; r + 8 <= rows; r += 8){
|
||||||
|
Packet8bf fp16 = convertF16toF32(result2 + r);
|
||||||
|
res2.template storePacket<Packet8bf>(r, fp16);
|
||||||
|
}
|
||||||
|
for(; r< rows; r++){
|
||||||
|
res2(r) = Eigen::bfloat16(result2[r]);
|
||||||
}
|
}
|
||||||
col++;
|
col++;
|
||||||
}
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user