mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-04-21 09:09:36 +08:00
Eliminate type-punning UB in Eigen::half.
This commit is contained in:
parent
420d891de7
commit
6b4881ad48
@ -522,11 +522,6 @@ EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC numext::uint16_t raw_half_as_uint16(const
|
|||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
union float32_bits {
|
|
||||||
unsigned int u;
|
|
||||||
float f;
|
|
||||||
};
|
|
||||||
|
|
||||||
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC __half_raw float_to_half_rtne(float ff) {
|
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC __half_raw float_to_half_rtne(float ff) {
|
||||||
#if (defined(EIGEN_HAS_CUDA_FP16) && defined(EIGEN_CUDA_ARCH) && EIGEN_CUDA_ARCH >= 300) || \
|
#if (defined(EIGEN_HAS_CUDA_FP16) && defined(EIGEN_CUDA_ARCH) && EIGEN_CUDA_ARCH >= 300) || \
|
||||||
(defined(EIGEN_HAS_HIP_FP16) && defined(EIGEN_HIP_DEVICE_COMPILE))
|
(defined(EIGEN_HAS_HIP_FP16) && defined(EIGEN_HIP_DEVICE_COMPILE))
|
||||||
@ -549,46 +544,45 @@ EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC __half_raw float_to_half_rtne(float ff) {
|
|||||||
return h;
|
return h;
|
||||||
|
|
||||||
#else
|
#else
|
||||||
float32_bits f;
|
uint32_t f_bits = Eigen::numext::bit_cast<uint32_t>(ff);
|
||||||
f.f = ff;
|
const uint32_t f32infty_bits = {255 << 23};
|
||||||
|
const uint32_t f16max_bits = {(127 + 16) << 23};
|
||||||
const float32_bits f32infty = {255 << 23};
|
const uint32_t denorm_magic_bits = {((127 - 15) + (23 - 10) + 1) << 23};
|
||||||
const float32_bits f16max = {(127 + 16) << 23};
|
const uint32_t sign_mask = 0x80000000u;
|
||||||
const float32_bits denorm_magic = {((127 - 15) + (23 - 10) + 1) << 23};
|
|
||||||
unsigned int sign_mask = 0x80000000u;
|
|
||||||
__half_raw o;
|
__half_raw o;
|
||||||
o.x = static_cast<numext::uint16_t>(0x0u);
|
o.x = static_cast<uint16_t>(0x0u);
|
||||||
|
|
||||||
unsigned int sign = f.u & sign_mask;
|
const uint32_t sign = f_bits & sign_mask;
|
||||||
f.u ^= sign;
|
f_bits ^= sign;
|
||||||
|
|
||||||
// NOTE all the integer compares in this function can be safely
|
// NOTE all the integer compares in this function can be safely
|
||||||
// compiled into signed compares since all operands are below
|
// compiled into signed compares since all operands are below
|
||||||
// 0x80000000. Important if you want fast straight SSE2 code
|
// 0x80000000. Important if you want fast straight SSE2 code
|
||||||
// (since there's no unsigned PCMPGTD).
|
// (since there's no unsigned PCMPGTD).
|
||||||
|
|
||||||
if (f.u >= f16max.u) { // result is Inf or NaN (all exponent bits set)
|
if (f_bits >= f16max_bits) { // result is Inf or NaN (all exponent bits set)
|
||||||
o.x = (f.u > f32infty.u) ? 0x7e00 : 0x7c00; // NaN->qNaN and Inf->Inf
|
o.x = (f_bits > f32infty_bits) ? 0x7e00 : 0x7c00; // NaN->qNaN and Inf->Inf
|
||||||
} else { // (De)normalized number or zero
|
} else { // (De)normalized number or zero
|
||||||
if (f.u < (113 << 23)) { // resulting FP16 is subnormal or zero
|
if (f_bits < (113 << 23)) { // resulting FP16 is subnormal or zero
|
||||||
// use a magic value to align our 10 mantissa bits at the bottom of
|
// use a magic value to align our 10 mantissa bits at the bottom of
|
||||||
// the float. as long as FP addition is round-to-nearest-even this
|
// the float. as long as FP addition is round-to-nearest-even this
|
||||||
// just works.
|
// just works.
|
||||||
f.f += denorm_magic.f;
|
f_bits = Eigen::numext::bit_cast<uint32_t>(Eigen::numext::bit_cast<float>(f_bits) +
|
||||||
|
Eigen::numext::bit_cast<float>(denorm_magic_bits));
|
||||||
|
|
||||||
// and one integer subtract of the bias later, we have our final float!
|
// and one integer subtract of the bias later, we have our final float!
|
||||||
o.x = static_cast<numext::uint16_t>(f.u - denorm_magic.u);
|
o.x = static_cast<numext::uint16_t>(f_bits - denorm_magic_bits);
|
||||||
} else {
|
} else {
|
||||||
unsigned int mant_odd = (f.u >> 13) & 1; // resulting mantissa is odd
|
const uint32_t mant_odd = (f_bits >> 13) & 1; // resulting mantissa is odd
|
||||||
|
|
||||||
// update exponent, rounding bias part 1
|
// update exponent, rounding bias part 1
|
||||||
// Equivalent to `f.u += ((unsigned int)(15 - 127) << 23) + 0xfff`, but
|
// Equivalent to `f.u += ((unsigned int)(15 - 127) << 23) + 0xfff`, but
|
||||||
// without arithmetic overflow.
|
// without arithmetic overflow.
|
||||||
f.u += 0xc8000fffU;
|
f_bits += 0xc8000fffU;
|
||||||
// rounding bias part 2
|
// rounding bias part 2
|
||||||
f.u += mant_odd;
|
f_bits += mant_odd;
|
||||||
// take the bits!
|
// take the bits!
|
||||||
o.x = static_cast<numext::uint16_t>(f.u >> 13);
|
o.x = static_cast<numext::uint16_t>(f_bits >> 13);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -611,24 +605,23 @@ EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC float half_to_float(__half_raw h) {
|
|||||||
#elif defined(EIGEN_HAS_ARM64_FP16_SCALAR_ARITHMETIC)
|
#elif defined(EIGEN_HAS_ARM64_FP16_SCALAR_ARITHMETIC)
|
||||||
return static_cast<float>(h.x);
|
return static_cast<float>(h.x);
|
||||||
#else
|
#else
|
||||||
const float32_bits magic = {113 << 23};
|
const float magic = Eigen::numext::bit_cast<float>(static_cast<uint32_t>(113 << 23));
|
||||||
const unsigned int shifted_exp = 0x7c00 << 13; // exponent mask after shift
|
const uint32_t shifted_exp = 0x7c00 << 13; // exponent mask after shift
|
||||||
float32_bits o;
|
uint32_t o_bits = (h.x & 0x7fff) << 13; // exponent/mantissa bits
|
||||||
|
const uint32_t exp = shifted_exp & o_bits; // just the exponent
|
||||||
o.u = (h.x & 0x7fff) << 13; // exponent/mantissa bits
|
o_bits += (127 - 15) << 23; // exponent adjust
|
||||||
unsigned int exp = shifted_exp & o.u; // just the exponent
|
|
||||||
o.u += (127 - 15) << 23; // exponent adjust
|
|
||||||
|
|
||||||
// handle exponent special cases
|
// handle exponent special cases
|
||||||
if (exp == shifted_exp) { // Inf/NaN?
|
if (exp == shifted_exp) { // Inf/NaN?
|
||||||
o.u += (128 - 16) << 23; // extra exp adjust
|
o_bits += (128 - 16) << 23; // extra exp adjust
|
||||||
} else if (exp == 0) { // Zero/Denormal?
|
} else if (exp == 0) { // Zero/Denormal?
|
||||||
o.u += 1 << 23; // extra exp adjust
|
o_bits += 1 << 23; // extra exp adjust
|
||||||
o.f -= magic.f; // renormalize
|
// renormalize
|
||||||
|
o_bits = Eigen::numext::bit_cast<uint32_t>(Eigen::numext::bit_cast<float>(o_bits) - magic);
|
||||||
}
|
}
|
||||||
|
|
||||||
o.u |= (h.x & 0x8000) << 16; // sign bit
|
o_bits |= (h.x & 0x8000) << 16; // sign bit
|
||||||
return o.f;
|
return Eigen::numext::bit_cast<float>(o_bits);
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user