mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-05-01 08:14:10 +08:00
Convert bit calculation to constexpr, avoid casts.
This commit is contained in:
parent
baf9a985ec
commit
cb1e8228e9
@ -28,11 +28,11 @@ template<> struct make_integer<double> { typedef numext::int64_t type; };
|
||||
template<> struct make_integer<half> { typedef numext::int16_t type; };
|
||||
template<> struct make_integer<bfloat16> { typedef numext::int16_t type; };
|
||||
|
||||
template<typename Packet> EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC
|
||||
template<typename Packet> EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC
|
||||
Packet pfrexp_generic_get_biased_exponent(const Packet& a) {
|
||||
typedef typename unpacket_traits<Packet>::type Scalar;
|
||||
typedef typename unpacket_traits<Packet>::integer_packet PacketI;
|
||||
enum { mantissa_bits = numext::numeric_limits<Scalar>::digits - 1};
|
||||
static constexpr int mantissa_bits = numext::numeric_limits<Scalar>::digits - 1;
|
||||
return pcast<PacketI, Packet>(plogical_shift_right<mantissa_bits>(preinterpret<PacketI>(pabs(a))));
|
||||
}
|
||||
|
||||
@ -42,42 +42,41 @@ template<typename Packet> EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC
|
||||
Packet pfrexp_generic(const Packet& a, Packet& exponent) {
|
||||
typedef typename unpacket_traits<Packet>::type Scalar;
|
||||
typedef typename make_unsigned<typename make_integer<Scalar>::type>::type ScalarUI;
|
||||
enum {
|
||||
static constexpr int
|
||||
TotalBits = sizeof(Scalar) * CHAR_BIT,
|
||||
MantissaBits = numext::numeric_limits<Scalar>::digits - 1,
|
||||
ExponentBits = int(TotalBits) - int(MantissaBits) - 1
|
||||
};
|
||||
ExponentBits = TotalBits - MantissaBits - 1;
|
||||
|
||||
EIGEN_CONSTEXPR ScalarUI scalar_sign_mantissa_mask =
|
||||
~(((ScalarUI(1) << int(ExponentBits)) - ScalarUI(1)) << int(MantissaBits)); // ~0x7f800000
|
||||
const Packet sign_mantissa_mask = pset1frombits<Packet>(static_cast<ScalarUI>(scalar_sign_mantissa_mask));
|
||||
EIGEN_CONSTEXPR ScalarUI scalar_sign_mantissa_mask =
|
||||
~(((ScalarUI(1) << ExponentBits) - ScalarUI(1)) << MantissaBits); // ~0x7f800000
|
||||
const Packet sign_mantissa_mask = pset1frombits<Packet>(static_cast<ScalarUI>(scalar_sign_mantissa_mask));
|
||||
const Packet half = pset1<Packet>(Scalar(0.5));
|
||||
const Packet zero = pzero(a);
|
||||
const Packet normal_min = pset1<Packet>((numext::numeric_limits<Scalar>::min)()); // Minimum normal value, 2^-126
|
||||
|
||||
|
||||
// To handle denormals, normalize by multiplying by 2^(int(MantissaBits)+1).
|
||||
const Packet is_denormal = pcmp_lt(pabs(a), normal_min);
|
||||
EIGEN_CONSTEXPR ScalarUI scalar_normalization_offset = ScalarUI(int(MantissaBits) + 1); // 24
|
||||
EIGEN_CONSTEXPR ScalarUI scalar_normalization_offset = ScalarUI(MantissaBits + 1); // 24
|
||||
// The following cannot be constexpr because bfloat16(uint16_t) is not constexpr.
|
||||
const Scalar scalar_normalization_factor = Scalar(ScalarUI(1) << int(scalar_normalization_offset)); // 2^24
|
||||
const Packet normalization_factor = pset1<Packet>(scalar_normalization_factor);
|
||||
const Packet normalization_factor = pset1<Packet>(scalar_normalization_factor);
|
||||
const Packet normalized_a = pselect(is_denormal, pmul(a, normalization_factor), a);
|
||||
|
||||
|
||||
// Determine exponent offset: -126 if normal, -126-24 if denormal
|
||||
const Scalar scalar_exponent_offset = -Scalar((ScalarUI(1)<<(int(ExponentBits)-1)) - ScalarUI(2)); // -126
|
||||
const Scalar scalar_exponent_offset = -Scalar((ScalarUI(1)<<(ExponentBits-1)) - ScalarUI(2)); // -126
|
||||
Packet exponent_offset = pset1<Packet>(scalar_exponent_offset);
|
||||
const Packet normalization_offset = pset1<Packet>(-Scalar(scalar_normalization_offset)); // -24
|
||||
exponent_offset = pselect(is_denormal, padd(exponent_offset, normalization_offset), exponent_offset);
|
||||
|
||||
|
||||
// Determine exponent and mantissa from normalized_a.
|
||||
exponent = pfrexp_generic_get_biased_exponent(normalized_a);
|
||||
// Zero, Inf and NaN return 'a' unmodified, exponent is zero
|
||||
// (technically the exponent is unspecified for inf/NaN, but GCC/Clang set it to zero)
|
||||
const Scalar scalar_non_finite_exponent = Scalar((ScalarUI(1) << int(ExponentBits)) - ScalarUI(1)); // 255
|
||||
const Scalar scalar_non_finite_exponent = Scalar((ScalarUI(1) << ExponentBits) - ScalarUI(1)); // 255
|
||||
const Packet non_finite_exponent = pset1<Packet>(scalar_non_finite_exponent);
|
||||
const Packet is_zero_or_not_finite = por(pcmp_eq(a, zero), pcmp_eq(exponent, non_finite_exponent));
|
||||
const Packet m = pselect(is_zero_or_not_finite, a, por(pand(normalized_a, sign_mantissa_mask), half));
|
||||
exponent = pselect(is_zero_or_not_finite, zero, padd(exponent, exponent_offset));
|
||||
exponent = pselect(is_zero_or_not_finite, zero, padd(exponent, exponent_offset));
|
||||
return m;
|
||||
}
|
||||
|
||||
@ -110,25 +109,24 @@ Packet pldexp_generic(const Packet& a, const Packet& exponent) {
|
||||
typedef typename unpacket_traits<Packet>::integer_packet PacketI;
|
||||
typedef typename unpacket_traits<Packet>::type Scalar;
|
||||
typedef typename unpacket_traits<PacketI>::type ScalarI;
|
||||
enum {
|
||||
static constexpr int
|
||||
TotalBits = sizeof(Scalar) * CHAR_BIT,
|
||||
MantissaBits = numext::numeric_limits<Scalar>::digits - 1,
|
||||
ExponentBits = int(TotalBits) - int(MantissaBits) - 1
|
||||
};
|
||||
ExponentBits = TotalBits - MantissaBits - 1;
|
||||
|
||||
const Packet max_exponent = pset1<Packet>(Scalar((ScalarI(1)<<int(ExponentBits)) + ScalarI(int(MantissaBits) - 1))); // 278
|
||||
const PacketI bias = pset1<PacketI>((ScalarI(1)<<(int(ExponentBits)-1)) - ScalarI(1)); // 127
|
||||
const Packet max_exponent = pset1<Packet>(Scalar((ScalarI(1)<<ExponentBits) + ScalarI(MantissaBits - 1))); // 278
|
||||
const PacketI bias = pset1<PacketI>((ScalarI(1)<<(ExponentBits-1)) - ScalarI(1)); // 127
|
||||
const PacketI e = pcast<Packet, PacketI>(pmin(pmax(exponent, pnegate(max_exponent)), max_exponent));
|
||||
PacketI b = parithmetic_shift_right<2>(e); // floor(e/4);
|
||||
Packet c = preinterpret<Packet>(plogical_shift_left<int(MantissaBits)>(padd(b, bias))); // 2^b
|
||||
Packet c = preinterpret<Packet>(plogical_shift_left<MantissaBits>(padd(b, bias))); // 2^b
|
||||
Packet out = pmul(pmul(pmul(a, c), c), c); // a * 2^(3b)
|
||||
b = psub(psub(psub(e, b), b), b); // e - 3b
|
||||
c = preinterpret<Packet>(plogical_shift_left<int(MantissaBits)>(padd(b, bias))); // 2^(e-3*b)
|
||||
c = preinterpret<Packet>(plogical_shift_left<MantissaBits>(padd(b, bias))); // 2^(e-3*b)
|
||||
out = pmul(out, c);
|
||||
return out;
|
||||
}
|
||||
|
||||
// Explicitly multiplies
|
||||
// Explicitly multiplies
|
||||
// a * (2^e)
|
||||
// clamping e to the range
|
||||
// [NumTraits<Scalar>::min_exponent()-2, NumTraits<Scalar>::max_exponent()]
|
||||
@ -142,20 +140,19 @@ struct pldexp_fast_impl {
|
||||
typedef typename unpacket_traits<Packet>::integer_packet PacketI;
|
||||
typedef typename unpacket_traits<Packet>::type Scalar;
|
||||
typedef typename unpacket_traits<PacketI>::type ScalarI;
|
||||
enum {
|
||||
static constexpr int
|
||||
TotalBits = sizeof(Scalar) * CHAR_BIT,
|
||||
MantissaBits = numext::numeric_limits<Scalar>::digits - 1,
|
||||
ExponentBits = int(TotalBits) - int(MantissaBits) - 1
|
||||
};
|
||||
|
||||
ExponentBits = TotalBits - MantissaBits - 1;
|
||||
|
||||
static EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC
|
||||
Packet run(const Packet& a, const Packet& exponent) {
|
||||
const Packet bias = pset1<Packet>(Scalar((ScalarI(1)<<(int(ExponentBits)-1)) - ScalarI(1))); // 127
|
||||
const Packet limit = pset1<Packet>(Scalar((ScalarI(1)<<int(ExponentBits)) - ScalarI(1))); // 255
|
||||
const Packet bias = pset1<Packet>(Scalar((ScalarI(1)<<(ExponentBits-1)) - ScalarI(1))); // 127
|
||||
const Packet limit = pset1<Packet>(Scalar((ScalarI(1)<<ExponentBits) - ScalarI(1))); // 255
|
||||
// restrict biased exponent between 0 and 255 for float.
|
||||
const PacketI e = pcast<Packet, PacketI>(pmin(pmax(padd(exponent, bias), pzero(limit)), limit)); // exponent + 127
|
||||
// return a * (2^e)
|
||||
return pmul(a, preinterpret<Packet>(plogical_shift_left<int(MantissaBits)>(e)));
|
||||
return pmul(a, preinterpret<Packet>(plogical_shift_left<MantissaBits>(e)));
|
||||
}
|
||||
};
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user