mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-09-23 14:53:13 +08:00
Unroll F32 to BF16 loop - 1.8X faster conversions for LLVM. Use vector pairs for GCC.
This commit is contained in:
parent
874f5947f4
commit
6418ac0285
@ -2839,13 +2839,13 @@ EIGEN_ALWAYS_INLINE Packet4f oneConvertBF16Perm(Packet8us data, Packet16uc mask)
|
|||||||
}
|
}
|
||||||
|
|
||||||
template<bool lhsExtraRows, bool odd, Index size>
|
template<bool lhsExtraRows, bool odd, Index size>
|
||||||
EIGEN_ALWAYS_INLINE void convertArrayPointerBF16toF32DupOne(float *result, Index col, Index rows, const bfloat16* src, Index extra_rows)
|
EIGEN_ALWAYS_INLINE void convertArrayPointerBF16toF32DupOne(float *result, Index rows, const bfloat16* src, Index extra_rows)
|
||||||
{
|
{
|
||||||
Packet4f dup[4*4];
|
Packet4f dup[4*4];
|
||||||
Packet8bf data[4];
|
Packet8bf data[4];
|
||||||
|
|
||||||
for (Index i = 0; i < size; i++) {
|
for (Index i = 0; i < size; i++) {
|
||||||
data[i] = ploadu<Packet8bf>(src + col + rows*i);
|
data[i] = ploadu<Packet8bf>(src + rows*i);
|
||||||
}
|
}
|
||||||
|
|
||||||
for (Index i = 0, j = 0; i < size; i++, j += 4) {
|
for (Index i = 0, j = 0; i < size; i++, j += 4) {
|
||||||
@ -2876,15 +2876,16 @@ EIGEN_ALWAYS_INLINE void convertArrayPointerBF16toF32DupOne(float *result, Index
|
|||||||
template<bool lhsExtraRows>
|
template<bool lhsExtraRows>
|
||||||
EIGEN_ALWAYS_INLINE void convertArrayPointerBF16toF32Dup(float *result, Index cols, Index rows, const bfloat16* src, Index delta, Index extra_rows)
|
EIGEN_ALWAYS_INLINE void convertArrayPointerBF16toF32Dup(float *result, Index cols, Index rows, const bfloat16* src, Index delta, Index extra_rows)
|
||||||
{
|
{
|
||||||
Index col2 = 0, col = 0;
|
Index col = 0;
|
||||||
for(; col + 4*2 <= cols; col += 4*2, col2 += 4*rows, result += 4*4*4) {
|
src += delta*2;
|
||||||
convertArrayPointerBF16toF32DupOne<lhsExtraRows,false,4>(result, col2 + delta*2, rows, src, extra_rows);
|
for(; col + 4*2 <= cols; col += 4*2, result += 4*4*4, src += 4*rows) {
|
||||||
|
convertArrayPointerBF16toF32DupOne<lhsExtraRows,false,4>(result, rows, src, extra_rows);
|
||||||
}
|
}
|
||||||
for(; col + 2 <= cols; col += 2, col2 += rows, result += 4*4) {
|
for(; col + 2 <= cols; col += 2, result += 4*4, src += rows) {
|
||||||
convertArrayPointerBF16toF32DupOne<lhsExtraRows,false,1>(result, col2 + delta*2, rows, src, extra_rows);
|
convertArrayPointerBF16toF32DupOne<lhsExtraRows,false,1>(result, rows, src, extra_rows);
|
||||||
}
|
}
|
||||||
if (cols & 1) {
|
if (cols & 1) {
|
||||||
convertArrayPointerBF16toF32DupOne<lhsExtraRows,true,1>(result, col2 + delta, rows, src, extra_rows);
|
convertArrayPointerBF16toF32DupOne<lhsExtraRows,true,1>(result, rows, src - delta, extra_rows);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -2892,7 +2893,7 @@ template<const Index size, bool non_unit_stride>
|
|||||||
EIGEN_ALWAYS_INLINE void convertPointerBF16toF32(Index& i, float *result, Index rows, bfloat16*& src, Index resInc)
|
EIGEN_ALWAYS_INLINE void convertPointerBF16toF32(Index& i, float *result, Index rows, bfloat16*& src, Index resInc)
|
||||||
{
|
{
|
||||||
constexpr Index extra = ((size < 4) ? 4 : size);
|
constexpr Index extra = ((size < 4) ? 4 : size);
|
||||||
for(; i + size <= rows; i += extra, src += extra*resInc){
|
while (i + size <= rows) {
|
||||||
PacketBlock<Packet8bf,(size+7)/8> r32;
|
PacketBlock<Packet8bf,(size+7)/8> r32;
|
||||||
r32.packet[0] = loadBF16fromResult<non_unit_stride, 0>(src, resInc);
|
r32.packet[0] = loadBF16fromResult<non_unit_stride, 0>(src, resInc);
|
||||||
if (size >= 16) {
|
if (size >= 16) {
|
||||||
@ -2903,6 +2904,8 @@ EIGEN_ALWAYS_INLINE void convertPointerBF16toF32(Index& i, float *result, Index
|
|||||||
r32.packet[3] = loadBF16fromResult<non_unit_stride, 24>(src, resInc);
|
r32.packet[3] = loadBF16fromResult<non_unit_stride, 24>(src, resInc);
|
||||||
}
|
}
|
||||||
storeConvertBlockBF16<size>(result + i, r32, rows & 3);
|
storeConvertBlockBF16<size>(result + i, r32, rows & 3);
|
||||||
|
i += extra; src += extra*resInc;
|
||||||
|
if (size != 32) break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -3131,7 +3134,7 @@ template<const Index size, typename DataMapper>
|
|||||||
EIGEN_ALWAYS_INLINE void convertBF16toF32(Index& i, float *result, Index rows, const DataMapper& src)
|
EIGEN_ALWAYS_INLINE void convertBF16toF32(Index& i, float *result, Index rows, const DataMapper& src)
|
||||||
{
|
{
|
||||||
constexpr Index extra = ((size < 4) ? 4 : size);
|
constexpr Index extra = ((size < 4) ? 4 : size);
|
||||||
for(; i + size <= rows; i += extra){
|
while (i + size <= rows) {
|
||||||
PacketBlock<Packet8bf,(size+7)/8> r32;
|
PacketBlock<Packet8bf,(size+7)/8> r32;
|
||||||
r32.packet[0] = src.template loadPacket<Packet8bf>(i + 0);
|
r32.packet[0] = src.template loadPacket<Packet8bf>(i + 0);
|
||||||
if (size >= 16) {
|
if (size >= 16) {
|
||||||
@ -3142,6 +3145,8 @@ EIGEN_ALWAYS_INLINE void convertBF16toF32(Index& i, float *result, Index rows, c
|
|||||||
r32.packet[3] = src.template loadPacket<Packet8bf>(i + 24);
|
r32.packet[3] = src.template loadPacket<Packet8bf>(i + 24);
|
||||||
}
|
}
|
||||||
storeConvertBlockBF16<size>(result + i, r32, rows & 3);
|
storeConvertBlockBF16<size>(result + i, r32, rows & 3);
|
||||||
|
i += extra;
|
||||||
|
if (size != 32) break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -3171,18 +3176,18 @@ EIGEN_ALWAYS_INLINE void convertArrayF32toBF16ColVSX(float *result, Index col, I
|
|||||||
const DataMapper res2 = res.getSubMapper(0, col);
|
const DataMapper res2 = res.getSubMapper(0, col);
|
||||||
Index row;
|
Index row;
|
||||||
float *result2 = result + col*rows;
|
float *result2 = result + col*rows;
|
||||||
for(row = 0; row + 8 <= rows; row += 8){
|
for(row = 0; row + 8 <= rows; row += 8, result2 += 8){
|
||||||
// get and save block
|
// get and save block
|
||||||
PacketBlock<Packet8bf,size> block;
|
PacketBlock<Packet8bf,size> block;
|
||||||
for(Index j = 0; j < size; j++){
|
for(Index j = 0; j < size; j++){
|
||||||
block.packet[j] = convertF32toBF16VSX(result2 + j*rows + row);
|
block.packet[j] = convertF32toBF16VSX(result2 + j*rows);
|
||||||
}
|
}
|
||||||
res2.template storePacketBlock<Packet8bf,size>(row, 0, block);
|
res2.template storePacketBlock<Packet8bf,size>(row, 0, block);
|
||||||
}
|
}
|
||||||
// extra rows
|
// extra rows
|
||||||
if(row < rows){
|
if(row < rows){
|
||||||
for(Index j = 0; j < size; j++){
|
for(Index j = 0; j < size; j++){
|
||||||
Packet8bf fp16 = convertF32toBF16VSX(result2 + j*rows + row);
|
Packet8bf fp16 = convertF32toBF16VSX(result2 + j*rows);
|
||||||
res2.template storePacketPartial<Packet8bf>(row, j, fp16, rows & 7);
|
res2.template storePacketPartial<Packet8bf>(row, j, fp16, rows & 7);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -3196,9 +3201,16 @@ EIGEN_ALWAYS_INLINE void convertArrayF32toBF16VSX(float *result, Index cols, Ind
|
|||||||
convertArrayF32toBF16ColVSX<DataMapper,4>(result, col, rows, res);
|
convertArrayF32toBF16ColVSX<DataMapper,4>(result, col, rows, res);
|
||||||
}
|
}
|
||||||
// extra cols
|
// extra cols
|
||||||
while(col < cols){
|
switch (cols - col) {
|
||||||
|
case 1:
|
||||||
convertArrayF32toBF16ColVSX<DataMapper,1>(result, col, rows, res);
|
convertArrayF32toBF16ColVSX<DataMapper,1>(result, col, rows, res);
|
||||||
col++;
|
break;
|
||||||
|
case 2:
|
||||||
|
convertArrayF32toBF16ColVSX<DataMapper,2>(result, col, rows, res);
|
||||||
|
break;
|
||||||
|
case 3:
|
||||||
|
convertArrayF32toBF16ColVSX<DataMapper,3>(result, col, rows, res);
|
||||||
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -215,15 +215,10 @@ EIGEN_ALWAYS_INLINE void colLoops(Index depth, Index cols, Index rows, const Pac
|
|||||||
EIGEN_ALWAYS_INLINE Packet8bf convertF32toBF16(const float *res)
|
EIGEN_ALWAYS_INLINE Packet8bf convertF32toBF16(const float *res)
|
||||||
{
|
{
|
||||||
Packet16uc fp16[2];
|
Packet16uc fp16[2];
|
||||||
#if EIGEN_COMP_LLVM
|
|
||||||
__vector_pair fp16_vp = *reinterpret_cast<__vector_pair *>(const_cast<float *>(res));
|
__vector_pair fp16_vp = *reinterpret_cast<__vector_pair *>(const_cast<float *>(res));
|
||||||
__builtin_vsx_disassemble_pair(reinterpret_cast<void*>(fp16), &fp16_vp);
|
__builtin_vsx_disassemble_pair(reinterpret_cast<void*>(fp16), &fp16_vp);
|
||||||
fp16[0] = __builtin_vsx_xvcvspbf16(fp16[0]);
|
fp16[0] = __builtin_vsx_xvcvspbf16(fp16[0]);
|
||||||
fp16[1] = __builtin_vsx_xvcvspbf16(fp16[1]);
|
fp16[1] = __builtin_vsx_xvcvspbf16(fp16[1]);
|
||||||
#else
|
|
||||||
fp16[0] = __builtin_vsx_xvcvspbf16(reinterpret_cast<Packet16uc>(ploadu<Packet4f>(res + 0)));
|
|
||||||
fp16[1] = __builtin_vsx_xvcvspbf16(reinterpret_cast<Packet16uc>(ploadu<Packet4f>(res + 4)));
|
|
||||||
#endif
|
|
||||||
return vec_pack(reinterpret_cast<Packet4ui>(fp16[0]), reinterpret_cast<Packet4ui>(fp16[1]));
|
return vec_pack(reinterpret_cast<Packet4ui>(fp16[0]), reinterpret_cast<Packet4ui>(fp16[1]));
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -233,18 +228,20 @@ EIGEN_ALWAYS_INLINE void convertArrayF32toBF16Col(float *result, Index col, Inde
|
|||||||
const DataMapper res2 = res.getSubMapper(0, col);
|
const DataMapper res2 = res.getSubMapper(0, col);
|
||||||
Index row;
|
Index row;
|
||||||
float *result2 = result + col*rows;
|
float *result2 = result + col*rows;
|
||||||
for(row = 0; row + 8 <= rows; row += 8){
|
for(row = 0; row + 8 <= rows; row += 8, result2 += 8){
|
||||||
// get and save block
|
// get and save block
|
||||||
PacketBlock<Packet8bf,size> block;
|
PacketBlock<Packet8bf,size> block;
|
||||||
|
BFLOAT16_UNROLL
|
||||||
for(Index j = 0; j < size; j++){
|
for(Index j = 0; j < size; j++){
|
||||||
block.packet[j] = convertF32toBF16(result2 + j*rows + row);
|
block.packet[j] = convertF32toBF16(result2 + j*rows);
|
||||||
}
|
}
|
||||||
res2.template storePacketBlock<Packet8bf,size>(row, 0, block);
|
res2.template storePacketBlock<Packet8bf,size>(row, 0, block);
|
||||||
}
|
}
|
||||||
// extra rows
|
// extra rows
|
||||||
if(row < rows){
|
if(row < rows){
|
||||||
|
BFLOAT16_UNROLL
|
||||||
for(Index j = 0; j < size; j++){
|
for(Index j = 0; j < size; j++){
|
||||||
Packet8bf fp16 = convertF32toBF16(result2 + j*rows + row);
|
Packet8bf fp16 = convertF32toBF16(result2 + j*rows);
|
||||||
res2.template storePacketPartial<Packet8bf>(row, j, fp16, rows & 7);
|
res2.template storePacketPartial<Packet8bf>(row, j, fp16, rows & 7);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -254,7 +251,7 @@ template<const Index size, bool non_unit_stride = false>
|
|||||||
EIGEN_ALWAYS_INLINE void convertPointerF32toBF16(Index& i, float* result, Index rows, bfloat16*& dst, Index resInc = 1)
|
EIGEN_ALWAYS_INLINE void convertPointerF32toBF16(Index& i, float* result, Index rows, bfloat16*& dst, Index resInc = 1)
|
||||||
{
|
{
|
||||||
constexpr Index extra = ((size < 8) ? 8 : size);
|
constexpr Index extra = ((size < 8) ? 8 : size);
|
||||||
for(; i + size <= rows; i += extra, dst += extra*resInc){
|
while (i + size <= rows){
|
||||||
PacketBlock<Packet8bf,(size+7)/8> r32;
|
PacketBlock<Packet8bf,(size+7)/8> r32;
|
||||||
r32.packet[0] = convertF32toBF16(result + i + 0);
|
r32.packet[0] = convertF32toBF16(result + i + 0);
|
||||||
if (size >= 16) {
|
if (size >= 16) {
|
||||||
@ -272,6 +269,8 @@ EIGEN_ALWAYS_INLINE void convertPointerF32toBF16(Index& i, float* result, Index
|
|||||||
storeBF16fromResult<size, non_unit_stride, 16>(dst, r32.packet[2], resInc);
|
storeBF16fromResult<size, non_unit_stride, 16>(dst, r32.packet[2], resInc);
|
||||||
storeBF16fromResult<size, non_unit_stride, 24>(dst, r32.packet[3], resInc);
|
storeBF16fromResult<size, non_unit_stride, 24>(dst, r32.packet[3], resInc);
|
||||||
}
|
}
|
||||||
|
i += extra; dst += extra*resInc;
|
||||||
|
if (size != 32) break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -293,9 +292,16 @@ EIGEN_ALWAYS_INLINE void convertArrayF32toBF16(float *result, Index cols, Index
|
|||||||
convertArrayF32toBF16Col<DataMapper,4>(result, col, rows, res);
|
convertArrayF32toBF16Col<DataMapper,4>(result, col, rows, res);
|
||||||
}
|
}
|
||||||
// extra cols
|
// extra cols
|
||||||
while(col < cols){
|
switch (cols - col) {
|
||||||
|
case 1:
|
||||||
convertArrayF32toBF16Col<DataMapper,1>(result, col, rows, res);
|
convertArrayF32toBF16Col<DataMapper,1>(result, col, rows, res);
|
||||||
col++;
|
break;
|
||||||
|
case 2:
|
||||||
|
convertArrayF32toBF16Col<DataMapper,2>(result, col, rows, res);
|
||||||
|
break;
|
||||||
|
case 3:
|
||||||
|
convertArrayF32toBF16Col<DataMapper,3>(result, col, rows, res);
|
||||||
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -657,7 +657,7 @@ template<const Index size, bool inc = false>
|
|||||||
EIGEN_ALWAYS_INLINE void convertPointerF32toBF16VSX(Index& i, float* result, Index rows, bfloat16*& dst, Index resInc = 1)
|
EIGEN_ALWAYS_INLINE void convertPointerF32toBF16VSX(Index& i, float* result, Index rows, bfloat16*& dst, Index resInc = 1)
|
||||||
{
|
{
|
||||||
constexpr Index extra = ((size < 8) ? 8 : size);
|
constexpr Index extra = ((size < 8) ? 8 : size);
|
||||||
for(; i + size <= rows; i += extra, dst += extra*resInc){
|
while (i + size <= rows) {
|
||||||
PacketBlock<Packet8bf,(size+7)/8> r32;
|
PacketBlock<Packet8bf,(size+7)/8> r32;
|
||||||
r32.packet[0] = convertF32toBF16VSX(result + i + 0);
|
r32.packet[0] = convertF32toBF16VSX(result + i + 0);
|
||||||
if (size >= 16) {
|
if (size >= 16) {
|
||||||
@ -675,6 +675,8 @@ EIGEN_ALWAYS_INLINE void convertPointerF32toBF16VSX(Index& i, float* result, Ind
|
|||||||
storeBF16fromResult<size, inc, 16>(dst, r32.packet[2], resInc);
|
storeBF16fromResult<size, inc, 16>(dst, r32.packet[2], resInc);
|
||||||
storeBF16fromResult<size, inc, 24>(dst, r32.packet[3], resInc);
|
storeBF16fromResult<size, inc, 24>(dst, r32.packet[3], resInc);
|
||||||
}
|
}
|
||||||
|
i += extra; dst += extra*resInc;
|
||||||
|
if (size != 32) break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1154,6 +1154,7 @@ template<> EIGEN_STRONG_INLINE Packet8bf por<Packet8bf>(const Packet8bf& a, cons
|
|||||||
|
|
||||||
template<> EIGEN_STRONG_INLINE Packet4f pxor<Packet4f>(const Packet4f& a, const Packet4f& b) { return vec_xor(a, b); }
|
template<> EIGEN_STRONG_INLINE Packet4f pxor<Packet4f>(const Packet4f& a, const Packet4f& b) { return vec_xor(a, b); }
|
||||||
template<> EIGEN_STRONG_INLINE Packet4i pxor<Packet4i>(const Packet4i& a, const Packet4i& b) { return vec_xor(a, b); }
|
template<> EIGEN_STRONG_INLINE Packet4i pxor<Packet4i>(const Packet4i& a, const Packet4i& b) { return vec_xor(a, b); }
|
||||||
|
template<> EIGEN_STRONG_INLINE Packet8us pxor<Packet8us>(const Packet8us& a, const Packet8us& b) { return vec_xor(a, b); }
|
||||||
template<> EIGEN_STRONG_INLINE Packet8bf pxor<Packet8bf>(const Packet8bf& a, const Packet8bf& b) {
|
template<> EIGEN_STRONG_INLINE Packet8bf pxor<Packet8bf>(const Packet8bf& a, const Packet8bf& b) {
|
||||||
return pxor<Packet8us>(a, b);
|
return pxor<Packet8us>(a, b);
|
||||||
}
|
}
|
||||||
@ -1884,7 +1885,8 @@ template<> EIGEN_STRONG_INLINE Packet8bf pdiv<Packet8bf>(const Packet8bf& a, con
|
|||||||
}
|
}
|
||||||
|
|
||||||
template<> EIGEN_STRONG_INLINE Packet8bf pnegate<Packet8bf>(const Packet8bf& a) {
|
template<> EIGEN_STRONG_INLINE Packet8bf pnegate<Packet8bf>(const Packet8bf& a) {
|
||||||
BF16_TO_F32_UNARY_OP_WRAPPER(pnegate<Packet4f>, a);
|
EIGEN_DECLARE_CONST_FAST_Packet8us(neg_mask,0x8000);
|
||||||
|
return pxor<Packet8us>(p8us_neg_mask, a);
|
||||||
}
|
}
|
||||||
|
|
||||||
template<> EIGEN_STRONG_INLINE Packet8bf psub<Packet8bf>(const Packet8bf& a, const Packet8bf& b) {
|
template<> EIGEN_STRONG_INLINE Packet8bf psub<Packet8bf>(const Packet8bf& a, const Packet8bf& b) {
|
||||||
|
Loading…
x
Reference in New Issue
Block a user