Fix alias violation in BFloat16

reinterpret_cast between unrelated types is undefined behavior and leads
to misoptimizations on some platforms.
Use the safer (and faster) version via bit_cast
This commit is contained in:
Alexander Grund 2021-09-20 10:37:50 +02:00
parent 4d622be118
commit b5eaa42695
No known key found for this signature in database
GPG Key ID: E92C451FC21EF13F

View File

@ -253,12 +253,7 @@ EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC __bfloat16_raw truncate_to_bfloat16(const
output.value = std::signbit(v) ? 0xFFC0: 0x7FC0;
return output;
}
const uint16_t* p = reinterpret_cast<const uint16_t*>(&v);
#if defined(__BYTE_ORDER__) && __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
output.value = p[0];
#else
output.value = p[1];
#endif
output.value = static_cast<numext::uint16_t>(numext::bit_cast<numext::uint32_t>(v) >> 16);
return output;
}
@ -464,14 +459,7 @@ EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC __bfloat16_raw float_to_bfloat16_rtne<true
}
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC float bfloat16_to_float(__bfloat16_raw h) {
float result = 0;
unsigned short* q = reinterpret_cast<unsigned short*>(&result);
#if defined(__BYTE_ORDER__) && __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
q[0] = h.value;
#else
q[1] = h.value;
#endif
return result;
return numext::bit_cast<float>(static_cast<numext::uint32_t>(h.value) << 16);
}
// --- standard functions ---