Convert bit calculation to constexpr, avoid casts.

This commit is contained in:
Tobias Schlüter 2022-03-13 22:27:06 +09:00
parent baf9a985ec
commit cb1e8228e9

View File

@ -28,11 +28,11 @@ template<> struct make_integer<double> { typedef numext::int64_t type; };
template<> struct make_integer<half> { typedef numext::int16_t type; };
template<> struct make_integer<bfloat16> { typedef numext::int16_t type; };
template<typename Packet> EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC
template<typename Packet> EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC
Packet pfrexp_generic_get_biased_exponent(const Packet& a) {
typedef typename unpacket_traits<Packet>::type Scalar;
typedef typename unpacket_traits<Packet>::integer_packet PacketI;
enum { mantissa_bits = numext::numeric_limits<Scalar>::digits - 1};
static constexpr int mantissa_bits = numext::numeric_limits<Scalar>::digits - 1;
return pcast<PacketI, Packet>(plogical_shift_right<mantissa_bits>(preinterpret<PacketI>(pabs(a))));
}
@ -42,42 +42,41 @@ template<typename Packet> EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC
Packet pfrexp_generic(const Packet& a, Packet& exponent) {
typedef typename unpacket_traits<Packet>::type Scalar;
typedef typename make_unsigned<typename make_integer<Scalar>::type>::type ScalarUI;
enum {
static constexpr int
TotalBits = sizeof(Scalar) * CHAR_BIT,
MantissaBits = numext::numeric_limits<Scalar>::digits - 1,
ExponentBits = int(TotalBits) - int(MantissaBits) - 1
};
ExponentBits = TotalBits - MantissaBits - 1;
EIGEN_CONSTEXPR ScalarUI scalar_sign_mantissa_mask =
~(((ScalarUI(1) << int(ExponentBits)) - ScalarUI(1)) << int(MantissaBits)); // ~0x7f800000
const Packet sign_mantissa_mask = pset1frombits<Packet>(static_cast<ScalarUI>(scalar_sign_mantissa_mask));
EIGEN_CONSTEXPR ScalarUI scalar_sign_mantissa_mask =
~(((ScalarUI(1) << ExponentBits) - ScalarUI(1)) << MantissaBits); // ~0x7f800000
const Packet sign_mantissa_mask = pset1frombits<Packet>(static_cast<ScalarUI>(scalar_sign_mantissa_mask));
const Packet half = pset1<Packet>(Scalar(0.5));
const Packet zero = pzero(a);
const Packet normal_min = pset1<Packet>((numext::numeric_limits<Scalar>::min)()); // Minimum normal value, 2^-126
// To handle denormals, normalize by multiplying by 2^(int(MantissaBits)+1).
const Packet is_denormal = pcmp_lt(pabs(a), normal_min);
EIGEN_CONSTEXPR ScalarUI scalar_normalization_offset = ScalarUI(int(MantissaBits) + 1); // 24
EIGEN_CONSTEXPR ScalarUI scalar_normalization_offset = ScalarUI(MantissaBits + 1); // 24
// The following cannot be constexpr because bfloat16(uint16_t) is not constexpr.
const Scalar scalar_normalization_factor = Scalar(ScalarUI(1) << int(scalar_normalization_offset)); // 2^24
const Packet normalization_factor = pset1<Packet>(scalar_normalization_factor);
const Packet normalization_factor = pset1<Packet>(scalar_normalization_factor);
const Packet normalized_a = pselect(is_denormal, pmul(a, normalization_factor), a);
// Determine exponent offset: -126 if normal, -126-24 if denormal
const Scalar scalar_exponent_offset = -Scalar((ScalarUI(1)<<(int(ExponentBits)-1)) - ScalarUI(2)); // -126
const Scalar scalar_exponent_offset = -Scalar((ScalarUI(1)<<(ExponentBits-1)) - ScalarUI(2)); // -126
Packet exponent_offset = pset1<Packet>(scalar_exponent_offset);
const Packet normalization_offset = pset1<Packet>(-Scalar(scalar_normalization_offset)); // -24
exponent_offset = pselect(is_denormal, padd(exponent_offset, normalization_offset), exponent_offset);
// Determine exponent and mantissa from normalized_a.
exponent = pfrexp_generic_get_biased_exponent(normalized_a);
// Zero, Inf and NaN return 'a' unmodified, exponent is zero
// (technically the exponent is unspecified for inf/NaN, but GCC/Clang set it to zero)
const Scalar scalar_non_finite_exponent = Scalar((ScalarUI(1) << int(ExponentBits)) - ScalarUI(1)); // 255
const Scalar scalar_non_finite_exponent = Scalar((ScalarUI(1) << ExponentBits) - ScalarUI(1)); // 255
const Packet non_finite_exponent = pset1<Packet>(scalar_non_finite_exponent);
const Packet is_zero_or_not_finite = por(pcmp_eq(a, zero), pcmp_eq(exponent, non_finite_exponent));
const Packet m = pselect(is_zero_or_not_finite, a, por(pand(normalized_a, sign_mantissa_mask), half));
exponent = pselect(is_zero_or_not_finite, zero, padd(exponent, exponent_offset));
exponent = pselect(is_zero_or_not_finite, zero, padd(exponent, exponent_offset));
return m;
}
@ -110,25 +109,24 @@ Packet pldexp_generic(const Packet& a, const Packet& exponent) {
typedef typename unpacket_traits<Packet>::integer_packet PacketI;
typedef typename unpacket_traits<Packet>::type Scalar;
typedef typename unpacket_traits<PacketI>::type ScalarI;
enum {
static constexpr int
TotalBits = sizeof(Scalar) * CHAR_BIT,
MantissaBits = numext::numeric_limits<Scalar>::digits - 1,
ExponentBits = int(TotalBits) - int(MantissaBits) - 1
};
ExponentBits = TotalBits - MantissaBits - 1;
const Packet max_exponent = pset1<Packet>(Scalar((ScalarI(1)<<int(ExponentBits)) + ScalarI(int(MantissaBits) - 1))); // 278
const PacketI bias = pset1<PacketI>((ScalarI(1)<<(int(ExponentBits)-1)) - ScalarI(1)); // 127
const Packet max_exponent = pset1<Packet>(Scalar((ScalarI(1)<<ExponentBits) + ScalarI(MantissaBits - 1))); // 278
const PacketI bias = pset1<PacketI>((ScalarI(1)<<(ExponentBits-1)) - ScalarI(1)); // 127
const PacketI e = pcast<Packet, PacketI>(pmin(pmax(exponent, pnegate(max_exponent)), max_exponent));
PacketI b = parithmetic_shift_right<2>(e); // floor(e/4);
Packet c = preinterpret<Packet>(plogical_shift_left<int(MantissaBits)>(padd(b, bias))); // 2^b
Packet c = preinterpret<Packet>(plogical_shift_left<MantissaBits>(padd(b, bias))); // 2^b
Packet out = pmul(pmul(pmul(a, c), c), c); // a * 2^(3b)
b = psub(psub(psub(e, b), b), b); // e - 3b
c = preinterpret<Packet>(plogical_shift_left<int(MantissaBits)>(padd(b, bias))); // 2^(e-3*b)
c = preinterpret<Packet>(plogical_shift_left<MantissaBits>(padd(b, bias))); // 2^(e-3*b)
out = pmul(out, c);
return out;
}
// Explicitly multiplies
// Explicitly multiplies
// a * (2^e)
// clamping e to the range
// [NumTraits<Scalar>::min_exponent()-2, NumTraits<Scalar>::max_exponent()]
@ -142,20 +140,19 @@ struct pldexp_fast_impl {
typedef typename unpacket_traits<Packet>::integer_packet PacketI;
typedef typename unpacket_traits<Packet>::type Scalar;
typedef typename unpacket_traits<PacketI>::type ScalarI;
enum {
static constexpr int
TotalBits = sizeof(Scalar) * CHAR_BIT,
MantissaBits = numext::numeric_limits<Scalar>::digits - 1,
ExponentBits = int(TotalBits) - int(MantissaBits) - 1
};
ExponentBits = TotalBits - MantissaBits - 1;
static EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC
Packet run(const Packet& a, const Packet& exponent) {
const Packet bias = pset1<Packet>(Scalar((ScalarI(1)<<(int(ExponentBits)-1)) - ScalarI(1))); // 127
const Packet limit = pset1<Packet>(Scalar((ScalarI(1)<<int(ExponentBits)) - ScalarI(1))); // 255
const Packet bias = pset1<Packet>(Scalar((ScalarI(1)<<(ExponentBits-1)) - ScalarI(1))); // 127
const Packet limit = pset1<Packet>(Scalar((ScalarI(1)<<ExponentBits) - ScalarI(1))); // 255
// restrict biased exponent between 0 and 255 for float.
const PacketI e = pcast<Packet, PacketI>(pmin(pmax(padd(exponent, bias), pzero(limit)), limit)); // exponent + 127
// return a * (2^e)
return pmul(a, preinterpret<Packet>(plogical_shift_left<int(MantissaBits)>(e)));
return pmul(a, preinterpret<Packet>(plogical_shift_left<MantissaBits>(e)));
}
};