Improve performance for Power10 MMA bfloat16 GEMM

This commit is contained in:
Chip Kerchner 2023-01-06 23:08:37 +00:00 committed by Rasmus Munk Larsen
parent fe7f527787
commit d20fe21ae4
3 changed files with 536 additions and 228 deletions

View File

@ -841,6 +841,345 @@ struct dhs_pack<double, DataMapper, Packet2d, StorageOrder, PanelMode, false>
} }
}; };
#ifdef __MMA__
// General template for lhs packing, bfloat16 specialization.
template<typename DataMapper, int StorageOrder, bool PanelMode>
struct dhs_pack<bfloat16, DataMapper, Packet8bf, StorageOrder, PanelMode, true>
{
EIGEN_STRONG_INLINE void operator()(bfloat16* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride, Index offset)
{
const Index vectorSize = quad_traits<bfloat16>::vectorsize;
Index ri = 0, j = 0;
for(; j + 2*vectorSize <= rows; j+=2*vectorSize)
{
const DataMapper lhs2 = lhs.getSubMapper(j, 0);
Index i = 0;
if(PanelMode) ri += 2*vectorSize*offset;
if(StorageOrder == ColMajor)
{
for(; i + 2 <= depth; i+=2)
{
PacketBlock<Packet8bf,4> block;
block.packet[0] = lhs2.template loadPacket<Packet8bf>(0 * vectorSize, i + 0);
block.packet[1] = lhs2.template loadPacket<Packet8bf>(1 * vectorSize, i + 0);
block.packet[2] = lhs2.template loadPacket<Packet8bf>(0 * vectorSize, i + 1);
block.packet[3] = lhs2.template loadPacket<Packet8bf>(1 * vectorSize, i + 1);
Packet8bf t0, t1;
t0 = vec_mergeh(block.packet[0].m_val, block.packet[2].m_val);
t1 = vec_mergel(block.packet[0].m_val, block.packet[2].m_val);
block.packet[2] = vec_mergeh(block.packet[1].m_val, block.packet[3].m_val);
block.packet[3] = vec_mergel(block.packet[1].m_val, block.packet[3].m_val);
block.packet[0] = t0;
block.packet[1] = t1;
storeBlock<bfloat16, Packet8bf, 4>(blockA + ri, block);
ri += 2*2*vectorSize;
}
if (depth & 1)
{
PacketBlock<Packet8bf,2> block;
block.packet[0] = lhs2.template loadPacket<Packet8bf>(0 * vectorSize, i + 0);
block.packet[1] = lhs2.template loadPacket<Packet8bf>(1 * vectorSize, i + 0);
storeBlock<bfloat16, Packet8bf, 2>(blockA + ri, block);
ri += 2*vectorSize;
}
} else {
for(; i + vectorSize <= depth; i+=vectorSize)
{
PacketBlock<Packet8bf,8> block1, block2;
bload<DataMapper, Packet8bf, 8, StorageOrder, false, 8>(block1, lhs2, 0 * vectorSize, i);
bload<DataMapper, Packet8bf, 8, StorageOrder, false, 8>(block2, lhs2, 1 * vectorSize, i);
Packet2ul v1[8], v2[8];
v1[0] = reinterpret_cast<Packet2ul>(vec_mergeh(reinterpret_cast<Packet4ui>(block1.packet[0].m_val), reinterpret_cast<Packet4ui>(block1.packet[1].m_val)));
v1[1] = reinterpret_cast<Packet2ul>(vec_mergel(reinterpret_cast<Packet4ui>(block1.packet[0].m_val), reinterpret_cast<Packet4ui>(block1.packet[1].m_val)));
v1[2] = reinterpret_cast<Packet2ul>(vec_mergeh(reinterpret_cast<Packet4ui>(block1.packet[2].m_val), reinterpret_cast<Packet4ui>(block1.packet[3].m_val)));
v1[3] = reinterpret_cast<Packet2ul>(vec_mergel(reinterpret_cast<Packet4ui>(block1.packet[2].m_val), reinterpret_cast<Packet4ui>(block1.packet[3].m_val)));
v1[4] = reinterpret_cast<Packet2ul>(vec_mergeh(reinterpret_cast<Packet4ui>(block1.packet[4].m_val), reinterpret_cast<Packet4ui>(block1.packet[5].m_val)));
v1[5] = reinterpret_cast<Packet2ul>(vec_mergel(reinterpret_cast<Packet4ui>(block1.packet[4].m_val), reinterpret_cast<Packet4ui>(block1.packet[5].m_val)));
v1[6] = reinterpret_cast<Packet2ul>(vec_mergeh(reinterpret_cast<Packet4ui>(block1.packet[6].m_val), reinterpret_cast<Packet4ui>(block1.packet[7].m_val)));
v1[7] = reinterpret_cast<Packet2ul>(vec_mergel(reinterpret_cast<Packet4ui>(block1.packet[6].m_val), reinterpret_cast<Packet4ui>(block1.packet[7].m_val)));
v2[0] = reinterpret_cast<Packet2ul>(vec_mergeh(reinterpret_cast<Packet4ui>(block2.packet[0].m_val), reinterpret_cast<Packet4ui>(block2.packet[1].m_val)));
v2[1] = reinterpret_cast<Packet2ul>(vec_mergel(reinterpret_cast<Packet4ui>(block2.packet[0].m_val), reinterpret_cast<Packet4ui>(block2.packet[1].m_val)));
v2[2] = reinterpret_cast<Packet2ul>(vec_mergeh(reinterpret_cast<Packet4ui>(block2.packet[2].m_val), reinterpret_cast<Packet4ui>(block2.packet[3].m_val)));
v2[3] = reinterpret_cast<Packet2ul>(vec_mergel(reinterpret_cast<Packet4ui>(block2.packet[2].m_val), reinterpret_cast<Packet4ui>(block2.packet[3].m_val)));
v2[4] = reinterpret_cast<Packet2ul>(vec_mergeh(reinterpret_cast<Packet4ui>(block2.packet[4].m_val), reinterpret_cast<Packet4ui>(block2.packet[5].m_val)));
v2[5] = reinterpret_cast<Packet2ul>(vec_mergel(reinterpret_cast<Packet4ui>(block2.packet[4].m_val), reinterpret_cast<Packet4ui>(block2.packet[5].m_val)));
v2[6] = reinterpret_cast<Packet2ul>(vec_mergeh(reinterpret_cast<Packet4ui>(block2.packet[6].m_val), reinterpret_cast<Packet4ui>(block2.packet[7].m_val)));
v2[7] = reinterpret_cast<Packet2ul>(vec_mergel(reinterpret_cast<Packet4ui>(block2.packet[6].m_val), reinterpret_cast<Packet4ui>(block2.packet[7].m_val)));
block1.packet[0] = reinterpret_cast<Packet8us>(vec_mergeh(v1[0],v1[2]));
block1.packet[2] = reinterpret_cast<Packet8us>(vec_mergel(v1[0],v1[2]));
block1.packet[4] = reinterpret_cast<Packet8us>(vec_mergeh(v1[1],v1[3]));
block1.packet[6] = reinterpret_cast<Packet8us>(vec_mergel(v1[1],v1[3]));
block1.packet[1] = reinterpret_cast<Packet8us>(vec_mergeh(v1[4],v1[6]));
block1.packet[3] = reinterpret_cast<Packet8us>(vec_mergel(v1[4],v1[6]));
block1.packet[5] = reinterpret_cast<Packet8us>(vec_mergeh(v1[5],v1[7]));
block1.packet[7] = reinterpret_cast<Packet8us>(vec_mergel(v1[5],v1[7]));
block2.packet[0] = reinterpret_cast<Packet8us>(vec_mergeh(v2[0],v2[2]));
block2.packet[2] = reinterpret_cast<Packet8us>(vec_mergel(v2[0],v2[2]));
block2.packet[4] = reinterpret_cast<Packet8us>(vec_mergeh(v2[1],v2[3]));
block2.packet[6] = reinterpret_cast<Packet8us>(vec_mergel(v2[1],v2[3]));
block2.packet[1] = reinterpret_cast<Packet8us>(vec_mergeh(v2[4],v2[6]));
block2.packet[3] = reinterpret_cast<Packet8us>(vec_mergel(v2[4],v2[6]));
block2.packet[5] = reinterpret_cast<Packet8us>(vec_mergeh(v2[5],v2[7]));
block2.packet[7] = reinterpret_cast<Packet8us>(vec_mergel(v2[5],v2[7]));
for(Index M = 0; M < 8; M+=2) {
pstore<bfloat16>(blockA + ri + (0 * vectorSize) + (2*vectorSize * M), block1.packet[M+0]);
pstore<bfloat16>(blockA + ri + (1 * vectorSize) + (2*vectorSize * M), block1.packet[M+1]);
pstore<bfloat16>(blockA + ri + (2 * vectorSize) + (2*vectorSize * M), block2.packet[M+0]);
pstore<bfloat16>(blockA + ri + (3 * vectorSize) + (2*vectorSize * M), block2.packet[M+1]);
}
ri += 2*vectorSize*vectorSize;
}
for(; i + 2 <= depth; i+=2)
{
for(Index M = 0; M < 2*vectorSize; M++) {
blockA[ri + (M * 2) + 0] = lhs2(M, i + 0);
blockA[ri + (M * 2) + 1] = lhs2(M, i + 1);
}
ri += 2*2*vectorSize;
}
if (depth & 1)
{
for(Index M = 0; M < 2*vectorSize; M++) {
blockA[ri + M] = lhs2(M, i);
}
ri += 2*vectorSize;
}
}
if(PanelMode) ri += 2*vectorSize*(stride - offset - depth);
}
for(; j + vectorSize <= rows; j+=vectorSize)
{
const DataMapper lhs2 = lhs.getSubMapper(j, 0);
Index i = 0;
if(PanelMode) ri += vectorSize*offset;
if(StorageOrder == ColMajor)
{
for(; i + 2 <= depth; i+=2)
{
PacketBlock<Packet8bf,2> block;
block.packet[0] = lhs2.template loadPacket<Packet8bf>(0 * vectorSize, i + 0);
block.packet[1] = lhs2.template loadPacket<Packet8bf>(0 * vectorSize, i + 1);
Packet8bf t0;
t0 = vec_mergeh(block.packet[0].m_val, block.packet[1].m_val);
block.packet[1] = vec_mergel(block.packet[0].m_val, block.packet[1].m_val);
block.packet[0] = t0;
storeBlock<bfloat16, Packet8bf, 2>(blockA + ri, block);
ri += 2*vectorSize;
}
if (depth & 1)
{
Packet8bf lhsV = lhs2.template loadPacket<Packet8bf>(0 * vectorSize, i + 0);
pstore<bfloat16>(blockA + ri, lhsV);
ri += vectorSize;
}
} else {
for(; i + vectorSize <= depth; i+=vectorSize)
{
PacketBlock<Packet8bf,8> block1;
bload<DataMapper, Packet8bf, 8, StorageOrder, false, 8>(block1, lhs2, 0 * vectorSize, i);
Packet2ul v1[8];
v1[0] = reinterpret_cast<Packet2ul>(vec_mergeh(reinterpret_cast<Packet4ui>(block1.packet[0].m_val), reinterpret_cast<Packet4ui>(block1.packet[1].m_val)));
v1[1] = reinterpret_cast<Packet2ul>(vec_mergel(reinterpret_cast<Packet4ui>(block1.packet[0].m_val), reinterpret_cast<Packet4ui>(block1.packet[1].m_val)));
v1[2] = reinterpret_cast<Packet2ul>(vec_mergeh(reinterpret_cast<Packet4ui>(block1.packet[2].m_val), reinterpret_cast<Packet4ui>(block1.packet[3].m_val)));
v1[3] = reinterpret_cast<Packet2ul>(vec_mergel(reinterpret_cast<Packet4ui>(block1.packet[2].m_val), reinterpret_cast<Packet4ui>(block1.packet[3].m_val)));
v1[4] = reinterpret_cast<Packet2ul>(vec_mergeh(reinterpret_cast<Packet4ui>(block1.packet[4].m_val), reinterpret_cast<Packet4ui>(block1.packet[5].m_val)));
v1[5] = reinterpret_cast<Packet2ul>(vec_mergel(reinterpret_cast<Packet4ui>(block1.packet[4].m_val), reinterpret_cast<Packet4ui>(block1.packet[5].m_val)));
v1[6] = reinterpret_cast<Packet2ul>(vec_mergeh(reinterpret_cast<Packet4ui>(block1.packet[6].m_val), reinterpret_cast<Packet4ui>(block1.packet[7].m_val)));
v1[7] = reinterpret_cast<Packet2ul>(vec_mergel(reinterpret_cast<Packet4ui>(block1.packet[6].m_val), reinterpret_cast<Packet4ui>(block1.packet[7].m_val)));
block1.packet[0] = reinterpret_cast<Packet8us>(vec_mergeh(v1[0],v1[2]));
block1.packet[2] = reinterpret_cast<Packet8us>(vec_mergel(v1[0],v1[2]));
block1.packet[4] = reinterpret_cast<Packet8us>(vec_mergeh(v1[1],v1[3]));
block1.packet[6] = reinterpret_cast<Packet8us>(vec_mergel(v1[1],v1[3]));
block1.packet[1] = reinterpret_cast<Packet8us>(vec_mergeh(v1[4],v1[6]));
block1.packet[3] = reinterpret_cast<Packet8us>(vec_mergel(v1[4],v1[6]));
block1.packet[5] = reinterpret_cast<Packet8us>(vec_mergeh(v1[5],v1[7]));
block1.packet[7] = reinterpret_cast<Packet8us>(vec_mergel(v1[5],v1[7]));
for(Index M = 0; M < 8; M++) {
pstore<bfloat16>(blockA + ri + (vectorSize * M), block1.packet[M]);
}
ri += vectorSize*vectorSize;
}
for(; i + 2 <= depth; i+=2)
{
for(Index M = 0; M < vectorSize; M++) {
blockA[ri + (M * 2) + 0] = lhs2(M, i + 0);
blockA[ri + (M * 2) + 1] = lhs2(M, i + 1);
}
ri += 2*vectorSize;
}
if (depth & 1)
{
for(Index M = 0; M < vectorSize; M++) {
blockA[ri + M] = lhs2(M, i);
}
ri += vectorSize;
}
}
if(PanelMode) ri += vectorSize*(stride - offset - depth);
}
if(PanelMode) ri += offset;
for(; j < rows; j++)
{
const DataMapper lhs2 = lhs.getSubMapper(j, 0);
for(Index i = 0; i < depth; i++)
{
blockA[ri] = lhs2(0, i);
ri += 1;
}
if(PanelMode) ri += stride - depth;
}
}
};
// General template for rhs packing, bfloat16 specialization.
template<typename DataMapper, int StorageOrder, bool PanelMode>
struct dhs_pack<bfloat16, DataMapper, Packet8bf, StorageOrder, PanelMode, false>
{
EIGEN_STRONG_INLINE void operator()(bfloat16* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride, Index offset)
{
const Index vectorSize = quad_traits<bfloat16>::vectorsize;
Index ri = 0, j = 0;
for(; j + 4 <= cols; j+=4)
{
const DataMapper rhs2 = rhs.getSubMapper(0, j);
Index i = 0;
if(PanelMode) ri += 4*offset;
for(; i + vectorSize <= depth; i+=vectorSize)
{
if(StorageOrder == ColMajor)
{
PacketBlock<Packet8bf,4> block;
bload<DataMapper, Packet8bf, 4, StorageOrder, false, 4>(block, rhs2, i, 0);
Packet2ul t0, t1, t2, t3;
t0 = reinterpret_cast<Packet2ul>(vec_mergeh(reinterpret_cast<Packet4ui>(block.packet[0].m_val), reinterpret_cast<Packet4ui>(block.packet[1].m_val)));
t1 = reinterpret_cast<Packet2ul>(vec_mergeh(reinterpret_cast<Packet4ui>(block.packet[2].m_val), reinterpret_cast<Packet4ui>(block.packet[3].m_val)));
t2 = reinterpret_cast<Packet2ul>(vec_mergel(reinterpret_cast<Packet4ui>(block.packet[0].m_val), reinterpret_cast<Packet4ui>(block.packet[1].m_val)));
t3 = reinterpret_cast<Packet2ul>(vec_mergel(reinterpret_cast<Packet4ui>(block.packet[2].m_val), reinterpret_cast<Packet4ui>(block.packet[3].m_val)));
block.packet[0] = reinterpret_cast<Packet8us>(vec_mergeh(t0, t1));
block.packet[1] = reinterpret_cast<Packet8us>(vec_mergel(t0, t1));
block.packet[2] = reinterpret_cast<Packet8us>(vec_mergeh(t2, t3));
block.packet[3] = reinterpret_cast<Packet8us>(vec_mergel(t2, t3));
storeBlock<bfloat16, Packet8bf, 4>(blockB + ri, block);
} else {
PacketBlock<Packet8bf,8> block;
for (int M = 0; M < 8; M++) {
block.packet[M] = rhs2.template loadPacketPartial<Packet8bf>(i + M, 0, 4);
}
block.packet[0] = vec_mergeh(block.packet[0].m_val, block.packet[1].m_val);
block.packet[1] = vec_mergeh(block.packet[2].m_val, block.packet[3].m_val);
block.packet[2] = vec_mergeh(block.packet[4].m_val, block.packet[5].m_val);
block.packet[3] = vec_mergeh(block.packet[6].m_val, block.packet[7].m_val);
const Index size = 16 / sizeof(bfloat16);
for (int M = 0; M < 4; M++) {
pstore<bfloat16>(blockB + ri + (M * size), block.packet[M]);
}
}
ri += 4*vectorSize;
}
for (; i + 2 <= depth; i += 2) {
if(StorageOrder == ColMajor)
{
blockB[ri+0] = rhs2(i + 0, 0);
blockB[ri+1] = rhs2(i + 1, 0);
blockB[ri+2] = rhs2(i + 0, 1);
blockB[ri+3] = rhs2(i + 1, 1);
blockB[ri+4] = rhs2(i + 0, 2);
blockB[ri+5] = rhs2(i + 1, 2);
blockB[ri+6] = rhs2(i + 0, 3);
blockB[ri+7] = rhs2(i + 1, 3);
} else {
PacketBlock<Packet8bf,2> block;
for (int M = 0; M < 2; M++) {
block.packet[M] = rhs2.template loadPacketPartial<Packet8bf>(i + M, 0, 4);
}
block.packet[0] = vec_mergeh(block.packet[0].m_val, block.packet[1].m_val);
pstore<bfloat16>(blockB + ri, block.packet[0]);
}
ri += 4*2;
}
if (depth & 1)
{
blockB[ri+0] = rhs2(i, 0);
blockB[ri+1] = rhs2(i, 1);
blockB[ri+2] = rhs2(i, 2);
blockB[ri+3] = rhs2(i, 3);
ri += 4;
}
if(PanelMode) ri += 4*(stride - offset - depth);
}
if(PanelMode) ri += offset;
for(; j < cols; j++)
{
const DataMapper rhs2 = rhs.getSubMapper(0, j);
for(Index i = 0; i < depth; i++)
{
blockB[ri] = rhs2(i, 0);
ri += 1;
}
if(PanelMode) ri += stride - depth;
}
}
};
#endif
// General template for lhs complex packing, float64 specialization. // General template for lhs complex packing, float64 specialization.
template<typename DataMapper, typename Packet, typename PacketC, int StorageOrder, bool Conjugate, bool PanelMode> template<typename DataMapper, typename Packet, typename PacketC, int StorageOrder, bool Conjugate, bool PanelMode>
struct dhs_cpack<double, DataMapper, Packet, PacketC, StorageOrder, Conjugate, PanelMode, true> struct dhs_cpack<double, DataMapper, Packet, PacketC, StorageOrder, Conjugate, PanelMode, true>
@ -2322,6 +2661,64 @@ void gemm_pack_rhs<double, Index, DataMapper, nr, RowMajor, Conjugate, PanelMode
} }
#endif #endif
#ifdef __MMA__
template<typename Index, typename DataMapper, int nr, bool Conjugate, bool PanelMode>
struct gemm_pack_rhs<bfloat16, Index, DataMapper, nr, ColMajor, Conjugate, PanelMode>
{
void operator()(bfloat16* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride=0, Index offset=0);
};
template<typename Index, typename DataMapper, int nr, bool Conjugate, bool PanelMode>
void gemm_pack_rhs<bfloat16, Index, DataMapper, nr, ColMajor, Conjugate, PanelMode>
::operator()(bfloat16* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride, Index offset)
{
dhs_pack<bfloat16, DataMapper, Packet8bf, ColMajor, PanelMode, false> pack;
pack(blockB, rhs, depth, cols, stride, offset);
}
template<typename Index, typename DataMapper, int nr, bool Conjugate, bool PanelMode>
struct gemm_pack_rhs<bfloat16, Index, DataMapper, nr, RowMajor, Conjugate, PanelMode>
{
void operator()(bfloat16* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride=0, Index offset=0);
};
template<typename Index, typename DataMapper, int nr, bool Conjugate, bool PanelMode>
void gemm_pack_rhs<bfloat16, Index, DataMapper, nr, RowMajor, Conjugate, PanelMode>
::operator()(bfloat16* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride, Index offset)
{
dhs_pack<bfloat16, DataMapper, Packet8bf, RowMajor, PanelMode, false> pack;
pack(blockB, rhs, depth, cols, stride, offset);
}
template<typename Index, typename DataMapper, int Pack1, int Pack2, typename Packet, bool Conjugate, bool PanelMode>
struct gemm_pack_lhs<bfloat16, Index, DataMapper, Pack1, Pack2, Packet, ColMajor, Conjugate, PanelMode>
{
void operator()(bfloat16* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride=0, Index offset=0);
};
template<typename Index, typename DataMapper, int Pack1, int Pack2, typename Packet, bool Conjugate, bool PanelMode>
void gemm_pack_lhs<bfloat16, Index, DataMapper, Pack1, Pack2, Packet, ColMajor, Conjugate, PanelMode>
::operator()(bfloat16* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride, Index offset)
{
dhs_pack<bfloat16, DataMapper, Packet8bf, ColMajor, PanelMode, true> pack;
pack(blockA, lhs, depth, rows, stride, offset);
}
template<typename Index, typename DataMapper, int Pack1, int Pack2, typename Packet, bool Conjugate, bool PanelMode>
struct gemm_pack_lhs<bfloat16, Index, DataMapper, Pack1, Pack2, Packet, RowMajor, Conjugate, PanelMode>
{
void operator()(bfloat16* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride=0, Index offset=0);
};
template<typename Index, typename DataMapper, int Pack1, int Pack2, typename Packet, bool Conjugate, bool PanelMode>
void gemm_pack_lhs<bfloat16, Index, DataMapper, Pack1, Pack2, Packet, RowMajor, Conjugate, PanelMode>
::operator()(bfloat16* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride, Index offset)
{
dhs_pack<bfloat16, DataMapper, Packet8bf, RowMajor, PanelMode, true> pack;
pack(blockA, lhs, depth, rows, stride, offset);
}
#endif
template<typename Index, typename DataMapper, int Pack1, int Pack2, typename Packet, bool Conjugate, bool PanelMode> template<typename Index, typename DataMapper, int Pack1, int Pack2, typename Packet, bool Conjugate, bool PanelMode>
struct gemm_pack_lhs<float, Index, DataMapper, Pack1, Pack2, Packet, RowMajor, Conjugate, PanelMode> struct gemm_pack_lhs<float, Index, DataMapper, Pack1, Pack2, Packet, RowMajor, Conjugate, PanelMode>
{ {

View File

@ -1,117 +1,64 @@
#ifndef EIGEN_MATRIX_PRODUCT_MMA_BFLOAT16_ALTIVEC_H #ifndef EIGEN_MATRIX_PRODUCT_MMA_BFLOAT16_ALTIVEC_H
#define EIGEN_MATRIX_PRODUCT_MMA_BFLOAT16_ALTIVEC_H #define EIGEN_MATRIX_PRODUCT_MMA_BFLOAT16_ALTIVEC_H
#if EIGEN_COMP_LLVM
#define BFLOAT16_UNROLL _Pragma("unroll 8")
#else
#define BFLOAT16_UNROLL _Pragma("GCC unroll(8)")
#endif
namespace Eigen { namespace Eigen {
namespace internal { namespace internal {
EIGEN_STRONG_INLINE void pgerMMAbfloat16(__vector_quad* acc, const Packet8bf& a, const Packet8bf& b, int maskX, int maskY) EIGEN_ALWAYS_INLINE void scaleAndStore(float* result, Packet4f& acc, const Packet4f& pAlpha)
{
switch(maskX){
case 15:
switch(maskY){
case 0b1111:
__builtin_mma_xvbf16ger2pp(acc, reinterpret_cast<Packet16uc>(a.m_val), reinterpret_cast<Packet16uc>(b.m_val));
break;
case 0b0011:
__builtin_mma_pmxvbf16ger2pp(acc, reinterpret_cast<Packet16uc>(a.m_val), reinterpret_cast<Packet16uc>(b.m_val), 0b1111, 0b11, 0b11);
break;
case 0b0001:
__builtin_mma_pmxvbf16ger2pp(acc, reinterpret_cast<Packet16uc>(a.m_val), reinterpret_cast<Packet16uc>(b.m_val), 0b1111, 0b1, 0b11);
break;
case 0b0111:
__builtin_mma_pmxvbf16ger2pp(acc, reinterpret_cast<Packet16uc>(a.m_val), reinterpret_cast<Packet16uc>(b.m_val), 0b1111, 0b111, 0b11);
break;
}
break;
case 3:
switch(maskY){
case 0b1111:
__builtin_mma_xvbf16ger2pp(acc, reinterpret_cast<Packet16uc>(a.m_val), reinterpret_cast<Packet16uc>(b.m_val));
break;
case 0b0011:
__builtin_mma_pmxvbf16ger2pp(acc, reinterpret_cast<Packet16uc>(a.m_val), reinterpret_cast<Packet16uc>(b.m_val), 0b11, 0b11, 0b11);
break;
case 0b0001:
__builtin_mma_pmxvbf16ger2pp(acc, reinterpret_cast<Packet16uc>(a.m_val), reinterpret_cast<Packet16uc>(b.m_val), 0b11, 0b1, 0b11);
break;
case 0b0111:
__builtin_mma_pmxvbf16ger2pp(acc, reinterpret_cast<Packet16uc>(a.m_val), reinterpret_cast<Packet16uc>(b.m_val), 0b11, 0b111, 0b11);
break;
}
break;
case 1:
switch(maskY){
case 0b1111:
__builtin_mma_xvbf16ger2pp(acc, reinterpret_cast<Packet16uc>(a.m_val), reinterpret_cast<Packet16uc>(b.m_val));
break;
case 0b0011:
__builtin_mma_pmxvbf16ger2pp(acc, reinterpret_cast<Packet16uc>(a.m_val), reinterpret_cast<Packet16uc>(b.m_val), 0b1, 0b11, 0b11);
break;
case 0b0001:
__builtin_mma_pmxvbf16ger2pp(acc, reinterpret_cast<Packet16uc>(a.m_val), reinterpret_cast<Packet16uc>(b.m_val), 0b1, 0b1, 0b11);
break;
case 0b0111:
__builtin_mma_pmxvbf16ger2pp(acc, reinterpret_cast<Packet16uc>(a.m_val), reinterpret_cast<Packet16uc>(b.m_val), 0b1, 0b111, 0b11);
break;
}
break;
case 0b0111:
switch(maskY){
case 0b1111:
__builtin_mma_xvbf16ger2pp(acc, reinterpret_cast<Packet16uc>(a.m_val), reinterpret_cast<Packet16uc>(b.m_val));
break;
case 0b0011:
__builtin_mma_pmxvbf16ger2pp(acc, reinterpret_cast<Packet16uc>(a.m_val), reinterpret_cast<Packet16uc>(b.m_val), 0b111, 0b11, 0b11);
break;
case 0b0001:
__builtin_mma_pmxvbf16ger2pp(acc, reinterpret_cast<Packet16uc>(a.m_val), reinterpret_cast<Packet16uc>(b.m_val), 0b111, 0b1, 0b11);
break;
case 0b0111:
__builtin_mma_pmxvbf16ger2pp(acc, reinterpret_cast<Packet16uc>(a.m_val), reinterpret_cast<Packet16uc>(b.m_val), 0b111, 0b111, 0b11);
break;
}
break;
}
}
EIGEN_STRONG_INLINE void scaleAndStore(float* result, float* acc, Packet4f pAlpha)
{ {
Packet4f result_block = ploadu<Packet4f>(result); Packet4f result_block = ploadu<Packet4f>(result);
Packet4f packet_pmadd = pmadd(pload<Packet4f>(acc), pAlpha, result_block); result_block = pmadd(acc, pAlpha, result_block);
pstoreu(result, packet_pmadd); pstoreu(result, result_block);
} }
template<int num_packets, bool zero> template<Index num_packets, bool zero>
EIGEN_STRONG_INLINE Packet8bf loadLhsBfloat16(const bfloat16* indexA) EIGEN_ALWAYS_INLINE Packet8bf loadLhsBfloat16(const bfloat16* indexA)
{ {
Packet8bf lhs1 = ploadu<Packet8bf>(indexA); Packet8bf lhs1 = ploadu<Packet8bf>(indexA);
Packet8bf lhs2;
const int packet_size = 8; //We fit 8 bfloat16 on a 128 register
if(zero){ if(zero){
lhs2 = pset1<Packet8bf>(Eigen::bfloat16(0)); Packet8bf lhs2 = pset1<Packet8bf>(Eigen::bfloat16(0));
return vec_mergeh(lhs1.m_val, lhs2.m_val);
} else {
return lhs1;
} }
else lhs2 = ploadu<Packet8bf>(indexA + num_packets*packet_size);
return vec_mergeh(lhs1.m_val, lhs2.m_val);
} }
template<bool zero> template<bool zero>
EIGEN_STRONG_INLINE Packet8bf loadLhsBfloat16ExtraRows(const bfloat16* indexA, Index strideA, Index row, int extra_rows) EIGEN_ALWAYS_INLINE Packet8bf loadBfloat16Extra(const bfloat16* indexA, Index strideA, Index extra_rows)
{ {
EIGEN_ALIGN16 bfloat16 lhs_array[8] = {Eigen::bfloat16(0), Eigen::bfloat16(0), Eigen::bfloat16(0), Eigen::bfloat16(0), Eigen::bfloat16(0), Eigen::bfloat16(0), Eigen::bfloat16(0), Eigen::bfloat16(0)}; Index row_count = 0;
int count = 0; if (zero) {
const bfloat16* idxA = indexA + row*strideA; EIGEN_ALIGN16 bfloat16 lhs_array[8] = { Eigen::bfloat16(0) };
for(int row_count = 0; row_count < extra_rows; row_count++){ do{
lhs_array[count++] = *idxA; lhs_array[row_count] = *indexA;
if(!zero) lhs_array[count] = *(idxA+1); indexA += strideA;
count++; } while ((row_count += 2) < extra_rows*2);
idxA += strideA; return pload_partial<Packet8bf>(lhs_array, extra_rows*2);
} else {
EIGEN_ALIGN16 int lhs_array[4];
do{
lhs_array[row_count] = *reinterpret_cast<const int *>(indexA);
indexA += strideA;
} while ((row_count += 1) < extra_rows);
return reinterpret_cast<Packet8us>(pload_partial<Packet4i>(lhs_array, extra_rows));
} }
return pload<Packet8bf>(lhs_array);
} }
template<bool zero> template<bool zero>
EIGEN_STRONG_INLINE Packet8bf loadRhsBfloat16(const bfloat16* baseB, Index strideB, int i, int k) EIGEN_ALWAYS_INLINE Packet8bf loadLhsBfloat16ExtraRows(const bfloat16* indexA, Index strideA, Index row, Index extra_rows)
{
return loadBfloat16Extra<zero>(indexA + row*strideA, strideA, extra_rows);
}
template<bool zero>
EIGEN_ALWAYS_INLINE Packet8bf loadRhsBfloat16(const bfloat16* baseB, Index strideB, Index i, Index k)
{ {
const bfloat16* indexB = baseB + strideB*4*i + (k*4); const bfloat16* indexB = baseB + strideB*4*i + (k*4);
Packet8bf rhs1 = ploadu<Packet8bf>(indexB); Packet8bf rhs1 = ploadu<Packet8bf>(indexB);
@ -119,28 +66,16 @@ EIGEN_STRONG_INLINE Packet8bf loadRhsBfloat16(const bfloat16* baseB, Index strid
Packet8bf rhs2 = pset1<Packet8bf>(Eigen::bfloat16(0)); Packet8bf rhs2 = pset1<Packet8bf>(Eigen::bfloat16(0));
return vec_mergeh(rhs1.m_val, rhs2.m_val); return vec_mergeh(rhs1.m_val, rhs2.m_val);
} }
//r = vec_perm (a, b, c) return rhs1;
//Let v be the concatenation of a and b.
//Each byte of r selected by using the least-significant 5 bits of the corresponding byte of c as an index into v
//We need this elements from rhs: 0, 4, 1, 5, 2, 6, 3, 7
Packet16uc c = {0x0u, 0x1u, 0x8u, 0x9u, 0x2u, 0x3u, 0xAu, 0xB, 0x4, 0x5, 0xCu, 0xDu, 0x6u, 0x7u, 0xEu, 0xFu};
return vec_perm(rhs1.m_val, rhs1.m_val, c);
} }
template<bool zero> template<bool zero>
EIGEN_STRONG_INLINE Packet8bf loadRhsBfloat16ExtraCols(const bfloat16* blockB, Index strideB, Index offsetB, Index col, int i, int k, int extra_cols) EIGEN_ALWAYS_INLINE Packet8bf loadRhsBfloat16ExtraCols(const bfloat16* blockB, Index strideB, Index offsetB, Index col, Index i, Index k, Index extra_cols)
{ {
EIGEN_ALIGN16 bfloat16 rhs_vector[8] = {Eigen::bfloat16(0), Eigen::bfloat16(0), Eigen::bfloat16(0), Eigen::bfloat16(0), Eigen::bfloat16(0), Eigen::bfloat16(0), Eigen::bfloat16(0), Eigen::bfloat16(0)}; return loadBfloat16Extra<zero>(blockB + ((col+4*i)*strideB)+k+offsetB, strideB, extra_cols);
const bfloat16* indexB = blockB + ((col+4*i)*strideB)+k+offsetB;
for(int c = 0; c < extra_cols; c++){
rhs_vector[2*c] = *indexB;
if(!zero) rhs_vector[2*c+1] = *(indexB+1);
indexB += strideB;
}
return pload<Packet8bf>(rhs_vector);
} }
template<int num_acc, int num_packets, bool zero, bool rhs_extra_cols, bool lhs_extra_rows> template<Index num_acc, Index num_packets, bool zero, bool rhs_extra_cols, bool lhs_extra_rows>
EIGEN_STRONG_INLINE void KLoop EIGEN_STRONG_INLINE void KLoop
( (
const bfloat16* indexA, const bfloat16* indexA,
@ -152,107 +87,95 @@ EIGEN_STRONG_INLINE void KLoop
Index k, Index k,
Index row, Index row,
Index col, Index col,
int extra_rows, Index extra_rows,
int extra_cols, Index extra_cols
int mask_rows = 0xF,
int mask_cols = 0xF
) )
{ {
Packet8bf lhs; Packet8bf lhs;
Packet8bf rhs[num_acc]; Packet8bf rhs[num_acc];
if(lhs_extra_rows) lhs = loadLhsBfloat16ExtraRows<zero>(indexA+k, strideA, row, extra_rows); if(lhs_extra_rows) lhs = loadLhsBfloat16ExtraRows<zero>(indexA+k, strideA, row, extra_rows);
else lhs = loadLhsBfloat16<num_packets, zero>(indexA + k*num_packets*8); //a packet of bfloat16 has 8 elements else lhs = loadLhsBfloat16<num_packets, zero>(indexA + k*num_packets); //a packet of bfloat16 has 8 elements
for(int i = 0; i < num_acc; i++){ BFLOAT16_UNROLL
for(Index i = 0; i < num_acc; i++){
if(!rhs_extra_cols) if(!rhs_extra_cols)
rhs[i] = loadRhsBfloat16<zero>(indexB, strideB, i, k); rhs[i] = loadRhsBfloat16<zero>(indexB, strideB, i, k);
else{ else{
rhs[i] = loadRhsBfloat16ExtraCols<zero>(indexB, strideB, offsetB, col, i, k, extra_cols); rhs[i] = loadRhsBfloat16ExtraCols<zero>(indexB, strideB, offsetB, col, i, k, extra_cols);
} }
pgerMMAbfloat16(&(quad_acc[i]), rhs[i], lhs, mask_cols, mask_rows); __builtin_mma_xvbf16ger2pp(&(quad_acc[i]), reinterpret_cast<Packet16uc>(rhs[i].m_val), reinterpret_cast<Packet16uc>(lhs.m_val));
} }
} }
template<const int num_acc, const int standard_block_size, const int num_packets, bool rhsExtraCols = false, bool lhsExtraRows = false> template<const Index num_acc, const Index num_packets, bool rhsExtraCols = false, bool lhsExtraRows = false>
void colLoopBody(Index* p_col, Index row, Index depth, Index cols, Index rows, int offset_row, int block_index, Packet4f pAlpha, const bfloat16* indexA, Index strideA, const bfloat16* blockB, Index strideB, Index offsetB, float* result, int extra_cols = 0, int extra_rows = 0, int mask_cols = 0xF, int mask_rows = 0xF) void colLoopBody(Index& col, Index row, Index depth, Index cols, Index rows, Index offset_row, Index block_index, const Packet4f& pAlpha, const bfloat16* indexA, Index strideA, const bfloat16* blockB, Index strideB, Index offsetB, float* result, Index extra_cols = 0, Index extra_rows = 0)
{ {
int col = *p_col; const Index step = rhsExtraCols ? 1 : (num_acc * 4); //each accumulator has 4 elements
int count; const bfloat16* indexB = rhsExtraCols ? blockB : (blockB + 4*offsetB + strideB*col);
int max, step, bound;
const bfloat16* indexB;
if(num_acc == 1) bound = 0; while(col + step <= cols){
else bound = 1;
if(rhsExtraCols){
count = 0;
max = 1;
step = 1;
indexB = blockB;
}
else{
count = col;
step = num_acc * 4; //each accumulator has 4 elements
max = cols/step;
indexB = blockB + 4*offsetB + strideB*col;
}
while(count/step + bound < max){
Index k = 0; Index k = 0;
EIGEN_ALIGN32 float acc[num_acc][4][4]; Packet4f acc[num_acc][4];
__vector_quad quad_acc[num_acc]; __vector_quad quad_acc[num_acc];
for(int i = 0; i < num_acc; i++) BFLOAT16_UNROLL
for(Index i = 0; i < num_acc; i++)
__builtin_mma_xxsetaccz(&(quad_acc[i])); __builtin_mma_xxsetaccz(&(quad_acc[i]));
if(depth%2 != 0){ for(; k + 2 <= depth; k += 2){
KLoop<num_acc, num_packets, true, rhsExtraCols, lhsExtraRows>(indexA, indexB, quad_acc, strideA, strideB, offsetB, k, row, col, extra_rows, extra_cols, mask_rows, mask_cols); KLoop<num_acc, num_packets, false, rhsExtraCols, lhsExtraRows>(indexA, indexB, quad_acc, strideA, strideB, offsetB, k, row, col, extra_rows, extra_cols);
k = 1;
} }
for(; k/2 < depth/2; k += 2){ if(depth&1){
KLoop<num_acc, num_packets, false, rhsExtraCols, lhsExtraRows>(indexA, indexB, quad_acc, strideA, strideB, offsetB, k, row, col, extra_rows, extra_cols, mask_rows, mask_cols); KLoop<num_acc, num_packets, true, rhsExtraCols, lhsExtraRows>(indexA-offset_row, indexB, quad_acc, strideA, strideB, offsetB, k, row, col, extra_rows, extra_cols);
} }
for(int i = 0; i < num_acc; i++){
BFLOAT16_UNROLL
for(Index i = 0; i < num_acc; i++)
__builtin_mma_disassemble_acc((void*)acc[i], &(quad_acc[i])); __builtin_mma_disassemble_acc((void*)acc[i], &(quad_acc[i]));
for(Index i = 0; i < num_acc; i++){
if(lhsExtraRows){ if(lhsExtraRows){
for(int x = 0; x < extra_cols; x++){ float *r = result + (col+i*4)*rows + row;
for(int y = 0; y < extra_rows; y++){ for(Index x = 0; x < extra_cols; x++, r += rows){
result[((col+i*4)+x)*rows + row + y] += acc[i][x][y]*(pAlpha[0]); Packet4f result_block = ploadu_partial<Packet4f>(r, extra_rows);
} result_block = pmadd(acc[i][x], pAlpha, result_block);
pstoreu_partial<float>(r, result_block, extra_rows);
} }
} }
else{ else{
if(rhsExtraCols){ if(rhsExtraCols){
for(int x = 0; x < cols-col; x++){ float *r = result + (col+i*4)*rows + row + offset_row;
scaleAndStore(result + ((col+i*4)+x)*rows + row + offset_row,acc[i][x], pAlpha); for(Index x = 0; x < cols-col; x++, r += rows){
scaleAndStore(r,acc[i][x], pAlpha);
} }
} }
else{ else{
for(int x = 0; x < 4; x++){ float *r = result + (col+i*4)*rows + (block_index*16) + offset_row;
scaleAndStore(result + ((col+i*4)+x)*rows + (block_index*16) + offset_row,acc[i][x], pAlpha); for(Index x = 0; x < 4; x++, r += rows){
scaleAndStore(r,acc[i][x], pAlpha);
} }
} }
} }
} }
count += step; if(rhsExtraCols) return;
if(!rhsExtraCols) { indexB += strideB*step;
indexB += strideB*step; col += step;
col += step;
}
} }
*p_col = col;
} }
template<typename Index, typename Packet, typename RhsPacket, typename DataMapper, const Index accRows, const Index accCols> template<typename Index, typename Packet, typename RhsPacket, typename DataMapper, const Index accRows, const Index accCols>
void gemmMMAbfloat16(const DataMapper& res, const bfloat16* blockA, const bfloat16* blockB, Index rows, Index depth, Index cols, bfloat16 alpha, Index strideA, Index strideB, Index offsetA, Index offsetB) void gemmMMAbfloat16(const DataMapper& res, const bfloat16* blockA, const bfloat16* blockB, Index rows, Index depth, Index cols, bfloat16 alpha, Index strideA, Index strideB, Index offsetA, Index offsetB)
{ {
if(rows == 0 || cols == 0 || depth == 0) return; if(rows == 0 || cols == 0 || depth == 0) return;
const Packet4f pAlpha = pset1<Packet4f>(Eigen::bfloat16_impl::bfloat16_to_float(alpha)); float falpha = Eigen::bfloat16_impl::bfloat16_to_float(alpha);
if (falpha == float(0)) return;
const Packet4f pAlpha = pset1<Packet4f>(falpha);
ei_declare_aligned_stack_constructed_variable(float, result, cols*rows, 0); ei_declare_aligned_stack_constructed_variable(float, result, cols*rows, 0);
for(int j = 0; j < cols; j++){ typedef typename DataMapper::LinearMapper LinearMapper;
for(int i = 0; i < rows; i++){ for(Index j = 0; j < cols; j++){
result[j*rows + i] = res(i,j); const LinearMapper res2 = res.getLinearMapper(0, j);
for(Index i = 0; i < rows; i++){
result[j*rows + i] = res2(i);
} }
} }
@ -268,26 +191,27 @@ void gemmMMAbfloat16(const DataMapper& res, const bfloat16* blockA, const bfloat
//Blocks of 8 columns with <8 elements as row major. This happens when there's less than 8 remaining rows //Blocks of 8 columns with <8 elements as row major. This happens when there's less than 8 remaining rows
//Loop for LHS standard block (8x16) //Loop for LHS standard block (8x16)
int standard_block_size = 16; const Index standard_block_size = 16;
const int standard_blocks_quantity = rows/standard_block_size; //Number of standard blocks const Index standard_blocks_quantity = rows/standard_block_size; //Number of standard blocks
int bigSuffix = (2*8) * (strideA-offsetA-depth); Index bigSuffix = (2*8) * (strideA-offsetA-depth);
const bfloat16* indexA = blockA; const bfloat16* indexA = blockA;
int block_index; const Index offset_factor = 2;
Index block_index;
for(block_index = 0; block_index < standard_blocks_quantity; block_index++){ for(block_index = 0; block_index < standard_blocks_quantity; block_index++){
indexA += 2*8*offsetA; indexA += 2*8*offsetA;
for(int offset_row = 0; offset_row < standard_block_size; offset_row += 4){ //This block size has 16 rows maximum for(Index offset_row = 0; offset_row < standard_block_size; offset_row += 4){ //This block size has 16 rows maximum
col = 0; col = 0;
colLoopBody<5, 16, 2>(&col, row, depth, cols, rows, offset_row, block_index, pAlpha, indexA+offset_row, strideA, blockB, strideB, offsetB, result); colLoopBody<7, 16>(col, row, depth, cols, rows, offset_row, block_index, pAlpha, indexA+offset_row*offset_factor, strideA, blockB, strideB, offsetB, result);
colLoopBody<4, 16, 2>(&col, row, depth, cols, rows, offset_row, block_index, pAlpha, indexA+offset_row, strideA, blockB, strideB, offsetB, result); colLoopBody<6, 16>(col, row, depth, cols, rows, offset_row, block_index, pAlpha, indexA+offset_row*offset_factor, strideA, blockB, strideB, offsetB, result);
colLoopBody<3, 16, 2>(&col, row, depth, cols, rows, offset_row, block_index, pAlpha, indexA+offset_row, strideA, blockB, strideB, offsetB, result); colLoopBody<5, 16>(col, row, depth, cols, rows, offset_row, block_index, pAlpha, indexA+offset_row*offset_factor, strideA, blockB, strideB, offsetB, result);
colLoopBody<2, 16, 2>(&col, row, depth, cols, rows, offset_row, block_index, pAlpha, indexA+offset_row, strideA, blockB, strideB, offsetB, result); colLoopBody<4, 16>(col, row, depth, cols, rows, offset_row, block_index, pAlpha, indexA+offset_row*offset_factor, strideA, blockB, strideB, offsetB, result);
colLoopBody<1, 16, 2>(&col, row, depth, cols, rows, offset_row, block_index, pAlpha, indexA+offset_row, strideA, blockB, strideB, offsetB, result); colLoopBody<3, 16>(col, row, depth, cols, rows, offset_row, block_index, pAlpha, indexA+offset_row*offset_factor, strideA, blockB, strideB, offsetB, result);
colLoopBody<2, 16>(col, row, depth, cols, rows, offset_row, block_index, pAlpha, indexA+offset_row*offset_factor, strideA, blockB, strideB, offsetB, result);
colLoopBody<1, 16>(col, row, depth, cols, rows, offset_row, block_index, pAlpha, indexA+offset_row*offset_factor, strideA, blockB, strideB, offsetB, result);
if(cols > col){ if(cols > col){
int extra_cols= cols-col; Index extra_cols= cols-col;
int shift = (4-extra_cols>= 0) ? 4-extra_cols: 0;
int mask_cols= 0xF >> shift;
//Remember: It doesnt make sense use multiple acc to extra_cols as we are unrolling col loop //Remember: It doesnt make sense use multiple acc to extra_cols as we are unrolling col loop
colLoopBody<1, 16, 2, true>(&col, row, depth, cols, rows, offset_row, block_index, pAlpha, indexA+offset_row, strideA, blockB, strideB, offsetB, result, extra_cols, 4, mask_cols, 0xF); colLoopBody<1, 16, true>(col, row, depth, cols, rows, offset_row, block_index, pAlpha, indexA+offset_row*offset_factor, strideA, blockB, strideB, offsetB, result, extra_cols, 4);
} }
} }
row += 16; row += 16;
@ -296,75 +220,66 @@ void gemmMMAbfloat16(const DataMapper& res, const bfloat16* blockA, const bfloat
//LHS (8x8) block //LHS (8x8) block
if(rows - standard_blocks_quantity*16 >= 8){ if(rows - standard_blocks_quantity*16 >= 8){
indexA += 1*8*offsetA + 2*8*offsetA; indexA += 1*8*offsetA + 2*8*offsetA;
for(int offset_row = 0; offset_row < 8; offset_row += 4){ for(Index offset_row = 0; offset_row < 8; offset_row += 4){
col = 0; col = 0;
colLoopBody<5, 8, 1>(&col, row, depth, cols, rows, offset_row, block_index, pAlpha, indexA+offset_row, strideA, blockB, strideB, offsetB, result); colLoopBody<7, 8>(col, row, depth, cols, rows, offset_row, block_index, pAlpha, indexA+offset_row*offset_factor, strideA, blockB, strideB, offsetB, result);
colLoopBody<4, 8, 1>(&col, row, depth, cols, rows, offset_row, block_index, pAlpha, indexA+offset_row, strideA, blockB, strideB, offsetB, result); colLoopBody<6, 8>(col, row, depth, cols, rows, offset_row, block_index, pAlpha, indexA+offset_row*offset_factor, strideA, blockB, strideB, offsetB, result);
colLoopBody<3, 8, 1>(&col, row, depth, cols, rows, offset_row, block_index, pAlpha, indexA+offset_row, strideA, blockB, strideB, offsetB, result); colLoopBody<5, 8>(col, row, depth, cols, rows, offset_row, block_index, pAlpha, indexA+offset_row*offset_factor, strideA, blockB, strideB, offsetB, result);
colLoopBody<2, 8, 1>(&col, row, depth, cols, rows, offset_row, block_index, pAlpha, indexA+offset_row, strideA, blockB, strideB, offsetB, result); colLoopBody<4, 8>(col, row, depth, cols, rows, offset_row, block_index, pAlpha, indexA+offset_row*offset_factor, strideA, blockB, strideB, offsetB, result);
colLoopBody<1, 8, 1>(&col, row, depth, cols, rows, offset_row, block_index, pAlpha, indexA+offset_row, strideA, blockB, strideB, offsetB, result); colLoopBody<3, 8>(col, row, depth, cols, rows, offset_row, block_index, pAlpha, indexA+offset_row*offset_factor, strideA, blockB, strideB, offsetB, result);
colLoopBody<2, 8>(col, row, depth, cols, rows, offset_row, block_index, pAlpha, indexA+offset_row*offset_factor, strideA, blockB, strideB, offsetB, result);
colLoopBody<1, 8>(col, row, depth, cols, rows, offset_row, block_index, pAlpha, indexA+offset_row*offset_factor, strideA, blockB, strideB, offsetB, result);
} }
if(cols > col){ if(cols > col){
int extra_cols= cols-col; Index extra_cols= cols-col;
int shift = (4-extra_cols>= 0) ? 4-extra_cols: 0;
int mask_cols= 0xF >> shift;
for(int offset_row = 0; offset_row < 8; offset_row += 4){ for(Index offset_row = 0; offset_row < 8; offset_row += 4){
colLoopBody<1, 8, 1, true>(&col, row, depth, cols, rows, offset_row, block_index, pAlpha, indexA+offset_row, strideA, blockB, strideB, offsetB, result, extra_cols, 4, mask_cols, 0xF); colLoopBody<1, 8, true>(col, row, depth, cols, rows, offset_row, block_index, pAlpha, indexA+offset_row*offset_factor, strideA, blockB, strideB, offsetB, result, extra_cols, 4);
} }
} //end extra cols } //end extra cols
row += 8; row += 8;
} }
//extra rows //extra rows
while(row < rows){ while(row < rows){
int extra_rows = rows-row; Index extra_rows = rows-row;
int shift = (4-extra_rows >= 0) ? 4-extra_rows : 0; Index extra_rows_or_four = (extra_rows <= 4) ? extra_rows : 4;
int mask_rows = 0xF >> shift;
int extra_rows_or_four = (extra_rows <= 4) ? extra_rows : 4;
//This index is the beginning of remaining block. //This index is the beginning of remaining block.
//This last block for LHS is organized as RowMajor //This last block for LHS is organized as RowMajor
col = 0; col = 0;
colLoopBody<5, 8, 1, false, true>(&col, row, depth, cols, rows, 0, block_index, pAlpha, blockA, strideA, blockB, strideB, offsetB, result, 4, extra_rows_or_four, 0xF, mask_rows); colLoopBody<7, 8, false, true>(col, row, depth, cols, rows, 0, block_index, pAlpha, blockA, strideA, blockB, strideB, offsetB, result, 4, extra_rows_or_four);
colLoopBody<4, 8, 1, false, true>(&col, row, depth, cols, rows, 0, block_index, pAlpha, blockA, strideA, blockB, strideB, offsetB, result, 4, extra_rows_or_four, 0xF, mask_rows); colLoopBody<6, 8, false, true>(col, row, depth, cols, rows, 0, block_index, pAlpha, blockA, strideA, blockB, strideB, offsetB, result, 4, extra_rows_or_four);
colLoopBody<3, 8, 1, false, true>(&col, row, depth, cols, rows, 0, block_index, pAlpha, blockA, strideA, blockB, strideB, offsetB, result, 4, extra_rows_or_four, 0xF, mask_rows); colLoopBody<5, 8, false, true>(col, row, depth, cols, rows, 0, block_index, pAlpha, blockA, strideA, blockB, strideB, offsetB, result, 4, extra_rows_or_four);
colLoopBody<2, 8, 1, false, true>(&col, row, depth, cols, rows, 0, block_index, pAlpha, blockA, strideA, blockB, strideB, offsetB, result, 4, extra_rows_or_four, 0xF, mask_rows); colLoopBody<4, 8, false, true>(col, row, depth, cols, rows, 0, block_index, pAlpha, blockA, strideA, blockB, strideB, offsetB, result, 4, extra_rows_or_four);
colLoopBody<1, 8, 1, false, true>(&col, row, depth, cols, rows, 0, block_index, pAlpha, blockA, strideA, blockB, strideB, offsetB, result, 4, extra_rows_or_four, 0xF, mask_rows); colLoopBody<3, 8, false, true>(col, row, depth, cols, rows, 0, block_index, pAlpha, blockA, strideA, blockB, strideB, offsetB, result, 4, extra_rows_or_four);
colLoopBody<2, 8, false, true>(col, row, depth, cols, rows, 0, block_index, pAlpha, blockA, strideA, blockB, strideB, offsetB, result, 4, extra_rows_or_four);
colLoopBody<1, 8, false, true>(col, row, depth, cols, rows, 0, block_index, pAlpha, blockA, strideA, blockB, strideB, offsetB, result, 4, extra_rows_or_four);
if(cols > col){ if(cols > col){
int extra_cols= cols-col; Index extra_cols= cols-col;
int shift = (4-extra_cols>= 0) ? 4-extra_cols: 0;
int mask_cols= 0xF >> shift;
int extra_cols_or_four = (extra_cols <= 4) ? extra_cols : 4;
colLoopBody<1, 8, 1, true, true>(&col, row, depth, cols, rows, 0, block_index, pAlpha, blockA, strideA, blockB, strideB, offsetB, result, extra_cols_or_four, extra_rows_or_four, mask_cols, mask_rows); colLoopBody<1, 8, true, true>(col, row, depth, cols, rows, 0, block_index, pAlpha, blockA, strideA, blockB, strideB, offsetB, result, extra_cols, extra_rows_or_four);
} }
row += extra_rows_or_four; row += extra_rows_or_four;
} }
//Convert back to bfloat16 //Convert back to bfloat16
for(col = 0; col/4 < cols/4; col += 4){ for(col = 0; col + 4 <= cols; col += 4){
int row; const DataMapper res2 = res.getSubMapper(0, col);
for(row = 0; row/8 < rows/8; row += 8){ for(row = 0; row + 8 <= rows; row += 8){
//get and save block //get and save block
PacketBlock<Packet8bf,4> block; PacketBlock<Packet8bf,4> block;
for(int j = 0; j < 4; j++){ for(Index j = 0; j < 4; j++){
Packet4f temp_even, temp_odd; Packet16uc fp16_0 = __builtin_vsx_xvcvspbf16(reinterpret_cast<Packet16uc>(pload<Packet4f>(result + (col + j)*rows + row)));
EIGEN_ALIGN32 float even[4], odd[4]; Packet16uc fp16_1 = __builtin_vsx_xvcvspbf16(reinterpret_cast<Packet16uc>(pload<Packet4f>(result + (col + j)*rows + row + 4)));
for(int i = 0; i < 4; i++){ block.packet[j].m_val = vec_pack(reinterpret_cast<Packet4ui>(fp16_0), reinterpret_cast<Packet4ui>(fp16_1));
even[i] = result[(col + j)*rows + row + i*2];
odd[i] = result[(col + j)*rows + row + i*2+1];
}
temp_even = pload<Packet4f>(even);
temp_odd = pload<Packet4f>(odd);
block.packet[j] = F32ToBf16(temp_even, temp_odd);
} }
res.template storePacketBlock<Packet8bf,4>(row, col, block); res2.template storePacketBlock<Packet8bf,4>(row, 0, block);
} }
//extra rows //extra rows
while(row < rows){ while(row < rows){
for(int col_off = 0; col_off < 4; col_off++){ for(Index col_off = 0; col_off < 4; col_off++){
res(row, col+col_off) = Eigen::bfloat16(result[(col+col_off)*rows+row]); res2(row, col_off) = Eigen::bfloat16(result[(col+col_off)*rows+row]);
} }
row++; row++;
} }
@ -372,8 +287,9 @@ void gemmMMAbfloat16(const DataMapper& res, const bfloat16* blockA, const bfloat
} }
//extra cols //extra cols
while(col < cols){ while(col < cols){
for(int r = 0; r < rows; r++){ const LinearMapper res2 = res.getLinearMapper(0, col);
res(r, col) = Eigen::bfloat16(result[col*rows + r]); for(Index r= 0; r< rows; r++){
res2(r) = Eigen::bfloat16(result[col*rows + r]);
} }
col++; col++;
} }

View File

@ -380,11 +380,6 @@ public:
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void store(const blas_data_mapper<Scalar, Index, StorageOrder, AlignmentType, Incr>* sup, Index i, Index j, const PacketBlock<SubPacket, n>& block) const { EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void store(const blas_data_mapper<Scalar, Index, StorageOrder, AlignmentType, Incr>* sup, Index i, Index j, const PacketBlock<SubPacket, n>& block) const {
spbh.store(sup, i,j,block); spbh.store(sup, i,j,block);
sup->template storePacket<SubPacket>(i, j+idx, block.packet[idx]); sup->template storePacket<SubPacket>(i, j+idx, block.packet[idx]);
//for(int l = 0; l < unpacket_traits<SubPacket>::size; l++)
//{
// Scalar_ *v = &sup->operator()(i+l, j+idx);
// *v = *reinterpret_cast<Scalar_ *>(&block.packet[idx][l]);
//}
} }
}; };