Eliminate type-punning UB in Eigen::half.

This commit is contained in:
Antonio Sánchez 2025-02-12 21:12:33 +00:00
parent 420d891de7
commit 6b4881ad48

View File

@ -522,11 +522,6 @@ EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC numext::uint16_t raw_half_as_uint16(const
#endif #endif
} }
union float32_bits {
unsigned int u;
float f;
};
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC __half_raw float_to_half_rtne(float ff) { EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC __half_raw float_to_half_rtne(float ff) {
#if (defined(EIGEN_HAS_CUDA_FP16) && defined(EIGEN_CUDA_ARCH) && EIGEN_CUDA_ARCH >= 300) || \ #if (defined(EIGEN_HAS_CUDA_FP16) && defined(EIGEN_CUDA_ARCH) && EIGEN_CUDA_ARCH >= 300) || \
(defined(EIGEN_HAS_HIP_FP16) && defined(EIGEN_HIP_DEVICE_COMPILE)) (defined(EIGEN_HAS_HIP_FP16) && defined(EIGEN_HIP_DEVICE_COMPILE))
@ -549,46 +544,45 @@ EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC __half_raw float_to_half_rtne(float ff) {
return h; return h;
#else #else
float32_bits f; uint32_t f_bits = Eigen::numext::bit_cast<uint32_t>(ff);
f.f = ff; const uint32_t f32infty_bits = {255 << 23};
const uint32_t f16max_bits = {(127 + 16) << 23};
const float32_bits f32infty = {255 << 23}; const uint32_t denorm_magic_bits = {((127 - 15) + (23 - 10) + 1) << 23};
const float32_bits f16max = {(127 + 16) << 23}; const uint32_t sign_mask = 0x80000000u;
const float32_bits denorm_magic = {((127 - 15) + (23 - 10) + 1) << 23};
unsigned int sign_mask = 0x80000000u;
__half_raw o; __half_raw o;
o.x = static_cast<numext::uint16_t>(0x0u); o.x = static_cast<uint16_t>(0x0u);
unsigned int sign = f.u & sign_mask; const uint32_t sign = f_bits & sign_mask;
f.u ^= sign; f_bits ^= sign;
// NOTE all the integer compares in this function can be safely // NOTE all the integer compares in this function can be safely
// compiled into signed compares since all operands are below // compiled into signed compares since all operands are below
// 0x80000000. Important if you want fast straight SSE2 code // 0x80000000. Important if you want fast straight SSE2 code
// (since there's no unsigned PCMPGTD). // (since there's no unsigned PCMPGTD).
if (f.u >= f16max.u) { // result is Inf or NaN (all exponent bits set) if (f_bits >= f16max_bits) { // result is Inf or NaN (all exponent bits set)
o.x = (f.u > f32infty.u) ? 0x7e00 : 0x7c00; // NaN->qNaN and Inf->Inf o.x = (f_bits > f32infty_bits) ? 0x7e00 : 0x7c00; // NaN->qNaN and Inf->Inf
} else { // (De)normalized number or zero } else { // (De)normalized number or zero
if (f.u < (113 << 23)) { // resulting FP16 is subnormal or zero if (f_bits < (113 << 23)) { // resulting FP16 is subnormal or zero
// use a magic value to align our 10 mantissa bits at the bottom of // use a magic value to align our 10 mantissa bits at the bottom of
// the float. as long as FP addition is round-to-nearest-even this // the float. as long as FP addition is round-to-nearest-even this
// just works. // just works.
f.f += denorm_magic.f; f_bits = Eigen::numext::bit_cast<uint32_t>(Eigen::numext::bit_cast<float>(f_bits) +
Eigen::numext::bit_cast<float>(denorm_magic_bits));
// and one integer subtract of the bias later, we have our final float! // and one integer subtract of the bias later, we have our final float!
o.x = static_cast<numext::uint16_t>(f.u - denorm_magic.u); o.x = static_cast<numext::uint16_t>(f_bits - denorm_magic_bits);
} else { } else {
unsigned int mant_odd = (f.u >> 13) & 1; // resulting mantissa is odd const uint32_t mant_odd = (f_bits >> 13) & 1; // resulting mantissa is odd
// update exponent, rounding bias part 1 // update exponent, rounding bias part 1
// Equivalent to `f.u += ((unsigned int)(15 - 127) << 23) + 0xfff`, but // Equivalent to `f.u += ((unsigned int)(15 - 127) << 23) + 0xfff`, but
// without arithmetic overflow. // without arithmetic overflow.
f.u += 0xc8000fffU; f_bits += 0xc8000fffU;
// rounding bias part 2 // rounding bias part 2
f.u += mant_odd; f_bits += mant_odd;
// take the bits! // take the bits!
o.x = static_cast<numext::uint16_t>(f.u >> 13); o.x = static_cast<numext::uint16_t>(f_bits >> 13);
} }
} }
@ -611,24 +605,23 @@ EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC float half_to_float(__half_raw h) {
#elif defined(EIGEN_HAS_ARM64_FP16_SCALAR_ARITHMETIC) #elif defined(EIGEN_HAS_ARM64_FP16_SCALAR_ARITHMETIC)
return static_cast<float>(h.x); return static_cast<float>(h.x);
#else #else
const float32_bits magic = {113 << 23}; const float magic = Eigen::numext::bit_cast<float>(static_cast<uint32_t>(113 << 23));
const unsigned int shifted_exp = 0x7c00 << 13; // exponent mask after shift const uint32_t shifted_exp = 0x7c00 << 13; // exponent mask after shift
float32_bits o; uint32_t o_bits = (h.x & 0x7fff) << 13; // exponent/mantissa bits
const uint32_t exp = shifted_exp & o_bits; // just the exponent
o.u = (h.x & 0x7fff) << 13; // exponent/mantissa bits o_bits += (127 - 15) << 23; // exponent adjust
unsigned int exp = shifted_exp & o.u; // just the exponent
o.u += (127 - 15) << 23; // exponent adjust
// handle exponent special cases // handle exponent special cases
if (exp == shifted_exp) { // Inf/NaN? if (exp == shifted_exp) { // Inf/NaN?
o.u += (128 - 16) << 23; // extra exp adjust o_bits += (128 - 16) << 23; // extra exp adjust
} else if (exp == 0) { // Zero/Denormal? } else if (exp == 0) { // Zero/Denormal?
o.u += 1 << 23; // extra exp adjust o_bits += 1 << 23; // extra exp adjust
o.f -= magic.f; // renormalize // renormalize
o_bits = Eigen::numext::bit_cast<uint32_t>(Eigen::numext::bit_cast<float>(o_bits) - magic);
} }
o.u |= (h.x & 0x8000) << 16; // sign bit o_bits |= (h.x & 0x8000) << 16; // sign bit
return o.f; return Eigen::numext::bit_cast<float>(o_bits);
#endif #endif
} }