diff --git a/Eigen/src/Core/arch/Default/Half.h b/Eigen/src/Core/arch/Default/Half.h index a8cb228c1..95697f3cf 100644 --- a/Eigen/src/Core/arch/Default/Half.h +++ b/Eigen/src/Core/arch/Default/Half.h @@ -522,11 +522,6 @@ EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC numext::uint16_t raw_half_as_uint16(const #endif } -union float32_bits { - unsigned int u; - float f; -}; - EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC __half_raw float_to_half_rtne(float ff) { #if (defined(EIGEN_HAS_CUDA_FP16) && defined(EIGEN_CUDA_ARCH) && EIGEN_CUDA_ARCH >= 300) || \ (defined(EIGEN_HAS_HIP_FP16) && defined(EIGEN_HIP_DEVICE_COMPILE)) @@ -549,46 +544,45 @@ EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC __half_raw float_to_half_rtne(float ff) { return h; #else - float32_bits f; - f.f = ff; - - const float32_bits f32infty = {255 << 23}; - const float32_bits f16max = {(127 + 16) << 23}; - const float32_bits denorm_magic = {((127 - 15) + (23 - 10) + 1) << 23}; - unsigned int sign_mask = 0x80000000u; + uint32_t f_bits = Eigen::numext::bit_cast(ff); + const uint32_t f32infty_bits = {255 << 23}; + const uint32_t f16max_bits = {(127 + 16) << 23}; + const uint32_t denorm_magic_bits = {((127 - 15) + (23 - 10) + 1) << 23}; + const uint32_t sign_mask = 0x80000000u; __half_raw o; - o.x = static_cast(0x0u); + o.x = static_cast(0x0u); - unsigned int sign = f.u & sign_mask; - f.u ^= sign; + const uint32_t sign = f_bits & sign_mask; + f_bits ^= sign; // NOTE all the integer compares in this function can be safely // compiled into signed compares since all operands are below // 0x80000000. Important if you want fast straight SSE2 code // (since there's no unsigned PCMPGTD). - if (f.u >= f16max.u) { // result is Inf or NaN (all exponent bits set) - o.x = (f.u > f32infty.u) ? 0x7e00 : 0x7c00; // NaN->qNaN and Inf->Inf - } else { // (De)normalized number or zero - if (f.u < (113 << 23)) { // resulting FP16 is subnormal or zero + if (f_bits >= f16max_bits) { // result is Inf or NaN (all exponent bits set) + o.x = (f_bits > f32infty_bits) ? 0x7e00 : 0x7c00; // NaN->qNaN and Inf->Inf + } else { // (De)normalized number or zero + if (f_bits < (113 << 23)) { // resulting FP16 is subnormal or zero // use a magic value to align our 10 mantissa bits at the bottom of // the float. as long as FP addition is round-to-nearest-even this // just works. - f.f += denorm_magic.f; + f_bits = Eigen::numext::bit_cast(Eigen::numext::bit_cast(f_bits) + + Eigen::numext::bit_cast(denorm_magic_bits)); // and one integer subtract of the bias later, we have our final float! - o.x = static_cast(f.u - denorm_magic.u); + o.x = static_cast(f_bits - denorm_magic_bits); } else { - unsigned int mant_odd = (f.u >> 13) & 1; // resulting mantissa is odd + const uint32_t mant_odd = (f_bits >> 13) & 1; // resulting mantissa is odd // update exponent, rounding bias part 1 // Equivalent to `f.u += ((unsigned int)(15 - 127) << 23) + 0xfff`, but // without arithmetic overflow. - f.u += 0xc8000fffU; + f_bits += 0xc8000fffU; // rounding bias part 2 - f.u += mant_odd; + f_bits += mant_odd; // take the bits! - o.x = static_cast(f.u >> 13); + o.x = static_cast(f_bits >> 13); } } @@ -611,24 +605,23 @@ EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC float half_to_float(__half_raw h) { #elif defined(EIGEN_HAS_ARM64_FP16_SCALAR_ARITHMETIC) return static_cast(h.x); #else - const float32_bits magic = {113 << 23}; - const unsigned int shifted_exp = 0x7c00 << 13; // exponent mask after shift - float32_bits o; - - o.u = (h.x & 0x7fff) << 13; // exponent/mantissa bits - unsigned int exp = shifted_exp & o.u; // just the exponent - o.u += (127 - 15) << 23; // exponent adjust + const float magic = Eigen::numext::bit_cast(static_cast(113 << 23)); + const uint32_t shifted_exp = 0x7c00 << 13; // exponent mask after shift + uint32_t o_bits = (h.x & 0x7fff) << 13; // exponent/mantissa bits + const uint32_t exp = shifted_exp & o_bits; // just the exponent + o_bits += (127 - 15) << 23; // exponent adjust // handle exponent special cases - if (exp == shifted_exp) { // Inf/NaN? - o.u += (128 - 16) << 23; // extra exp adjust - } else if (exp == 0) { // Zero/Denormal? - o.u += 1 << 23; // extra exp adjust - o.f -= magic.f; // renormalize + if (exp == shifted_exp) { // Inf/NaN? + o_bits += (128 - 16) << 23; // extra exp adjust + } else if (exp == 0) { // Zero/Denormal? + o_bits += 1 << 23; // extra exp adjust + // renormalize + o_bits = Eigen::numext::bit_cast(Eigen::numext::bit_cast(o_bits) - magic); } - o.u |= (h.x & 0x8000) << 16; // sign bit - return o.f; + o_bits |= (h.x & 0x8000) << 16; // sign bit + return Eigen::numext::bit_cast(o_bits); #endif }