Unroll F32 to BF16 loop - 1.8X faster conversions for LLVM. Use vector pairs for GCC.

This commit is contained in:
Chip Kerchner 2023-05-01 16:54:16 +00:00 committed by Rasmus Munk Larsen
parent 874f5947f4
commit 6418ac0285
4 changed files with 50 additions and 28 deletions

View File

@ -2839,13 +2839,13 @@ EIGEN_ALWAYS_INLINE Packet4f oneConvertBF16Perm(Packet8us data, Packet16uc mask)
} }
template<bool lhsExtraRows, bool odd, Index size> template<bool lhsExtraRows, bool odd, Index size>
EIGEN_ALWAYS_INLINE void convertArrayPointerBF16toF32DupOne(float *result, Index col, Index rows, const bfloat16* src, Index extra_rows) EIGEN_ALWAYS_INLINE void convertArrayPointerBF16toF32DupOne(float *result, Index rows, const bfloat16* src, Index extra_rows)
{ {
Packet4f dup[4*4]; Packet4f dup[4*4];
Packet8bf data[4]; Packet8bf data[4];
for (Index i = 0; i < size; i++) { for (Index i = 0; i < size; i++) {
data[i] = ploadu<Packet8bf>(src + col + rows*i); data[i] = ploadu<Packet8bf>(src + rows*i);
} }
for (Index i = 0, j = 0; i < size; i++, j += 4) { for (Index i = 0, j = 0; i < size; i++, j += 4) {
@ -2876,15 +2876,16 @@ EIGEN_ALWAYS_INLINE void convertArrayPointerBF16toF32DupOne(float *result, Index
template<bool lhsExtraRows> template<bool lhsExtraRows>
EIGEN_ALWAYS_INLINE void convertArrayPointerBF16toF32Dup(float *result, Index cols, Index rows, const bfloat16* src, Index delta, Index extra_rows) EIGEN_ALWAYS_INLINE void convertArrayPointerBF16toF32Dup(float *result, Index cols, Index rows, const bfloat16* src, Index delta, Index extra_rows)
{ {
Index col2 = 0, col = 0; Index col = 0;
for(; col + 4*2 <= cols; col += 4*2, col2 += 4*rows, result += 4*4*4) { src += delta*2;
convertArrayPointerBF16toF32DupOne<lhsExtraRows,false,4>(result, col2 + delta*2, rows, src, extra_rows); for(; col + 4*2 <= cols; col += 4*2, result += 4*4*4, src += 4*rows) {
convertArrayPointerBF16toF32DupOne<lhsExtraRows,false,4>(result, rows, src, extra_rows);
} }
for(; col + 2 <= cols; col += 2, col2 += rows, result += 4*4) { for(; col + 2 <= cols; col += 2, result += 4*4, src += rows) {
convertArrayPointerBF16toF32DupOne<lhsExtraRows,false,1>(result, col2 + delta*2, rows, src, extra_rows); convertArrayPointerBF16toF32DupOne<lhsExtraRows,false,1>(result, rows, src, extra_rows);
} }
if (cols & 1) { if (cols & 1) {
convertArrayPointerBF16toF32DupOne<lhsExtraRows,true,1>(result, col2 + delta, rows, src, extra_rows); convertArrayPointerBF16toF32DupOne<lhsExtraRows,true,1>(result, rows, src - delta, extra_rows);
} }
} }
@ -2892,7 +2893,7 @@ template<const Index size, bool non_unit_stride>
EIGEN_ALWAYS_INLINE void convertPointerBF16toF32(Index& i, float *result, Index rows, bfloat16*& src, Index resInc) EIGEN_ALWAYS_INLINE void convertPointerBF16toF32(Index& i, float *result, Index rows, bfloat16*& src, Index resInc)
{ {
constexpr Index extra = ((size < 4) ? 4 : size); constexpr Index extra = ((size < 4) ? 4 : size);
for(; i + size <= rows; i += extra, src += extra*resInc){ while (i + size <= rows) {
PacketBlock<Packet8bf,(size+7)/8> r32; PacketBlock<Packet8bf,(size+7)/8> r32;
r32.packet[0] = loadBF16fromResult<non_unit_stride, 0>(src, resInc); r32.packet[0] = loadBF16fromResult<non_unit_stride, 0>(src, resInc);
if (size >= 16) { if (size >= 16) {
@ -2903,6 +2904,8 @@ EIGEN_ALWAYS_INLINE void convertPointerBF16toF32(Index& i, float *result, Index
r32.packet[3] = loadBF16fromResult<non_unit_stride, 24>(src, resInc); r32.packet[3] = loadBF16fromResult<non_unit_stride, 24>(src, resInc);
} }
storeConvertBlockBF16<size>(result + i, r32, rows & 3); storeConvertBlockBF16<size>(result + i, r32, rows & 3);
i += extra; src += extra*resInc;
if (size != 32) break;
} }
} }
@ -3131,7 +3134,7 @@ template<const Index size, typename DataMapper>
EIGEN_ALWAYS_INLINE void convertBF16toF32(Index& i, float *result, Index rows, const DataMapper& src) EIGEN_ALWAYS_INLINE void convertBF16toF32(Index& i, float *result, Index rows, const DataMapper& src)
{ {
constexpr Index extra = ((size < 4) ? 4 : size); constexpr Index extra = ((size < 4) ? 4 : size);
for(; i + size <= rows; i += extra){ while (i + size <= rows) {
PacketBlock<Packet8bf,(size+7)/8> r32; PacketBlock<Packet8bf,(size+7)/8> r32;
r32.packet[0] = src.template loadPacket<Packet8bf>(i + 0); r32.packet[0] = src.template loadPacket<Packet8bf>(i + 0);
if (size >= 16) { if (size >= 16) {
@ -3142,6 +3145,8 @@ EIGEN_ALWAYS_INLINE void convertBF16toF32(Index& i, float *result, Index rows, c
r32.packet[3] = src.template loadPacket<Packet8bf>(i + 24); r32.packet[3] = src.template loadPacket<Packet8bf>(i + 24);
} }
storeConvertBlockBF16<size>(result + i, r32, rows & 3); storeConvertBlockBF16<size>(result + i, r32, rows & 3);
i += extra;
if (size != 32) break;
} }
} }
@ -3171,18 +3176,18 @@ EIGEN_ALWAYS_INLINE void convertArrayF32toBF16ColVSX(float *result, Index col, I
const DataMapper res2 = res.getSubMapper(0, col); const DataMapper res2 = res.getSubMapper(0, col);
Index row; Index row;
float *result2 = result + col*rows; float *result2 = result + col*rows;
for(row = 0; row + 8 <= rows; row += 8){ for(row = 0; row + 8 <= rows; row += 8, result2 += 8){
// get and save block // get and save block
PacketBlock<Packet8bf,size> block; PacketBlock<Packet8bf,size> block;
for(Index j = 0; j < size; j++){ for(Index j = 0; j < size; j++){
block.packet[j] = convertF32toBF16VSX(result2 + j*rows + row); block.packet[j] = convertF32toBF16VSX(result2 + j*rows);
} }
res2.template storePacketBlock<Packet8bf,size>(row, 0, block); res2.template storePacketBlock<Packet8bf,size>(row, 0, block);
} }
// extra rows // extra rows
if(row < rows){ if(row < rows){
for(Index j = 0; j < size; j++){ for(Index j = 0; j < size; j++){
Packet8bf fp16 = convertF32toBF16VSX(result2 + j*rows + row); Packet8bf fp16 = convertF32toBF16VSX(result2 + j*rows);
res2.template storePacketPartial<Packet8bf>(row, j, fp16, rows & 7); res2.template storePacketPartial<Packet8bf>(row, j, fp16, rows & 7);
} }
} }
@ -3196,9 +3201,16 @@ EIGEN_ALWAYS_INLINE void convertArrayF32toBF16VSX(float *result, Index cols, Ind
convertArrayF32toBF16ColVSX<DataMapper,4>(result, col, rows, res); convertArrayF32toBF16ColVSX<DataMapper,4>(result, col, rows, res);
} }
// extra cols // extra cols
while(col < cols){ switch (cols - col) {
case 1:
convertArrayF32toBF16ColVSX<DataMapper,1>(result, col, rows, res); convertArrayF32toBF16ColVSX<DataMapper,1>(result, col, rows, res);
col++; break;
case 2:
convertArrayF32toBF16ColVSX<DataMapper,2>(result, col, rows, res);
break;
case 3:
convertArrayF32toBF16ColVSX<DataMapper,3>(result, col, rows, res);
break;
} }
} }

View File

@ -215,15 +215,10 @@ EIGEN_ALWAYS_INLINE void colLoops(Index depth, Index cols, Index rows, const Pac
EIGEN_ALWAYS_INLINE Packet8bf convertF32toBF16(const float *res) EIGEN_ALWAYS_INLINE Packet8bf convertF32toBF16(const float *res)
{ {
Packet16uc fp16[2]; Packet16uc fp16[2];
#if EIGEN_COMP_LLVM
__vector_pair fp16_vp = *reinterpret_cast<__vector_pair *>(const_cast<float *>(res)); __vector_pair fp16_vp = *reinterpret_cast<__vector_pair *>(const_cast<float *>(res));
__builtin_vsx_disassemble_pair(reinterpret_cast<void*>(fp16), &fp16_vp); __builtin_vsx_disassemble_pair(reinterpret_cast<void*>(fp16), &fp16_vp);
fp16[0] = __builtin_vsx_xvcvspbf16(fp16[0]); fp16[0] = __builtin_vsx_xvcvspbf16(fp16[0]);
fp16[1] = __builtin_vsx_xvcvspbf16(fp16[1]); fp16[1] = __builtin_vsx_xvcvspbf16(fp16[1]);
#else
fp16[0] = __builtin_vsx_xvcvspbf16(reinterpret_cast<Packet16uc>(ploadu<Packet4f>(res + 0)));
fp16[1] = __builtin_vsx_xvcvspbf16(reinterpret_cast<Packet16uc>(ploadu<Packet4f>(res + 4)));
#endif
return vec_pack(reinterpret_cast<Packet4ui>(fp16[0]), reinterpret_cast<Packet4ui>(fp16[1])); return vec_pack(reinterpret_cast<Packet4ui>(fp16[0]), reinterpret_cast<Packet4ui>(fp16[1]));
} }
@ -233,18 +228,20 @@ EIGEN_ALWAYS_INLINE void convertArrayF32toBF16Col(float *result, Index col, Inde
const DataMapper res2 = res.getSubMapper(0, col); const DataMapper res2 = res.getSubMapper(0, col);
Index row; Index row;
float *result2 = result + col*rows; float *result2 = result + col*rows;
for(row = 0; row + 8 <= rows; row += 8){ for(row = 0; row + 8 <= rows; row += 8, result2 += 8){
// get and save block // get and save block
PacketBlock<Packet8bf,size> block; PacketBlock<Packet8bf,size> block;
BFLOAT16_UNROLL
for(Index j = 0; j < size; j++){ for(Index j = 0; j < size; j++){
block.packet[j] = convertF32toBF16(result2 + j*rows + row); block.packet[j] = convertF32toBF16(result2 + j*rows);
} }
res2.template storePacketBlock<Packet8bf,size>(row, 0, block); res2.template storePacketBlock<Packet8bf,size>(row, 0, block);
} }
// extra rows // extra rows
if(row < rows){ if(row < rows){
BFLOAT16_UNROLL
for(Index j = 0; j < size; j++){ for(Index j = 0; j < size; j++){
Packet8bf fp16 = convertF32toBF16(result2 + j*rows + row); Packet8bf fp16 = convertF32toBF16(result2 + j*rows);
res2.template storePacketPartial<Packet8bf>(row, j, fp16, rows & 7); res2.template storePacketPartial<Packet8bf>(row, j, fp16, rows & 7);
} }
} }
@ -254,7 +251,7 @@ template<const Index size, bool non_unit_stride = false>
EIGEN_ALWAYS_INLINE void convertPointerF32toBF16(Index& i, float* result, Index rows, bfloat16*& dst, Index resInc = 1) EIGEN_ALWAYS_INLINE void convertPointerF32toBF16(Index& i, float* result, Index rows, bfloat16*& dst, Index resInc = 1)
{ {
constexpr Index extra = ((size < 8) ? 8 : size); constexpr Index extra = ((size < 8) ? 8 : size);
for(; i + size <= rows; i += extra, dst += extra*resInc){ while (i + size <= rows){
PacketBlock<Packet8bf,(size+7)/8> r32; PacketBlock<Packet8bf,(size+7)/8> r32;
r32.packet[0] = convertF32toBF16(result + i + 0); r32.packet[0] = convertF32toBF16(result + i + 0);
if (size >= 16) { if (size >= 16) {
@ -272,6 +269,8 @@ EIGEN_ALWAYS_INLINE void convertPointerF32toBF16(Index& i, float* result, Index
storeBF16fromResult<size, non_unit_stride, 16>(dst, r32.packet[2], resInc); storeBF16fromResult<size, non_unit_stride, 16>(dst, r32.packet[2], resInc);
storeBF16fromResult<size, non_unit_stride, 24>(dst, r32.packet[3], resInc); storeBF16fromResult<size, non_unit_stride, 24>(dst, r32.packet[3], resInc);
} }
i += extra; dst += extra*resInc;
if (size != 32) break;
} }
} }
@ -293,9 +292,16 @@ EIGEN_ALWAYS_INLINE void convertArrayF32toBF16(float *result, Index cols, Index
convertArrayF32toBF16Col<DataMapper,4>(result, col, rows, res); convertArrayF32toBF16Col<DataMapper,4>(result, col, rows, res);
} }
// extra cols // extra cols
while(col < cols){ switch (cols - col) {
case 1:
convertArrayF32toBF16Col<DataMapper,1>(result, col, rows, res); convertArrayF32toBF16Col<DataMapper,1>(result, col, rows, res);
col++; break;
case 2:
convertArrayF32toBF16Col<DataMapper,2>(result, col, rows, res);
break;
case 3:
convertArrayF32toBF16Col<DataMapper,3>(result, col, rows, res);
break;
} }
} }

View File

@ -657,7 +657,7 @@ template<const Index size, bool inc = false>
EIGEN_ALWAYS_INLINE void convertPointerF32toBF16VSX(Index& i, float* result, Index rows, bfloat16*& dst, Index resInc = 1) EIGEN_ALWAYS_INLINE void convertPointerF32toBF16VSX(Index& i, float* result, Index rows, bfloat16*& dst, Index resInc = 1)
{ {
constexpr Index extra = ((size < 8) ? 8 : size); constexpr Index extra = ((size < 8) ? 8 : size);
for(; i + size <= rows; i += extra, dst += extra*resInc){ while (i + size <= rows) {
PacketBlock<Packet8bf,(size+7)/8> r32; PacketBlock<Packet8bf,(size+7)/8> r32;
r32.packet[0] = convertF32toBF16VSX(result + i + 0); r32.packet[0] = convertF32toBF16VSX(result + i + 0);
if (size >= 16) { if (size >= 16) {
@ -675,6 +675,8 @@ EIGEN_ALWAYS_INLINE void convertPointerF32toBF16VSX(Index& i, float* result, Ind
storeBF16fromResult<size, inc, 16>(dst, r32.packet[2], resInc); storeBF16fromResult<size, inc, 16>(dst, r32.packet[2], resInc);
storeBF16fromResult<size, inc, 24>(dst, r32.packet[3], resInc); storeBF16fromResult<size, inc, 24>(dst, r32.packet[3], resInc);
} }
i += extra; dst += extra*resInc;
if (size != 32) break;
} }
} }

View File

@ -1154,6 +1154,7 @@ template<> EIGEN_STRONG_INLINE Packet8bf por<Packet8bf>(const Packet8bf& a, cons
template<> EIGEN_STRONG_INLINE Packet4f pxor<Packet4f>(const Packet4f& a, const Packet4f& b) { return vec_xor(a, b); } template<> EIGEN_STRONG_INLINE Packet4f pxor<Packet4f>(const Packet4f& a, const Packet4f& b) { return vec_xor(a, b); }
template<> EIGEN_STRONG_INLINE Packet4i pxor<Packet4i>(const Packet4i& a, const Packet4i& b) { return vec_xor(a, b); } template<> EIGEN_STRONG_INLINE Packet4i pxor<Packet4i>(const Packet4i& a, const Packet4i& b) { return vec_xor(a, b); }
template<> EIGEN_STRONG_INLINE Packet8us pxor<Packet8us>(const Packet8us& a, const Packet8us& b) { return vec_xor(a, b); }
template<> EIGEN_STRONG_INLINE Packet8bf pxor<Packet8bf>(const Packet8bf& a, const Packet8bf& b) { template<> EIGEN_STRONG_INLINE Packet8bf pxor<Packet8bf>(const Packet8bf& a, const Packet8bf& b) {
return pxor<Packet8us>(a, b); return pxor<Packet8us>(a, b);
} }
@ -1884,7 +1885,8 @@ template<> EIGEN_STRONG_INLINE Packet8bf pdiv<Packet8bf>(const Packet8bf& a, con
} }
template<> EIGEN_STRONG_INLINE Packet8bf pnegate<Packet8bf>(const Packet8bf& a) { template<> EIGEN_STRONG_INLINE Packet8bf pnegate<Packet8bf>(const Packet8bf& a) {
BF16_TO_F32_UNARY_OP_WRAPPER(pnegate<Packet4f>, a); EIGEN_DECLARE_CONST_FAST_Packet8us(neg_mask,0x8000);
return pxor<Packet8us>(p8us_neg_mask, a);
} }
template<> EIGEN_STRONG_INLINE Packet8bf psub<Packet8bf>(const Packet8bf& a, const Packet8bf& b) { template<> EIGEN_STRONG_INLINE Packet8bf psub<Packet8bf>(const Packet8bf& a, const Packet8bf& b) {