From 6418ac02859547f12856be3952a1ba6cc41104f7 Mon Sep 17 00:00:00 2001 From: Chip Kerchner Date: Mon, 1 May 2023 16:54:16 +0000 Subject: [PATCH] Unroll F32 to BF16 loop - 1.8X faster conversions for LLVM. Use vector pairs for GCC. --- Eigen/src/Core/arch/AltiVec/MatrixProduct.h | 42 ++++++++++++------- .../arch/AltiVec/MatrixProductMMAbfloat16.h | 28 ++++++++----- .../Core/arch/AltiVec/MatrixVectorProduct.h | 4 +- Eigen/src/Core/arch/AltiVec/PacketMath.h | 4 +- 4 files changed, 50 insertions(+), 28 deletions(-) diff --git a/Eigen/src/Core/arch/AltiVec/MatrixProduct.h b/Eigen/src/Core/arch/AltiVec/MatrixProduct.h index 5545234f9..e86cc5b4f 100644 --- a/Eigen/src/Core/arch/AltiVec/MatrixProduct.h +++ b/Eigen/src/Core/arch/AltiVec/MatrixProduct.h @@ -2839,13 +2839,13 @@ EIGEN_ALWAYS_INLINE Packet4f oneConvertBF16Perm(Packet8us data, Packet16uc mask) } template -EIGEN_ALWAYS_INLINE void convertArrayPointerBF16toF32DupOne(float *result, Index col, Index rows, const bfloat16* src, Index extra_rows) +EIGEN_ALWAYS_INLINE void convertArrayPointerBF16toF32DupOne(float *result, Index rows, const bfloat16* src, Index extra_rows) { Packet4f dup[4*4]; Packet8bf data[4]; for (Index i = 0; i < size; i++) { - data[i] = ploadu(src + col + rows*i); + data[i] = ploadu(src + rows*i); } for (Index i = 0, j = 0; i < size; i++, j += 4) { @@ -2876,15 +2876,16 @@ EIGEN_ALWAYS_INLINE void convertArrayPointerBF16toF32DupOne(float *result, Index template EIGEN_ALWAYS_INLINE void convertArrayPointerBF16toF32Dup(float *result, Index cols, Index rows, const bfloat16* src, Index delta, Index extra_rows) { - Index col2 = 0, col = 0; - for(; col + 4*2 <= cols; col += 4*2, col2 += 4*rows, result += 4*4*4) { - convertArrayPointerBF16toF32DupOne(result, col2 + delta*2, rows, src, extra_rows); + Index col = 0; + src += delta*2; + for(; col + 4*2 <= cols; col += 4*2, result += 4*4*4, src += 4*rows) { + convertArrayPointerBF16toF32DupOne(result, rows, src, extra_rows); } - for(; col + 2 <= cols; col += 2, col2 += rows, result += 4*4) { - convertArrayPointerBF16toF32DupOne(result, col2 + delta*2, rows, src, extra_rows); + for(; col + 2 <= cols; col += 2, result += 4*4, src += rows) { + convertArrayPointerBF16toF32DupOne(result, rows, src, extra_rows); } if (cols & 1) { - convertArrayPointerBF16toF32DupOne(result, col2 + delta, rows, src, extra_rows); + convertArrayPointerBF16toF32DupOne(result, rows, src - delta, extra_rows); } } @@ -2892,7 +2893,7 @@ template EIGEN_ALWAYS_INLINE void convertPointerBF16toF32(Index& i, float *result, Index rows, bfloat16*& src, Index resInc) { constexpr Index extra = ((size < 4) ? 4 : size); - for(; i + size <= rows; i += extra, src += extra*resInc){ + while (i + size <= rows) { PacketBlock r32; r32.packet[0] = loadBF16fromResult(src, resInc); if (size >= 16) { @@ -2903,6 +2904,8 @@ EIGEN_ALWAYS_INLINE void convertPointerBF16toF32(Index& i, float *result, Index r32.packet[3] = loadBF16fromResult(src, resInc); } storeConvertBlockBF16(result + i, r32, rows & 3); + i += extra; src += extra*resInc; + if (size != 32) break; } } @@ -3131,7 +3134,7 @@ template EIGEN_ALWAYS_INLINE void convertBF16toF32(Index& i, float *result, Index rows, const DataMapper& src) { constexpr Index extra = ((size < 4) ? 4 : size); - for(; i + size <= rows; i += extra){ + while (i + size <= rows) { PacketBlock r32; r32.packet[0] = src.template loadPacket(i + 0); if (size >= 16) { @@ -3142,6 +3145,8 @@ EIGEN_ALWAYS_INLINE void convertBF16toF32(Index& i, float *result, Index rows, c r32.packet[3] = src.template loadPacket(i + 24); } storeConvertBlockBF16(result + i, r32, rows & 3); + i += extra; + if (size != 32) break; } } @@ -3171,18 +3176,18 @@ EIGEN_ALWAYS_INLINE void convertArrayF32toBF16ColVSX(float *result, Index col, I const DataMapper res2 = res.getSubMapper(0, col); Index row; float *result2 = result + col*rows; - for(row = 0; row + 8 <= rows; row += 8){ + for(row = 0; row + 8 <= rows; row += 8, result2 += 8){ // get and save block PacketBlock block; for(Index j = 0; j < size; j++){ - block.packet[j] = convertF32toBF16VSX(result2 + j*rows + row); + block.packet[j] = convertF32toBF16VSX(result2 + j*rows); } res2.template storePacketBlock(row, 0, block); } // extra rows if(row < rows){ for(Index j = 0; j < size; j++){ - Packet8bf fp16 = convertF32toBF16VSX(result2 + j*rows + row); + Packet8bf fp16 = convertF32toBF16VSX(result2 + j*rows); res2.template storePacketPartial(row, j, fp16, rows & 7); } } @@ -3196,9 +3201,16 @@ EIGEN_ALWAYS_INLINE void convertArrayF32toBF16VSX(float *result, Index cols, Ind convertArrayF32toBF16ColVSX(result, col, rows, res); } // extra cols - while(col < cols){ + switch (cols - col) { + case 1: convertArrayF32toBF16ColVSX(result, col, rows, res); - col++; + break; + case 2: + convertArrayF32toBF16ColVSX(result, col, rows, res); + break; + case 3: + convertArrayF32toBF16ColVSX(result, col, rows, res); + break; } } diff --git a/Eigen/src/Core/arch/AltiVec/MatrixProductMMAbfloat16.h b/Eigen/src/Core/arch/AltiVec/MatrixProductMMAbfloat16.h index 6bdbb0b56..0944c2d16 100644 --- a/Eigen/src/Core/arch/AltiVec/MatrixProductMMAbfloat16.h +++ b/Eigen/src/Core/arch/AltiVec/MatrixProductMMAbfloat16.h @@ -215,15 +215,10 @@ EIGEN_ALWAYS_INLINE void colLoops(Index depth, Index cols, Index rows, const Pac EIGEN_ALWAYS_INLINE Packet8bf convertF32toBF16(const float *res) { Packet16uc fp16[2]; -#if EIGEN_COMP_LLVM __vector_pair fp16_vp = *reinterpret_cast<__vector_pair *>(const_cast(res)); __builtin_vsx_disassemble_pair(reinterpret_cast(fp16), &fp16_vp); fp16[0] = __builtin_vsx_xvcvspbf16(fp16[0]); fp16[1] = __builtin_vsx_xvcvspbf16(fp16[1]); -#else - fp16[0] = __builtin_vsx_xvcvspbf16(reinterpret_cast(ploadu(res + 0))); - fp16[1] = __builtin_vsx_xvcvspbf16(reinterpret_cast(ploadu(res + 4))); -#endif return vec_pack(reinterpret_cast(fp16[0]), reinterpret_cast(fp16[1])); } @@ -233,18 +228,20 @@ EIGEN_ALWAYS_INLINE void convertArrayF32toBF16Col(float *result, Index col, Inde const DataMapper res2 = res.getSubMapper(0, col); Index row; float *result2 = result + col*rows; - for(row = 0; row + 8 <= rows; row += 8){ + for(row = 0; row + 8 <= rows; row += 8, result2 += 8){ // get and save block PacketBlock block; + BFLOAT16_UNROLL for(Index j = 0; j < size; j++){ - block.packet[j] = convertF32toBF16(result2 + j*rows + row); + block.packet[j] = convertF32toBF16(result2 + j*rows); } res2.template storePacketBlock(row, 0, block); } // extra rows if(row < rows){ + BFLOAT16_UNROLL for(Index j = 0; j < size; j++){ - Packet8bf fp16 = convertF32toBF16(result2 + j*rows + row); + Packet8bf fp16 = convertF32toBF16(result2 + j*rows); res2.template storePacketPartial(row, j, fp16, rows & 7); } } @@ -254,7 +251,7 @@ template EIGEN_ALWAYS_INLINE void convertPointerF32toBF16(Index& i, float* result, Index rows, bfloat16*& dst, Index resInc = 1) { constexpr Index extra = ((size < 8) ? 8 : size); - for(; i + size <= rows; i += extra, dst += extra*resInc){ + while (i + size <= rows){ PacketBlock r32; r32.packet[0] = convertF32toBF16(result + i + 0); if (size >= 16) { @@ -272,6 +269,8 @@ EIGEN_ALWAYS_INLINE void convertPointerF32toBF16(Index& i, float* result, Index storeBF16fromResult(dst, r32.packet[2], resInc); storeBF16fromResult(dst, r32.packet[3], resInc); } + i += extra; dst += extra*resInc; + if (size != 32) break; } } @@ -293,9 +292,16 @@ EIGEN_ALWAYS_INLINE void convertArrayF32toBF16(float *result, Index cols, Index convertArrayF32toBF16Col(result, col, rows, res); } // extra cols - while(col < cols){ + switch (cols - col) { + case 1: convertArrayF32toBF16Col(result, col, rows, res); - col++; + break; + case 2: + convertArrayF32toBF16Col(result, col, rows, res); + break; + case 3: + convertArrayF32toBF16Col(result, col, rows, res); + break; } } diff --git a/Eigen/src/Core/arch/AltiVec/MatrixVectorProduct.h b/Eigen/src/Core/arch/AltiVec/MatrixVectorProduct.h index bb851fe1b..62840a3e9 100644 --- a/Eigen/src/Core/arch/AltiVec/MatrixVectorProduct.h +++ b/Eigen/src/Core/arch/AltiVec/MatrixVectorProduct.h @@ -657,7 +657,7 @@ template EIGEN_ALWAYS_INLINE void convertPointerF32toBF16VSX(Index& i, float* result, Index rows, bfloat16*& dst, Index resInc = 1) { constexpr Index extra = ((size < 8) ? 8 : size); - for(; i + size <= rows; i += extra, dst += extra*resInc){ + while (i + size <= rows) { PacketBlock r32; r32.packet[0] = convertF32toBF16VSX(result + i + 0); if (size >= 16) { @@ -675,6 +675,8 @@ EIGEN_ALWAYS_INLINE void convertPointerF32toBF16VSX(Index& i, float* result, Ind storeBF16fromResult(dst, r32.packet[2], resInc); storeBF16fromResult(dst, r32.packet[3], resInc); } + i += extra; dst += extra*resInc; + if (size != 32) break; } } diff --git a/Eigen/src/Core/arch/AltiVec/PacketMath.h b/Eigen/src/Core/arch/AltiVec/PacketMath.h index 5f4ccda5e..b24796e4a 100644 --- a/Eigen/src/Core/arch/AltiVec/PacketMath.h +++ b/Eigen/src/Core/arch/AltiVec/PacketMath.h @@ -1154,6 +1154,7 @@ template<> EIGEN_STRONG_INLINE Packet8bf por(const Packet8bf& a, cons template<> EIGEN_STRONG_INLINE Packet4f pxor(const Packet4f& a, const Packet4f& b) { return vec_xor(a, b); } template<> EIGEN_STRONG_INLINE Packet4i pxor(const Packet4i& a, const Packet4i& b) { return vec_xor(a, b); } +template<> EIGEN_STRONG_INLINE Packet8us pxor(const Packet8us& a, const Packet8us& b) { return vec_xor(a, b); } template<> EIGEN_STRONG_INLINE Packet8bf pxor(const Packet8bf& a, const Packet8bf& b) { return pxor(a, b); } @@ -1884,7 +1885,8 @@ template<> EIGEN_STRONG_INLINE Packet8bf pdiv(const Packet8bf& a, con } template<> EIGEN_STRONG_INLINE Packet8bf pnegate(const Packet8bf& a) { - BF16_TO_F32_UNARY_OP_WRAPPER(pnegate, a); + EIGEN_DECLARE_CONST_FAST_Packet8us(neg_mask,0x8000); + return pxor(p8us_neg_mask, a); } template<> EIGEN_STRONG_INLINE Packet8bf psub(const Packet8bf& a, const Packet8bf& b) {