mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-05-03 01:04:23 +08:00
Convert bit calculation to constexpr, avoid casts.
This commit is contained in:
parent
baf9a985ec
commit
cb1e8228e9
@ -32,7 +32,7 @@ template<typename Packet> EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC
|
|||||||
Packet pfrexp_generic_get_biased_exponent(const Packet& a) {
|
Packet pfrexp_generic_get_biased_exponent(const Packet& a) {
|
||||||
typedef typename unpacket_traits<Packet>::type Scalar;
|
typedef typename unpacket_traits<Packet>::type Scalar;
|
||||||
typedef typename unpacket_traits<Packet>::integer_packet PacketI;
|
typedef typename unpacket_traits<Packet>::integer_packet PacketI;
|
||||||
enum { mantissa_bits = numext::numeric_limits<Scalar>::digits - 1};
|
static constexpr int mantissa_bits = numext::numeric_limits<Scalar>::digits - 1;
|
||||||
return pcast<PacketI, Packet>(plogical_shift_right<mantissa_bits>(preinterpret<PacketI>(pabs(a))));
|
return pcast<PacketI, Packet>(plogical_shift_right<mantissa_bits>(preinterpret<PacketI>(pabs(a))));
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -42,14 +42,13 @@ template<typename Packet> EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC
|
|||||||
Packet pfrexp_generic(const Packet& a, Packet& exponent) {
|
Packet pfrexp_generic(const Packet& a, Packet& exponent) {
|
||||||
typedef typename unpacket_traits<Packet>::type Scalar;
|
typedef typename unpacket_traits<Packet>::type Scalar;
|
||||||
typedef typename make_unsigned<typename make_integer<Scalar>::type>::type ScalarUI;
|
typedef typename make_unsigned<typename make_integer<Scalar>::type>::type ScalarUI;
|
||||||
enum {
|
static constexpr int
|
||||||
TotalBits = sizeof(Scalar) * CHAR_BIT,
|
TotalBits = sizeof(Scalar) * CHAR_BIT,
|
||||||
MantissaBits = numext::numeric_limits<Scalar>::digits - 1,
|
MantissaBits = numext::numeric_limits<Scalar>::digits - 1,
|
||||||
ExponentBits = int(TotalBits) - int(MantissaBits) - 1
|
ExponentBits = TotalBits - MantissaBits - 1;
|
||||||
};
|
|
||||||
|
|
||||||
EIGEN_CONSTEXPR ScalarUI scalar_sign_mantissa_mask =
|
EIGEN_CONSTEXPR ScalarUI scalar_sign_mantissa_mask =
|
||||||
~(((ScalarUI(1) << int(ExponentBits)) - ScalarUI(1)) << int(MantissaBits)); // ~0x7f800000
|
~(((ScalarUI(1) << ExponentBits) - ScalarUI(1)) << MantissaBits); // ~0x7f800000
|
||||||
const Packet sign_mantissa_mask = pset1frombits<Packet>(static_cast<ScalarUI>(scalar_sign_mantissa_mask));
|
const Packet sign_mantissa_mask = pset1frombits<Packet>(static_cast<ScalarUI>(scalar_sign_mantissa_mask));
|
||||||
const Packet half = pset1<Packet>(Scalar(0.5));
|
const Packet half = pset1<Packet>(Scalar(0.5));
|
||||||
const Packet zero = pzero(a);
|
const Packet zero = pzero(a);
|
||||||
@ -57,14 +56,14 @@ Packet pfrexp_generic(const Packet& a, Packet& exponent) {
|
|||||||
|
|
||||||
// To handle denormals, normalize by multiplying by 2^(int(MantissaBits)+1).
|
// To handle denormals, normalize by multiplying by 2^(int(MantissaBits)+1).
|
||||||
const Packet is_denormal = pcmp_lt(pabs(a), normal_min);
|
const Packet is_denormal = pcmp_lt(pabs(a), normal_min);
|
||||||
EIGEN_CONSTEXPR ScalarUI scalar_normalization_offset = ScalarUI(int(MantissaBits) + 1); // 24
|
EIGEN_CONSTEXPR ScalarUI scalar_normalization_offset = ScalarUI(MantissaBits + 1); // 24
|
||||||
// The following cannot be constexpr because bfloat16(uint16_t) is not constexpr.
|
// The following cannot be constexpr because bfloat16(uint16_t) is not constexpr.
|
||||||
const Scalar scalar_normalization_factor = Scalar(ScalarUI(1) << int(scalar_normalization_offset)); // 2^24
|
const Scalar scalar_normalization_factor = Scalar(ScalarUI(1) << int(scalar_normalization_offset)); // 2^24
|
||||||
const Packet normalization_factor = pset1<Packet>(scalar_normalization_factor);
|
const Packet normalization_factor = pset1<Packet>(scalar_normalization_factor);
|
||||||
const Packet normalized_a = pselect(is_denormal, pmul(a, normalization_factor), a);
|
const Packet normalized_a = pselect(is_denormal, pmul(a, normalization_factor), a);
|
||||||
|
|
||||||
// Determine exponent offset: -126 if normal, -126-24 if denormal
|
// Determine exponent offset: -126 if normal, -126-24 if denormal
|
||||||
const Scalar scalar_exponent_offset = -Scalar((ScalarUI(1)<<(int(ExponentBits)-1)) - ScalarUI(2)); // -126
|
const Scalar scalar_exponent_offset = -Scalar((ScalarUI(1)<<(ExponentBits-1)) - ScalarUI(2)); // -126
|
||||||
Packet exponent_offset = pset1<Packet>(scalar_exponent_offset);
|
Packet exponent_offset = pset1<Packet>(scalar_exponent_offset);
|
||||||
const Packet normalization_offset = pset1<Packet>(-Scalar(scalar_normalization_offset)); // -24
|
const Packet normalization_offset = pset1<Packet>(-Scalar(scalar_normalization_offset)); // -24
|
||||||
exponent_offset = pselect(is_denormal, padd(exponent_offset, normalization_offset), exponent_offset);
|
exponent_offset = pselect(is_denormal, padd(exponent_offset, normalization_offset), exponent_offset);
|
||||||
@ -73,7 +72,7 @@ Packet pfrexp_generic(const Packet& a, Packet& exponent) {
|
|||||||
exponent = pfrexp_generic_get_biased_exponent(normalized_a);
|
exponent = pfrexp_generic_get_biased_exponent(normalized_a);
|
||||||
// Zero, Inf and NaN return 'a' unmodified, exponent is zero
|
// Zero, Inf and NaN return 'a' unmodified, exponent is zero
|
||||||
// (technically the exponent is unspecified for inf/NaN, but GCC/Clang set it to zero)
|
// (technically the exponent is unspecified for inf/NaN, but GCC/Clang set it to zero)
|
||||||
const Scalar scalar_non_finite_exponent = Scalar((ScalarUI(1) << int(ExponentBits)) - ScalarUI(1)); // 255
|
const Scalar scalar_non_finite_exponent = Scalar((ScalarUI(1) << ExponentBits) - ScalarUI(1)); // 255
|
||||||
const Packet non_finite_exponent = pset1<Packet>(scalar_non_finite_exponent);
|
const Packet non_finite_exponent = pset1<Packet>(scalar_non_finite_exponent);
|
||||||
const Packet is_zero_or_not_finite = por(pcmp_eq(a, zero), pcmp_eq(exponent, non_finite_exponent));
|
const Packet is_zero_or_not_finite = por(pcmp_eq(a, zero), pcmp_eq(exponent, non_finite_exponent));
|
||||||
const Packet m = pselect(is_zero_or_not_finite, a, por(pand(normalized_a, sign_mantissa_mask), half));
|
const Packet m = pselect(is_zero_or_not_finite, a, por(pand(normalized_a, sign_mantissa_mask), half));
|
||||||
@ -110,20 +109,19 @@ Packet pldexp_generic(const Packet& a, const Packet& exponent) {
|
|||||||
typedef typename unpacket_traits<Packet>::integer_packet PacketI;
|
typedef typename unpacket_traits<Packet>::integer_packet PacketI;
|
||||||
typedef typename unpacket_traits<Packet>::type Scalar;
|
typedef typename unpacket_traits<Packet>::type Scalar;
|
||||||
typedef typename unpacket_traits<PacketI>::type ScalarI;
|
typedef typename unpacket_traits<PacketI>::type ScalarI;
|
||||||
enum {
|
static constexpr int
|
||||||
TotalBits = sizeof(Scalar) * CHAR_BIT,
|
TotalBits = sizeof(Scalar) * CHAR_BIT,
|
||||||
MantissaBits = numext::numeric_limits<Scalar>::digits - 1,
|
MantissaBits = numext::numeric_limits<Scalar>::digits - 1,
|
||||||
ExponentBits = int(TotalBits) - int(MantissaBits) - 1
|
ExponentBits = TotalBits - MantissaBits - 1;
|
||||||
};
|
|
||||||
|
|
||||||
const Packet max_exponent = pset1<Packet>(Scalar((ScalarI(1)<<int(ExponentBits)) + ScalarI(int(MantissaBits) - 1))); // 278
|
const Packet max_exponent = pset1<Packet>(Scalar((ScalarI(1)<<ExponentBits) + ScalarI(MantissaBits - 1))); // 278
|
||||||
const PacketI bias = pset1<PacketI>((ScalarI(1)<<(int(ExponentBits)-1)) - ScalarI(1)); // 127
|
const PacketI bias = pset1<PacketI>((ScalarI(1)<<(ExponentBits-1)) - ScalarI(1)); // 127
|
||||||
const PacketI e = pcast<Packet, PacketI>(pmin(pmax(exponent, pnegate(max_exponent)), max_exponent));
|
const PacketI e = pcast<Packet, PacketI>(pmin(pmax(exponent, pnegate(max_exponent)), max_exponent));
|
||||||
PacketI b = parithmetic_shift_right<2>(e); // floor(e/4);
|
PacketI b = parithmetic_shift_right<2>(e); // floor(e/4);
|
||||||
Packet c = preinterpret<Packet>(plogical_shift_left<int(MantissaBits)>(padd(b, bias))); // 2^b
|
Packet c = preinterpret<Packet>(plogical_shift_left<MantissaBits>(padd(b, bias))); // 2^b
|
||||||
Packet out = pmul(pmul(pmul(a, c), c), c); // a * 2^(3b)
|
Packet out = pmul(pmul(pmul(a, c), c), c); // a * 2^(3b)
|
||||||
b = psub(psub(psub(e, b), b), b); // e - 3b
|
b = psub(psub(psub(e, b), b), b); // e - 3b
|
||||||
c = preinterpret<Packet>(plogical_shift_left<int(MantissaBits)>(padd(b, bias))); // 2^(e-3*b)
|
c = preinterpret<Packet>(plogical_shift_left<MantissaBits>(padd(b, bias))); // 2^(e-3*b)
|
||||||
out = pmul(out, c);
|
out = pmul(out, c);
|
||||||
return out;
|
return out;
|
||||||
}
|
}
|
||||||
@ -142,20 +140,19 @@ struct pldexp_fast_impl {
|
|||||||
typedef typename unpacket_traits<Packet>::integer_packet PacketI;
|
typedef typename unpacket_traits<Packet>::integer_packet PacketI;
|
||||||
typedef typename unpacket_traits<Packet>::type Scalar;
|
typedef typename unpacket_traits<Packet>::type Scalar;
|
||||||
typedef typename unpacket_traits<PacketI>::type ScalarI;
|
typedef typename unpacket_traits<PacketI>::type ScalarI;
|
||||||
enum {
|
static constexpr int
|
||||||
TotalBits = sizeof(Scalar) * CHAR_BIT,
|
TotalBits = sizeof(Scalar) * CHAR_BIT,
|
||||||
MantissaBits = numext::numeric_limits<Scalar>::digits - 1,
|
MantissaBits = numext::numeric_limits<Scalar>::digits - 1,
|
||||||
ExponentBits = int(TotalBits) - int(MantissaBits) - 1
|
ExponentBits = TotalBits - MantissaBits - 1;
|
||||||
};
|
|
||||||
|
|
||||||
static EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC
|
static EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC
|
||||||
Packet run(const Packet& a, const Packet& exponent) {
|
Packet run(const Packet& a, const Packet& exponent) {
|
||||||
const Packet bias = pset1<Packet>(Scalar((ScalarI(1)<<(int(ExponentBits)-1)) - ScalarI(1))); // 127
|
const Packet bias = pset1<Packet>(Scalar((ScalarI(1)<<(ExponentBits-1)) - ScalarI(1))); // 127
|
||||||
const Packet limit = pset1<Packet>(Scalar((ScalarI(1)<<int(ExponentBits)) - ScalarI(1))); // 255
|
const Packet limit = pset1<Packet>(Scalar((ScalarI(1)<<ExponentBits) - ScalarI(1))); // 255
|
||||||
// restrict biased exponent between 0 and 255 for float.
|
// restrict biased exponent between 0 and 255 for float.
|
||||||
const PacketI e = pcast<Packet, PacketI>(pmin(pmax(padd(exponent, bias), pzero(limit)), limit)); // exponent + 127
|
const PacketI e = pcast<Packet, PacketI>(pmin(pmax(padd(exponent, bias), pzero(limit)), limit)); // exponent + 127
|
||||||
// return a * (2^e)
|
// return a * (2^e)
|
||||||
return pmul(a, preinterpret<Packet>(plogical_shift_left<int(MantissaBits)>(e)));
|
return pmul(a, preinterpret<Packet>(plogical_shift_left<MantissaBits>(e)));
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user